| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278 |
- /*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- package cn.exlive.video.handler;
- import com.google.common.net.HttpHeaders;
- import org.apache.commons.lang.StringUtils;
- import java.net.InetAddress;
- import java.net.URI;
- import java.net.URISyntaxException;
- import java.nio.ByteBuffer;
- import java.util.HashMap;
- import java.util.Iterator;
- import java.util.List;
- import java.util.Locale;
- import java.util.Map;
- import java.util.Set;
- import lombok.extern.slf4j.Slf4j;
- import org.java_websocket.client.WebSocketClient;
- import org.java_websocket.handshake.ServerHandshake;
- import org.springframework.web.socket.BinaryMessage;
- import org.springframework.web.socket.CloseStatus;
- import org.springframework.web.socket.TextMessage;
- import org.springframework.web.socket.WebSocketSession;
- import org.springframework.web.socket.handler.AbstractWebSocketHandler;
- import com.google.common.base.Joiner;
- import com.google.common.collect.ImmutableSet;
- /**
- *
- * WebSocket 代理的核心类,将 Client 的 WS 请求,转发到后台的 WebSocket 服务器,并把服务器的响应返回给 Client
- *
- * <pre>
- *
- * Created by zhenqin.
- * User: zhenqin
- * Date: 2023/3/17
- * Time: 下午3:31
- *
- * </pre>
- *
- * @author zhenqin
- */
- @Slf4j
- public class ProxyWebSocketHandler extends AbstractWebSocketHandler {
- /**
- * WebSocket Proxy 需要移除的 Header
- */
- final static Set<String> WEBSOCKET_EXCLUDE_HEADER_NAME =
- ImmutableSet.of("sec-websocket-version", "sec-websocket-extensions");
- /**
- * 远端 WebSocket 目标点
- */
- final String endPoint;
- /**
- * 代理远端的 websocket
- */
- MsgWebSocketClient webSocketClient;
- public ProxyWebSocketHandler(String endPoint) {
- this.endPoint = endPoint;
- }
- @Override
- public void afterConnectionEstablished(WebSocketSession session) throws Exception {
- final org.springframework.http.HttpHeaders handshakeHeaders = session.getHandshakeHeaders();
- final Map<String, String> headers = new HashMap<>();
- copyRequestHeaders(handshakeHeaders, headers);
- try {
- addProxyHeaders(handshakeHeaders, headers, session.getRemoteAddress().getHostName());
- } catch (Exception ignore) {
- }
- try {
- // 打开远端 websocket
- this.webSocketClient = new MsgWebSocketClient(endPoint, session, headers);
- this.webSocketClient.connect();
- log.info("连接成功。。。" + endPoint);
- } catch (Exception e) {
- log.error(endPoint + " 连接异常。", e);
- // 远端连接失败,则立即关闭
- // afterConnectionClosed(session, CloseStatus.SERVER_ERROR);
- }
- }
- @Override
- protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
- // 将客户端发送消息,发送到 远端 websocket
- if (!this.webSocketClient.isOpen()) {
- try {
- afterConnectionEstablished(session);
- } catch (Exception ignore) { }
- }
- if(this.webSocketClient.isOpen()) {
- this.webSocketClient.send(message.getPayload());
- }
- }
- @Override
- protected void handleBinaryMessage(WebSocketSession session, BinaryMessage message) throws Exception {
- // 将客户端发送消息,发送到 远端 websocket
- if (!this.webSocketClient.isOpen()) {
- try {
- afterConnectionEstablished(session);
- } catch (Exception ignore) { }
- }
- if(this.webSocketClient.isOpen()) {
- this.webSocketClient.send(message.getPayload());
- }
- }
- @Override
- public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
- log.info("断开连接。。。" + endPoint);
- // 关闭远端 websocket
- if (this.webSocketClient != null) {
- this.webSocketClient.close();
- }
- }
- /**
- * 请求的 Header
- * @param httpHeaders
- * @param requestHeader
- */
- protected void copyRequestHeaders(org.springframework.http.HttpHeaders httpHeaders, Map<String, String> requestHeader) {
- for (Map.Entry<String, List<String>> entry : httpHeaders.entrySet()) {
- String headerName = entry.getKey();
- String lowerHeaderName = headerName.toLowerCase(Locale.ENGLISH);
- // Remove hop-by-hop headers.
- if (WEBSOCKET_EXCLUDE_HEADER_NAME.contains(lowerHeaderName)) {
- continue;
- }
- final List<String> value = entry.getValue();
- if (value != null) {
- requestHeader.put(headerName, Joiner.on(", ").join(value));
- }
- }
- }
- /**
- * 代理的相关配置
- * @param httpHeaders
- * @param requestHeader
- */
- protected void addProxyHeaders(org.springframework.http.HttpHeaders httpHeaders, Map<String, String> requestHeader, String remoteHostName) {
- try {
- requestHeader.put(HttpHeaders.VIA, "http/1.1 " + InetAddress.getLocalHost().getHostName());
- requestHeader.put(HttpHeaders.X_FORWARDED_HOST, InetAddress.getLocalHost().getHostName());
- } catch (Exception ignore) {
- }
- String xForwardFor = getHeader(httpHeaders, "X-Forwarded-For");
- if (StringUtils.isBlank(xForwardFor)) {
- // xForwardFor, 第一层代理
- requestHeader.put(HttpHeaders.X_FORWARDED_FOR, remoteHostName);
- } else {
- // xForwardFor,多层代理,将外层 IP全部 copy
- requestHeader.put(HttpHeaders.X_FORWARDED_FOR, xForwardFor + ", " + remoteHostName);
- }
- requestHeader.put(HttpHeaders.X_FORWARDED_HOST, getHeader(httpHeaders, HttpHeaders.HOST));
- }
- /**
- * 返回 Header
- * @param httpHeaders
- * @param headerName
- * @return
- */
- protected String getHeader(org.springframework.http.HttpHeaders httpHeaders, String headerName) {
- final List<String> valuesAsList = httpHeaders.getValuesAsList(headerName);
- return valuesAsList.size() > 0 ? Joiner.on(", ").join(valuesAsList) : "";
- }
- static class MsgWebSocketClient extends WebSocketClient {
- /**
- * client ref
- */
- final WebSocketSession session;
- /**
- * 发起请求的 Header
- */
- final Map<String, String> httpHeaders;
- /**
- * 远端服务器返回的 Header
- */
- final Map<String, String> responseHeaders = new HashMap<>();
- public MsgWebSocketClient(String url, WebSocketSession session,
- Map<String, String> httpHeaders) throws URISyntaxException {
- super(new URI(url), httpHeaders); // 以 client 的 Header 访问 remote,否则部分有认证的,无法通过认证
- this.httpHeaders = httpHeaders;
- log.info("======= WebSocket Request Headers =======");
- for (Map.Entry<String, String> entry : httpHeaders.entrySet()) {
- log.info(entry.getKey() + ": " + entry.getValue());
- }
- log.info("========================================");
- this.setConnectionLostTimeout(30000);
- this.session = session;
- }
- @Override
- public void onOpen(ServerHandshake shake) {
- log.info("远端 {} 握手成功...", getURI());
- log.info("====== WebSocket Response Headers ======");
- for (Iterator<String> it = shake.iterateHttpFields(); it.hasNext();) {
- String key = it.next();
- responseHeaders.put(key, shake.getFieldValue(key));
- log.info(key + ": " + shake.getFieldValue(key));
- }
- log.info("========================================");
- }
- @Override
- public void onMessage(String paramString) {
- log.info("receive message: {} remote: {}", paramString, getURI());
- // String result = "【websocket消息】【" + DateTime.now().toString("yyyy-MM-dd HH:mm:ss") + "】收到客户端消息: " +
- // paramString;
- try {
- session.sendMessage(new TextMessage(paramString));
- } catch (Exception e) {
- log.error("WS发送消息异常。", e);
- }
- }
- @Override
- public void onMessage(ByteBuffer bytes) {
- log.info("receive binary message length: {} remote: {}", bytes.position(), getURI());
- try {
- session.sendMessage(new BinaryMessage(bytes));
- } catch (Exception e) {
- log.error("WS发送消息异常。", e);
- }
- }
- @Override
- public void onClose(int paramInt, String paramString, boolean paramBoolean) {
- log.info("close remote, reason: {} .", paramString);
- if (session != null) {
- try {
- session.close(CloseStatus.SESSION_NOT_RELIABLE);
- } catch (Exception e) {
- }
- }
- }
- @Override
- public void onError(Exception e) {
- log.error("WS:" + getURI() + " 异常。", e);
- }
- }
- }
|