/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package cn.exlive.video.handler;
import com.google.common.net.HttpHeaders;
import org.apache.commons.lang.StringUtils;
import java.net.InetAddress;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import lombok.extern.slf4j.Slf4j;
import org.java_websocket.client.WebSocketClient;
import org.java_websocket.handshake.ServerHandshake;
import org.springframework.web.socket.BinaryMessage;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.AbstractWebSocketHandler;
import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableSet;
/**
*
* WebSocket 代理的核心类,将 Client 的 WS 请求,转发到后台的 WebSocket 服务器,并把服务器的响应返回给 Client
*
*
*
* Created by zhenqin.
* User: zhenqin
* Date: 2023/3/17
* Time: 下午3:31
*
*
*
* @author zhenqin
*/
@Slf4j
public class ProxyWebSocketHandler extends AbstractWebSocketHandler {
/**
* WebSocket Proxy 需要移除的 Header
*/
final static Set WEBSOCKET_EXCLUDE_HEADER_NAME =
ImmutableSet.of("sec-websocket-version", "sec-websocket-extensions");
/**
* 远端 WebSocket 目标点
*/
final String endPoint;
/**
* 代理远端的 websocket
*/
MsgWebSocketClient webSocketClient;
public ProxyWebSocketHandler(String endPoint) {
this.endPoint = endPoint;
}
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
final org.springframework.http.HttpHeaders handshakeHeaders = session.getHandshakeHeaders();
final Map headers = new HashMap<>();
copyRequestHeaders(handshakeHeaders, headers);
try {
addProxyHeaders(handshakeHeaders, headers, session.getRemoteAddress().getHostName());
} catch (Exception ignore) {
}
try {
// 打开远端 websocket
this.webSocketClient = new MsgWebSocketClient(endPoint, session, headers);
this.webSocketClient.connect();
log.info("连接成功。。。" + endPoint);
} catch (Exception e) {
log.error(endPoint + " 连接异常。", e);
// 远端连接失败,则立即关闭
// afterConnectionClosed(session, CloseStatus.SERVER_ERROR);
}
}
@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
// 将客户端发送消息,发送到 远端 websocket
if (!this.webSocketClient.isOpen()) {
try {
afterConnectionEstablished(session);
} catch (Exception ignore) { }
}
if(this.webSocketClient.isOpen()) {
this.webSocketClient.send(message.getPayload());
}
}
@Override
protected void handleBinaryMessage(WebSocketSession session, BinaryMessage message) throws Exception {
// 将客户端发送消息,发送到 远端 websocket
if (!this.webSocketClient.isOpen()) {
try {
afterConnectionEstablished(session);
} catch (Exception ignore) { }
}
if(this.webSocketClient.isOpen()) {
this.webSocketClient.send(message.getPayload());
}
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
log.info("断开连接。。。" + endPoint);
// 关闭远端 websocket
if (this.webSocketClient != null) {
this.webSocketClient.close();
}
}
/**
* 请求的 Header
* @param httpHeaders
* @param requestHeader
*/
protected void copyRequestHeaders(org.springframework.http.HttpHeaders httpHeaders, Map requestHeader) {
for (Map.Entry> entry : httpHeaders.entrySet()) {
String headerName = entry.getKey();
String lowerHeaderName = headerName.toLowerCase(Locale.ENGLISH);
// Remove hop-by-hop headers.
if (WEBSOCKET_EXCLUDE_HEADER_NAME.contains(lowerHeaderName)) {
continue;
}
final List value = entry.getValue();
if (value != null) {
requestHeader.put(headerName, Joiner.on(", ").join(value));
}
}
}
/**
* 代理的相关配置
* @param httpHeaders
* @param requestHeader
*/
protected void addProxyHeaders(org.springframework.http.HttpHeaders httpHeaders, Map requestHeader, String remoteHostName) {
try {
requestHeader.put(HttpHeaders.VIA, "http/1.1 " + InetAddress.getLocalHost().getHostName());
requestHeader.put(HttpHeaders.X_FORWARDED_HOST, InetAddress.getLocalHost().getHostName());
} catch (Exception ignore) {
}
String xForwardFor = getHeader(httpHeaders, "X-Forwarded-For");
if (StringUtils.isBlank(xForwardFor)) {
// xForwardFor, 第一层代理
requestHeader.put(HttpHeaders.X_FORWARDED_FOR, remoteHostName);
} else {
// xForwardFor,多层代理,将外层 IP全部 copy
requestHeader.put(HttpHeaders.X_FORWARDED_FOR, xForwardFor + ", " + remoteHostName);
}
requestHeader.put(HttpHeaders.X_FORWARDED_HOST, getHeader(httpHeaders, HttpHeaders.HOST));
}
/**
* 返回 Header
* @param httpHeaders
* @param headerName
* @return
*/
protected String getHeader(org.springframework.http.HttpHeaders httpHeaders, String headerName) {
final List valuesAsList = httpHeaders.getValuesAsList(headerName);
return valuesAsList.size() > 0 ? Joiner.on(", ").join(valuesAsList) : "";
}
static class MsgWebSocketClient extends WebSocketClient {
/**
* client ref
*/
final WebSocketSession session;
/**
* 发起请求的 Header
*/
final Map httpHeaders;
/**
* 远端服务器返回的 Header
*/
final Map responseHeaders = new HashMap<>();
public MsgWebSocketClient(String url, WebSocketSession session,
Map httpHeaders) throws URISyntaxException {
super(new URI(url), httpHeaders); // 以 client 的 Header 访问 remote,否则部分有认证的,无法通过认证
this.httpHeaders = httpHeaders;
log.info("======= WebSocket Request Headers =======");
for (Map.Entry entry : httpHeaders.entrySet()) {
log.info(entry.getKey() + ": " + entry.getValue());
}
log.info("========================================");
this.setConnectionLostTimeout(30000);
this.session = session;
}
@Override
public void onOpen(ServerHandshake shake) {
log.info("远端 {} 握手成功...", getURI());
log.info("====== WebSocket Response Headers ======");
for (Iterator it = shake.iterateHttpFields(); it.hasNext();) {
String key = it.next();
responseHeaders.put(key, shake.getFieldValue(key));
log.info(key + ": " + shake.getFieldValue(key));
}
log.info("========================================");
}
@Override
public void onMessage(String paramString) {
log.info("receive message: {} remote: {}", paramString, getURI());
// String result = "【websocket消息】【" + DateTime.now().toString("yyyy-MM-dd HH:mm:ss") + "】收到客户端消息: " +
// paramString;
try {
session.sendMessage(new TextMessage(paramString));
} catch (Exception e) {
log.error("WS发送消息异常。", e);
}
}
@Override
public void onMessage(ByteBuffer bytes) {
log.info("receive binary message length: {} remote: {}", bytes.position(), getURI());
try {
session.sendMessage(new BinaryMessage(bytes));
} catch (Exception e) {
log.error("WS发送消息异常。", e);
}
}
@Override
public void onClose(int paramInt, String paramString, boolean paramBoolean) {
log.info("close remote, reason: {} .", paramString);
if (session != null) {
try {
session.close(CloseStatus.SESSION_NOT_RELIABLE);
} catch (Exception e) {
}
}
}
@Override
public void onError(Exception e) {
log.error("WS:" + getURI() + " 异常。", e);
}
}
}