/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package cn.exlive.video.handler; import com.google.common.net.HttpHeaders; import org.apache.commons.lang.StringUtils; import java.net.InetAddress; import java.net.URI; import java.net.URISyntaxException; import java.nio.ByteBuffer; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Set; import lombok.extern.slf4j.Slf4j; import org.java_websocket.client.WebSocketClient; import org.java_websocket.handshake.ServerHandshake; import org.springframework.web.socket.BinaryMessage; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.handler.AbstractWebSocketHandler; import com.google.common.base.Joiner; import com.google.common.collect.ImmutableSet; /** * * WebSocket 代理的核心类,将 Client 的 WS 请求,转发到后台的 WebSocket 服务器,并把服务器的响应返回给 Client * *
 *
 * Created by zhenqin.
 * User: zhenqin
 * Date: 2023/3/17
 * Time: 下午3:31
 *
 * 
* * @author zhenqin */ @Slf4j public class ProxyWebSocketHandler extends AbstractWebSocketHandler { /** * WebSocket Proxy 需要移除的 Header */ final static Set WEBSOCKET_EXCLUDE_HEADER_NAME = ImmutableSet.of("sec-websocket-version", "sec-websocket-extensions"); /** * 远端 WebSocket 目标点 */ final String endPoint; /** * 代理远端的 websocket */ MsgWebSocketClient webSocketClient; public ProxyWebSocketHandler(String endPoint) { this.endPoint = endPoint; } @Override public void afterConnectionEstablished(WebSocketSession session) throws Exception { final org.springframework.http.HttpHeaders handshakeHeaders = session.getHandshakeHeaders(); final Map headers = new HashMap<>(); copyRequestHeaders(handshakeHeaders, headers); try { addProxyHeaders(handshakeHeaders, headers, session.getRemoteAddress().getHostName()); } catch (Exception ignore) { } try { // 打开远端 websocket this.webSocketClient = new MsgWebSocketClient(endPoint, session, headers); this.webSocketClient.connect(); log.info("连接成功。。。" + endPoint); } catch (Exception e) { log.error(endPoint + " 连接异常。", e); // 远端连接失败,则立即关闭 // afterConnectionClosed(session, CloseStatus.SERVER_ERROR); } } @Override protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception { // 将客户端发送消息,发送到 远端 websocket if (!this.webSocketClient.isOpen()) { try { afterConnectionEstablished(session); } catch (Exception ignore) { } } if(this.webSocketClient.isOpen()) { this.webSocketClient.send(message.getPayload()); } } @Override protected void handleBinaryMessage(WebSocketSession session, BinaryMessage message) throws Exception { // 将客户端发送消息,发送到 远端 websocket if (!this.webSocketClient.isOpen()) { try { afterConnectionEstablished(session); } catch (Exception ignore) { } } if(this.webSocketClient.isOpen()) { this.webSocketClient.send(message.getPayload()); } } @Override public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception { log.info("断开连接。。。" + endPoint); // 关闭远端 websocket if (this.webSocketClient != null) { this.webSocketClient.close(); } } /** * 请求的 Header * @param httpHeaders * @param requestHeader */ protected void copyRequestHeaders(org.springframework.http.HttpHeaders httpHeaders, Map requestHeader) { for (Map.Entry> entry : httpHeaders.entrySet()) { String headerName = entry.getKey(); String lowerHeaderName = headerName.toLowerCase(Locale.ENGLISH); // Remove hop-by-hop headers. if (WEBSOCKET_EXCLUDE_HEADER_NAME.contains(lowerHeaderName)) { continue; } final List value = entry.getValue(); if (value != null) { requestHeader.put(headerName, Joiner.on(", ").join(value)); } } } /** * 代理的相关配置 * @param httpHeaders * @param requestHeader */ protected void addProxyHeaders(org.springframework.http.HttpHeaders httpHeaders, Map requestHeader, String remoteHostName) { try { requestHeader.put(HttpHeaders.VIA, "http/1.1 " + InetAddress.getLocalHost().getHostName()); requestHeader.put(HttpHeaders.X_FORWARDED_HOST, InetAddress.getLocalHost().getHostName()); } catch (Exception ignore) { } String xForwardFor = getHeader(httpHeaders, "X-Forwarded-For"); if (StringUtils.isBlank(xForwardFor)) { // xForwardFor, 第一层代理 requestHeader.put(HttpHeaders.X_FORWARDED_FOR, remoteHostName); } else { // xForwardFor,多层代理,将外层 IP全部 copy requestHeader.put(HttpHeaders.X_FORWARDED_FOR, xForwardFor + ", " + remoteHostName); } requestHeader.put(HttpHeaders.X_FORWARDED_HOST, getHeader(httpHeaders, HttpHeaders.HOST)); } /** * 返回 Header * @param httpHeaders * @param headerName * @return */ protected String getHeader(org.springframework.http.HttpHeaders httpHeaders, String headerName) { final List valuesAsList = httpHeaders.getValuesAsList(headerName); return valuesAsList.size() > 0 ? Joiner.on(", ").join(valuesAsList) : ""; } static class MsgWebSocketClient extends WebSocketClient { /** * client ref */ final WebSocketSession session; /** * 发起请求的 Header */ final Map httpHeaders; /** * 远端服务器返回的 Header */ final Map responseHeaders = new HashMap<>(); public MsgWebSocketClient(String url, WebSocketSession session, Map httpHeaders) throws URISyntaxException { super(new URI(url), httpHeaders); // 以 client 的 Header 访问 remote,否则部分有认证的,无法通过认证 this.httpHeaders = httpHeaders; log.info("======= WebSocket Request Headers ======="); for (Map.Entry entry : httpHeaders.entrySet()) { log.info(entry.getKey() + ": " + entry.getValue()); } log.info("========================================"); this.setConnectionLostTimeout(30000); this.session = session; } @Override public void onOpen(ServerHandshake shake) { log.info("远端 {} 握手成功...", getURI()); log.info("====== WebSocket Response Headers ======"); for (Iterator it = shake.iterateHttpFields(); it.hasNext();) { String key = it.next(); responseHeaders.put(key, shake.getFieldValue(key)); log.info(key + ": " + shake.getFieldValue(key)); } log.info("========================================"); } @Override public void onMessage(String paramString) { log.info("receive message: {} remote: {}", paramString, getURI()); // String result = "【websocket消息】【" + DateTime.now().toString("yyyy-MM-dd HH:mm:ss") + "】收到客户端消息: " + // paramString; try { session.sendMessage(new TextMessage(paramString)); } catch (Exception e) { log.error("WS发送消息异常。", e); } } @Override public void onMessage(ByteBuffer bytes) { log.info("receive binary message length: {} remote: {}", bytes.position(), getURI()); try { session.sendMessage(new BinaryMessage(bytes)); } catch (Exception e) { log.error("WS发送消息异常。", e); } } @Override public void onClose(int paramInt, String paramString, boolean paramBoolean) { log.info("close remote, reason: {} .", paramString); if (session != null) { try { session.close(CloseStatus.SESSION_NOT_RELIABLE); } catch (Exception e) { } } } @Override public void onError(Exception e) { log.error("WS:" + getURI() + " 异常。", e); } } }