import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.web.socket.server.standard.ServerEndpointExporter; /** * websocket 配置 * * @author ruoyi */ @Configuration public class WebSocketConfig { @Bean public ServerEndpointExporter serverEndpointExporter() { return new ServerEndpointExporter(); } }
工具类 SemaphoreUtils
import java.util.concurrent.Semaphore; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * 信号量相关处理 * * @author ruoyi */ public class SemaphoreUtils{ /** * SemaphoreUtils 日志控制器 */ private static final Logger LOGGER = LoggerFactory.getLogger(SemaphoreUtils.class); /** * 获取信号量 * * @param semaphore * @return */ public static boolean tryAcquire(Semaphore semaphore) { boolean flag = false; try { flag = semaphore.tryAcquire(); } catch (Exception e) { LOGGER.error("获取信号量异常", e); } return flag; } /** * 释放信号量 * * @param semaphore */ public static void release(Semaphore semaphore) { try { semaphore.release(); } catch (Exception e) { LOGGER.error("释放信号量异常", e); } } }服务端类WebSocketServer
import java.util.concurrent.Semaphore; import javax.websocket.OnClose; import javax.websocket.OnError; import javax.websocket.OnMessage; import javax.websocket.OnOpen; import javax.websocket.Session; import javax.websocket.server.ServerEndpoint; import com.lxh.demo.util.SemaphoreUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.stereotype.Component; /** * websocket 消息处理 * * @author ruoyi */ @Component @ServerEndpoint("/websocket/message") public class WebSocketServer { /** * WebSocketServer 日志控制器 */ private static final Logger LOGGER = LoggerFactory.getLogger(WebSocketServer.class); /** * 默认最多允许同时在线人数100 */ public static int socketMaxOnlineCount = 100; private static Semaphore socketSemaphore = new Semaphore(socketMaxOnlineCount); /** * 连接建立成功调用的方法 */ @OnOpen public void onOpen(Session session) throws Exception{ boolean semaphoreFlag = false; // 尝试获取信号量 semaphoreFlag = SemaphoreUtils.tryAcquire(socketSemaphore); if (!semaphoreFlag) { // 未获取到信号量 LOGGER.error("\n 当前在线人数超过限制数- {}", socketMaxOnlineCount); WebSocketUsers.sendMessageToUserByText(session, "当前在线人数超过限制数:" + socketMaxOnlineCount); session.close(); } else { // 添加用户 WebSocketUsers.put(session.getId(), session); LOGGER.info("\n 建立连接 - {}", session); LOGGER.info("\n 当前人数 - {}", WebSocketUsers.getUsers().size()); WebSocketUsers.sendMessageToUserByText(session, "连接成功"); } } /** * 连接关闭时处理 */ @OnClose public void onClose(Session session) { LOGGER.info("\n 关闭连接 - {}", session); // 移除用户 WebSocketUsers.remove(session.getId()); // 获取到信号量则需释放 SemaphoreUtils.release(socketSemaphore); } /** * 抛出异常时处理 */ @OnError public void onError(Session session, Throwable exception) throws Exception { if (session.isOpen()) { // 关闭连接 session.close(); } String sessionId = session.getId(); LOGGER.info("\n 连接异常 - {}", sessionId); LOGGER.info("\n 异常信息 - {}", exception); // 移出用户 WebSocketUsers.remove(sessionId); // 获取到信号量则需释放 SemaphoreUtils.release(socketSemaphore); } /** * 服务器接收到客户端消息时调用的方法 */ @OnMessage public void onMessage(String message, Session session) { String msg = message.replace("你", "我").replace("吗", ""); WebSocketUsers.sendMessageToUserByText(session, msg); } }WebSocketUsers工具类
import java.io.IOException; import java.util.Collection; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import javax.websocket.Session; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * websocket 客户端用户集 * * @author ruoyi */ public class WebSocketUsers { /** * WebSocketUsers 日志控制器 */ private static final Logger LOGGER = LoggerFactory.getLogger(WebSocketUsers.class); /** * 用户集 */ private static Map<String, Session> USERS = new ConcurrentHashMap<String, Session>(); /** * 存储用户 * * @param key 唯一键 * @param session 用户信息 */ public static void put(String key, Session session) { USERS.put(key, session); } /** * 移除用户 * * @param session 用户信息 * * @return 移除结果 */ public static boolean remove(Session session) { String key = null; boolean flag = USERS.containsValue(session); if (flag) { Set<Map.Entry<String, Session>> entries = USERS.entrySet(); for (Map.Entry<String, Session> entry : entries) { Session value = entry.getValue(); if (value.equals(session)) { key = entry.getKey(); break; } } } else { return true; } return remove(key); } /** * 移出用户 * * @param key 键 */ public static boolean remove(String key) { LOGGER.info("\n 正在移出用户 - {}", key); Session remove = USERS.remove(key); if (remove != null) { boolean containsValue = USERS.containsValue(remove); LOGGER.info("\n 移出结果 - {}", containsValue ? "失败" : "成功"); return containsValue; } else { return true; } } /** * 获取在线用户列表 * * @return 返回用户集合 */ public static Map<String, Session> getUsers() { return USERS; } /** * 群发消息文本消息 * * @param message 消息内容 */ public static void sendMessageToUsersByText(String message) { Collection<Session> values = USERS.values(); for (Session value : values) { sendMessageToUserByText(value, message); } } /** * 发送文本消息 * * @param session 缓存 * @param message 消息内容 */ public static void sendMessageToUserByText(Session session, String message) { if (session != null) { try { session.getBasicRemote().sendText(message); } catch (IOException e) { LOGGER.error("\n[发送消息异常]", e); } } else { LOGGER.info("\n[你已离线]"); } } }Html 页面代码
<!DOCTYPE html> <html lang="zh" xmlns:th=""> <head> <meta charset="utf-8"> <meta http-equiv="X-UA-Compatible" content="IE=edge"> <title>测试界面</title> </head> <body> <div> <input type="text" style="width: 20%" value="ws://127.0.0.1/websocket/message" id="url"> <button id="btn_join">连接</button> <button id="btn_exit">断开</button> </div> <br/> <textarea id="message" cols="100" rows="9"></textarea> <button id="btn_send">发送消息</button> <br/> <br/> <textarea id="text_content" readonly="readonly" cols="100" rows="9"></textarea>返回内容 <br/> <br/> <script th:src="@{/js/jquery.min.js}" ></script> <script type="text/javascript"> $(document).ready(function(){ var ws = null; // 连接 $('#btn_join').click(function() { var url = $("#url").val(); ws = new WebSocket(url); ws.onopen = function(event) { $('#text_content').append('已经打开连接!' + '\n'); } ws.onmessage = function(event) { $('#text_content').append(event.data + '\n'); } ws.onclose = function(event) { $('#text_content').append('已经关闭连接!' + '\n'); } }); // 发送消息 $('#btn_send').click(function() { var message = $('#message').val(); if (ws) { ws.send(message); } else { alert("未连接到服务器"); } }); //断开 $('#btn_exit').click(function() { if (ws) { ws.close(); ws = null; } }); }) </script> </body> </html>成功运行后,页面如下
注意此时没有走用户认证,那么就要对路径放行,因为若依框架用的是SpringSecurity,所以找到文件SecurityConfig.java ,进行路径放行
虽然按着上述步骤我们完成了浏览器(客户端)和Java(服务端)的WebSocket通信,但是我们不能限定哪些用户可以连接我们的服务端获取数据,服务端也不知道应该具体给哪些用户发送消息,在我们框架之前交互我们是通过浏览器传递toke 值来实现用户身份确认的,那么我们的WebSocket可不可以也这样呢?
很不幸的是 ws连接是无法像http一样完全自主定义请求头的,给token认证带来了不便,我们大致可以通过以下集中方式完成用户认证
1、将 token 明文携带在 url 中,例如ws://localhost:8080/weggo/websocket/message?Authorization=Bearer+token
2、通过websocket下的子协议来实现,Stomp这个协议来实现,前端采用SocketJs框架来实现对应定制请求头。实现携带authorization=Bearer +token 的需求,这样就可以正常建立连接
3、利用子协议数组,将 token 携带在 protocols 里,var ws = new WebSocket(url, ["token"]);
这样后端在 onOpen 事件中,就可以从 server 中读取 Sec-WebSocket-Protocol 属性来进行 token 的获取,具体可以参考WebScoket构造函数官方文档
var aWebSocket = new WebSocket(url [, protocols]); url 要连接的URL;这应该是WebSocket服务器将响应的URL。 protocols 可选 一个协议字符串或者一个包含协议字符串的数组。这些字符串用于指定子协议,这样单个服务器可以实现多个WebSocket子协议 (例如,您可能希望一台服务器能够根据指定的协议(protocol)处理不同类型的交互)。如果不指定协议字符串,则假定为空字符串。protocols对应的就是发起ws连接时, 携带在请求头中的Sec-WebSocket-Protocol属性, 服务端可以获取到此属性的值用于通信逻辑(即通信子协议,当然用来进行token认证也是完全没问题的),前端人员在请求头上携带sec-websocket-protocol=Bearer +token后台在请求到达oauth2之前进行拦截,然后将在请求头上添加Authorization=Bearer +token(key首字母大写),然后在响应头(respone)上添加sec-websocket-protocol=Bearer +token(不添加会报错)
方法3部分代码示例
//前端 var aWebSocket = new WebSocket(url ['用户token']); //后端 @Override public void afterConnectionEstablished(WebSocketSession session) throws Exception { //这里就是我们所提交的token String submitedToken=session.getHandshakeHeaders().get("sec-websocket-protocol").get(0); //根据token取得登录用户信息(业务逻辑根据你自己的来处理) }另外,如果需要在第一次握手前的时候就取得token,只需要在header里面取得就可以啦
@Override public boolean beforeHandshake(ServerHttpRequest serverHttpRequest, ServerHttpResponse serverHttpResponse, WebSocketHandler webSocketHandler, Map<String, Object> map) throws Exception { System.out.println("准备握手"); String submitedToken = serverHttpRequest.getHeaders().get("sec-websocket-protocol") return true; }因为我的项目是APP 移动端与服务端进行交互,所以后来选择了最简单实现的方案一
首先要解决的就是在拦截器获取url 的token 信息,原框架只从head里面获取,所以需要稍加改动
找到TokenService.java文件里的getToken方法,改成如下,这样就可以获取url 中的token 了又不影响原来的Http 请求
private String getToken(HttpServletRequest request) { String token = Optional.ofNullable(request.getHeader(header)).orElse(request.getParameter(header)); if (StringUtils.isNotEmpty(token) && token.startsWith(Constants.TOKEN_PREFIX)) { token = token.replace(Constants.TOKEN_PREFIX, ""); } return token; }接下来就是需要对我们的WebSocket类进行改造了,为了方便阅读,去除了WebSocketUsers类,添加了类变量webSocketSet来存储客户端对象
import com.alibaba.fastjson2.JSON; import com.tongchuang.common.utils.SecurityUtils; import com.tongchuang.web.mqtt.domain.DeviceInfo; import io.netty.util.HashedWheelTimer; import io.netty.util.Timeout; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.security.core.Authentication; import org.springframework.stereotype.Component; import javax.websocket.*; import javax.websocket.server.PathParam; import javax.websocket.server.ServerEndpoint; import java.io.IOException; import java.util.HashMap; import java.util.Map; import java.util.concurrent.CopyOnWriteArraySet; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; /** * websocket 消息处理 * * @author stronger */ @Component @ServerEndpoint("/websocket/message") public class WebSocketServer { /*========================声明类变量,意在所有实例共享=================================================*/ /** * WebSocketServer 日志控制器 */ private static final Logger LOGGER = LoggerFactory.getLogger(WebSocketServer.class); /** * 默认最多允许同时在线人数100 */ public static int socketMaxOnlineCount = 100; private static Semaphore socketSemaphore = new Semaphore(socketMaxOnlineCount); HashedWheelTimer timer = new HashedWheelTimer(1, TimeUnit.SECONDS, 8); /** * concurrent包的线程安全Set,用来存放每个客户端对应的MyWebSocket对象。 */ private static final CopyOnWriteArraySet<WebSocketServer> webSocketSet = new CopyOnWriteArraySet<>(); /** * 连接数 */ private static final AtomicInteger count = new AtomicInteger(); /*========================声明实例变量,意在每个实例独享=======================================================*/ /** * 与某个客户端的连接会话,需要通过它来给客户端发送数据 */ private Session session; /** * 用户id */ private String sid = ""; /** * 连接建立成功调用的方法 */ @OnOpen public void onOpen(Session session) throws Exception { // 尝试获取信号量 boolean semaphoreFlag = SemaphoreUtils.tryAcquire(socketSemaphore); if (!semaphoreFlag) { // 未获取到信号量 LOGGER.error("\n 当前在线人数超过限制数- {}", socketMaxOnlineCount); // 给当前Session 登录用户发送消息 sendMessageToUserByText(session, "当前在线人数超过限制数:" + socketMaxOnlineCount); session.close(); } else { // 返回此会话的经过身份验证的用户,如果此会话没有经过身份验证的用户,则返回null Authentication authentication = (Authentication) session.getUserPrincipal(); SecurityUtils.setAuthentication(authentication); String username = SecurityUtils.getUsername(); this.session = session; //如果存在就先删除一个,防止重复推送消息 for (WebSocketServer webSocket : webSocketSet) { if (webSocket.sid.equals(username)) { webSocketSet.remove(webSocket); count.getAndDecrement(); } } count.getAndIncrement(); webSocketSet.add(this); this.sid = username; LOGGER.info("\n 当前人数 - {}", count); sendMessageToUserByText(session, "连接成功"); } } /** * 连接关闭时处理 */ @OnClose public void onClose(Session session) { LOGGER.info("\n 关闭连接 - {}", session); // 移除用户 webSocketSet.remove(session); // 获取到信号量则需释放 SemaphoreUtils.release(socketSemaphore); } /** * 抛出异常时处理 */ @OnError public void onError(Session session, Throwable exception) throws Exception { if (session.isOpen()) { // 关闭连接 session.close(); } String sessionId = session.getId(); LOGGER.info("\n 连接异常 - {}", sessionId); LOGGER.info("\n 异常信息 - {}", exception); // 移出用户 webSocketSet.remove(session); // 获取到信号量则需释放 SemaphoreUtils.release(socketSemaphore); } /** * 服务器接收到客户端消息时调用的方法 */ @OnMessage public void onMessage(String message, Session session) { Authentication authentication = (Authentication) session.getUserPrincipal(); LOGGER.info("收到来自" + sid + "的信息:" + message); // 实时更新 this.refresh(sid, authentication); sendMessageToUserByText(session, "我收到了你的新消息哦"); } /** * 刷新定时任务,发送信息 */ private void refresh(String userId, Authentication authentication) { this.start(5000L, task -> { // 判断用户是否在线,不在线则不用处理,因为在内部无法关闭该定时任务,所以通过返回值在外部进行判断。 if (WebSocketServer.isConn(userId)) { // 因为这里是长链接,不会和普通网页一样,每次发送http 请求可以走拦截器【doFilterInternal】续约,所以需要手动续约 SecurityUtils.setAuthentication(authentication); // 从数据库或者缓存中获取信息,构建自定义的Bean DeviceInfo deviceInfo = DeviceInfo.builder().Macaddress("de5a735951ee").Imei("351517175516665") .Battery("99").Charge("0").Latitude("116.402649").Latitude("39.914859").Altitude("80") .Method(SecurityUtils.getUsername()).build(); // TODO判断数据是否有更新 // 发送最新数据给前端 WebSocketServer.sendInfo("JSON", deviceInfo, userId); // 设置返回值,判断是否需要继续执行 return true; } return false; }); } private void start(long delay, Function<Timeout, Boolean> function) { timer.newTimeout(t -> { // 获取返回值,判断是否执行 Boolean result = function.apply(t); if (result) { timer.newTimeout(t.task(), delay, TimeUnit.MILLISECONDS); } }, delay, TimeUnit.MILLISECONDS); } /** * 判断是否有链接 * * @return */ public static boolean isConn(String sid) { for (WebSocketServer item : webSocketSet) { if (item.sid.equals(sid)) { return true; } } return false; } /** * 群发自定义消息 * 或者指定用户发送消息 */ public static void sendInfo(String type, Object data, @PathParam("sid") String sid) { // 遍历WebSocketServer对象集合,如果符合条件就推送 for (WebSocketServer item : webSocketSet) { try { //这里可以设定只推送给这个sid的,为null则全部推送 if (sid == null) { item.sendMessage(type, data); } else if (item.sid.equals(sid)) { item.sendMessage(type, data); } } catch (IOException ignored) { } } } /** * 实现服务器主动推送 */ private void sendMessage(String type, Object data) throws IOException { Map<String, Object> result = new HashMap<>(); result.put("type", type); result.put("data", data); this.session.getAsyncRemote().sendText(JSON.toJSONString(result)); } /** * 实现服务器主动推送-根据session */ public static void sendMessageToUserByText(Session session, String message) { if (session != null) { try { session.getBasicRemote().sendText(message); } catch (IOException e) { LOGGER.error("\n[发送消息异常]", e); } } else { LOGGER.info("\n[你已离线]"); } } }