WebSocketConfig.java 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. package cn.exlive.video.config;
  2. import java.util.Map;
  3. import java.util.concurrent.ConcurrentHashMap;
  4. import cn.exlive.video.handler.ProxyWebSocketHandler;
  5. import lombok.Getter;
  6. import lombok.Setter;
  7. import lombok.extern.slf4j.Slf4j;
  8. import org.springframework.context.annotation.Bean;
  9. import org.springframework.context.annotation.Configuration;
  10. import org.springframework.context.annotation.DependsOn;
  11. import org.springframework.web.socket.config.annotation.EnableWebSocket;
  12. import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
  13. import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
  14. import org.springframework.web.socket.server.standard.ServerEndpointExporter;
  15. /**
  16. * 开启WebSocket支持,支持 WebSocket 路由代理配置类
  17. *
  18. * @author cxf
  19. */
  20. @Slf4j
  21. @Setter
  22. @Getter
  23. @EnableWebSocket
  24. @Configuration
  25. public class WebSocketConfig {
  26. /**
  27. * 所有 websocket 的特殊处理
  28. */
  29. final static Map<String, ApiRoute> WEBSOCKET_ACTION_MAPPING = new ConcurrentHashMap<>();
  30. /**
  31. * 添加 WebSocket 地址
  32. * @param path
  33. * @param route
  34. */
  35. public static void addWebSocketRouter(String path, ApiRoute route) {
  36. WEBSOCKET_ACTION_MAPPING.put(path, route);
  37. }
  38. /**
  39. * 注入ServerEndpointExporter,
  40. * 这个bean会自动注册使用了@ServerEndpoint注解声明的Websocket endpoint
  41. */
  42. // @Bean
  43. public ServerEndpointExporter serverEndpointExporter() {
  44. return new ServerEndpointExporter();
  45. }
  46. @Bean
  47. @DependsOn("simpleRouteLocator")
  48. public WebSocketConfigurer webSocketConfigurer() {
  49. return (WebSocketHandlerRegistry registry) -> {
  50. for (Map.Entry<String, ApiRoute> entry : WEBSOCKET_ACTION_MAPPING.entrySet()) {
  51. registry.addHandler(new ProxyWebSocketHandler(entry.getValue().getUrl()), entry.getValue().getApi()) // 设置连接路径和处理
  52. .setAllowedOrigins("*");
  53. }
  54. };
  55. }
  56. }