ProxyWebSocketHandler.java 9.8 KB


  1. /*
  2. * Licensed to the Apache Software Foundation (ASF) under one or more
  3. * contributor license agreements. See the NOTICE file distributed with
  4. * this work for additional information regarding copyright ownership.
  5. * The ASF licenses this file to You under the Apache License, Version 2.0
  6. * (the "License"); you may not use this file except in compliance with
  7. * the License. You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. package cn.exlive.video.handler;
  18. import com.google.common.net.HttpHeaders;
  19. import org.apache.commons.lang.StringUtils;
  20. import java.net.InetAddress;
  21. import java.net.URI;
  22. import java.net.URISyntaxException;
  23. import java.nio.ByteBuffer;
  24. import java.util.HashMap;
  25. import java.util.Iterator;
  26. import java.util.List;
  27. import java.util.Locale;
  28. import java.util.Map;
  29. import java.util.Set;
  30. import lombok.extern.slf4j.Slf4j;
  31. import org.java_websocket.client.WebSocketClient;
  32. import org.java_websocket.handshake.ServerHandshake;
  33. import org.springframework.web.socket.BinaryMessage;
  34. import org.springframework.web.socket.CloseStatus;
  35. import org.springframework.web.socket.TextMessage;
  36. import org.springframework.web.socket.WebSocketSession;
  37. import org.springframework.web.socket.handler.AbstractWebSocketHandler;
  38. import com.google.common.base.Joiner;
  39. import com.google.common.collect.ImmutableSet;
  40. /**
  41. *
  42. * WebSocket 代理的核心类,将 Client 的 WS 请求,转发到后台的 WebSocket 服务器,并把服务器的响应返回给 Client
  43. *
  44. * <pre>
  45. *
  46. * Created by zhenqin.
  47. * User: zhenqin
  48. * Date: 2023/3/17
  49. * Time: 下午3:31
  50. *
  51. * </pre>
  52. *
  53. * @author zhenqin
  54. */
  55. @Slf4j
  56. public class ProxyWebSocketHandler extends AbstractWebSocketHandler {
  57. /**
  58. * WebSocket Proxy 需要移除的 Header
  59. */
  60. final static Set<String> WEBSOCKET_EXCLUDE_HEADER_NAME =
  61. ImmutableSet.of("sec-websocket-version", "sec-websocket-extensions");
  62. /**
  63. * 远端 WebSocket 目标点
  64. */
  65. final String endPoint;
  66. /**
  67. * 代理远端的 websocket
  68. */
  69. MsgWebSocketClient webSocketClient;
  70. public ProxyWebSocketHandler(String endPoint) {
  71. this.endPoint = endPoint;
  72. }
  73. @Override
  74. public void afterConnectionEstablished(WebSocketSession session) throws Exception {
  75. final org.springframework.http.HttpHeaders handshakeHeaders = session.getHandshakeHeaders();
  76. final Map<String, String> headers = new HashMap<>();
  77. copyRequestHeaders(handshakeHeaders, headers);
  78. try {
  79. addProxyHeaders(handshakeHeaders, headers, session.getRemoteAddress().getHostName());
  80. } catch (Exception ignore) {
  81. }
  82. try {
  83. // 打开远端 websocket
  84. this.webSocketClient = new MsgWebSocketClient(endPoint, session, headers);
  85. this.webSocketClient.connect();
  86. log.info("连接成功。。。" + endPoint);
  87. } catch (Exception e) {
  88. log.error(endPoint + " 连接异常。", e);
  89. // 远端连接失败,则立即关闭
  90. // afterConnectionClosed(session, CloseStatus.SERVER_ERROR);
  91. }
  92. }
  93. @Override
  94. protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
  95. // 将客户端发送消息,发送到 远端 websocket
  96. if (!this.webSocketClient.isOpen()) {
  97. try {
  98. afterConnectionEstablished(session);
  99. } catch (Exception ignore) { }
  100. }
  101. if(this.webSocketClient.isOpen()) {
  102. this.webSocketClient.send(message.getPayload());
  103. }
  104. }
  105. @Override
  106. protected void handleBinaryMessage(WebSocketSession session, BinaryMessage message) throws Exception {
  107. // 将客户端发送消息,发送到 远端 websocket
  108. if (!this.webSocketClient.isOpen()) {
  109. try {
  110. afterConnectionEstablished(session);
  111. } catch (Exception ignore) { }
  112. }
  113. if(this.webSocketClient.isOpen()) {
  114. this.webSocketClient.send(message.getPayload());
  115. }
  116. }
  117. @Override
  118. public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
  119. log.info("断开连接。。。" + endPoint);
  120. // 关闭远端 websocket
  121. if (this.webSocketClient != null) {
  122. this.webSocketClient.close();
  123. }
  124. }
  125. /**
  126. * 请求的 Header
  127. * @param httpHeaders
  128. * @param requestHeader
  129. */
  130. protected void copyRequestHeaders(org.springframework.http.HttpHeaders httpHeaders, Map<String, String> requestHeader) {
  131. for (Map.Entry<String, List<String>> entry : httpHeaders.entrySet()) {
  132. String headerName = entry.getKey();
  133. String lowerHeaderName = headerName.toLowerCase(Locale.ENGLISH);
  134. // Remove hop-by-hop headers.
  135. if (WEBSOCKET_EXCLUDE_HEADER_NAME.contains(lowerHeaderName)) {
  136. continue;
  137. }
  138. final List<String> value = entry.getValue();
  139. if (value != null) {
  140. requestHeader.put(headerName, Joiner.on(", ").join(value));
  141. }
  142. }
  143. }
  144. /**
  145. * 代理的相关配置
  146. * @param httpHeaders
  147. * @param requestHeader
  148. */
  149. protected void addProxyHeaders(org.springframework.http.HttpHeaders httpHeaders, Map<String, String> requestHeader, String remoteHostName) {
  150. try {
  151. requestHeader.put(HttpHeaders.VIA, "http/1.1 " + InetAddress.getLocalHost().getHostName());
  152. requestHeader.put(HttpHeaders.X_FORWARDED_HOST, InetAddress.getLocalHost().getHostName());
  153. } catch (Exception ignore) {
  154. }
  155. String xForwardFor = getHeader(httpHeaders, "X-Forwarded-For");
  156. if (StringUtils.isBlank(xForwardFor)) {
  157. // xForwardFor, 第一层代理
  158. requestHeader.put(HttpHeaders.X_FORWARDED_FOR, remoteHostName);
  159. } else {
  160. // xForwardFor,多层代理,将外层 IP全部 copy
  161. requestHeader.put(HttpHeaders.X_FORWARDED_FOR, xForwardFor + ", " + remoteHostName);
  162. }
  163. requestHeader.put(HttpHeaders.X_FORWARDED_HOST, getHeader(httpHeaders, HttpHeaders.HOST));
  164. }
  165. /**
  166. * 返回 Header
  167. * @param httpHeaders
  168. * @param headerName
  169. * @return
  170. */
  171. protected String getHeader(org.springframework.http.HttpHeaders httpHeaders, String headerName) {
  172. final List<String> valuesAsList = httpHeaders.getValuesAsList(headerName);
  173. return valuesAsList.size() > 0 ? Joiner.on(", ").join(valuesAsList) : "";
  174. }
  175. static class MsgWebSocketClient extends WebSocketClient {
  176. /**
  177. * client ref
  178. */
  179. final WebSocketSession session;
  180. /**
  181. * 发起请求的 Header
  182. */
  183. final Map<String, String> httpHeaders;
  184. /**
  185. * 远端服务器返回的 Header
  186. */
  187. final Map<String, String> responseHeaders = new HashMap<>();
  188. public MsgWebSocketClient(String url, WebSocketSession session,
  189. Map<String, String> httpHeaders) throws URISyntaxException {
  190. super(new URI(url), httpHeaders); // 以 client 的 Header 访问 remote,否则部分有认证的,无法通过认证
  191. this.httpHeaders = httpHeaders;
  192. log.info("======= WebSocket Request Headers =======");
  193. for (Map.Entry<String, String> entry : httpHeaders.entrySet()) {
  194. log.info(entry.getKey() + ": " + entry.getValue());
  195. }
  196. log.info("========================================");
  197. this.setConnectionLostTimeout(30000);
  198. this.session = session;
  199. }
  200. @Override
  201. public void onOpen(ServerHandshake shake) {
  202. log.info("远端 {} 握手成功...", getURI());
  203. log.info("====== WebSocket Response Headers ======");
  204. for (Iterator<String> it = shake.iterateHttpFields(); it.hasNext();) {
  205. String key = it.next();
  206. responseHeaders.put(key, shake.getFieldValue(key));
  207. log.info(key + ": " + shake.getFieldValue(key));
  208. }
  209. log.info("========================================");
  210. }
  211. @Override
  212. public void onMessage(String paramString) {
  213. log.info("receive message: {} remote: {}", paramString, getURI());
  214. // String result = "【websocket消息】【" + DateTime.now().toString("yyyy-MM-dd HH:mm:ss") + "】收到客户端消息: " +
  215. // paramString;
  216. try {
  217. session.sendMessage(new TextMessage(paramString));
  218. } catch (Exception e) {
  219. log.error("WS发送消息异常。", e);
  220. }
  221. }
  222. @Override
  223. public void onMessage(ByteBuffer bytes) {
  224. log.info("receive binary message length: {} remote: {}", bytes.position(), getURI());
  225. try {
  226. session.sendMessage(new BinaryMessage(bytes));
  227. } catch (Exception e) {
  228. log.error("WS发送消息异常。", e);
  229. }
  230. }
  231. @Override
  232. public void onClose(int paramInt, String paramString, boolean paramBoolean) {
  233. log.info("close remote, reason: {} .", paramString);
  234. if (session != null) {
  235. try {
  236. session.close(CloseStatus.SESSION_NOT_RELIABLE);
  237. } catch (Exception e) {
  238. }
  239. }
  240. }
  241. @Override
  242. public void onError(Exception e) {
  243. log.error("WS:" + getURI() + " 异常。", e);
  244. }
  245. }
  246. }