package com.ruoyi.ws.service; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import javax.websocket.Session; import org.springframework.stereotype.Service; import com.alibaba.fastjson2.JSON; import lombok.extern.slf4j.Slf4j; /** * WebSocket工具类 */ @Service @Slf4j public class UserWebSocketService { // 使用ConcurrentHashMap存储会话 private static final Map SESSION_POOL = new ConcurrentHashMap<>(); // 创建线程池 private static final ExecutorService MESSAGE_EXECUTOR = Executors.newFixedThreadPool(10); // 批处理大小 private static final int BATCH_SIZE = 100; // 最大重试次数 private static final int MAX_RETRY = 3; /** * 添加用户WebSocket会话 */ public void addSession(Long userId, Session session) { if (userId != null && session != null) { SESSION_POOL.put(userId, session); log.info("用户{}的WebSocket会话已添加", userId); } } /** * 移除用户WebSocket会话 */ public void removeSession(Long userId) { if (userId != null) { SESSION_POOL.remove(userId); log.info("用户{}的WebSocket会话已移除", userId); } } /** * 获取用户WebSocket会话 */ public Session getSession(Long userId) { return userId != null ? SESSION_POOL.get(userId) : null; } /** * 向指定用户发送消息 * * @param message 消息内容 * @param userIds 目标用户ID数组 */ public void sendMessageToUser(Object message, Long... userIds) { if (message == null) { return; } this.sendMessageToUser(message, Arrays.asList(userIds)); } /** * 向指定用户发送消息 * * @param message 消息内容 * @param userIds 目标用户ID数组 */ public void sendMessageToUser(Object message, Collection userIds) { if (message == null) { return; } String msg = JSON.toJSONString(message); for (Long userId : userIds) { Session session = this.getSession(userId); if (session != null) { try { if (session.isOpen()) { session.getBasicRemote().sendText(msg); log.info("向用户{}发送消息成功: {}", userId, msg); } } catch (Exception e) { log.error("发送消息给用户{}失败", userId, e); } } } } /** * 向所有在线用户发送消息 * * @param message 消息内容 */ public void sendMessageToAll(Object message) { if (message == null) { return; } String msg = JSON.toJSONString(message); // 将在线用户分批处理 List>> batches = new ArrayList<>(); List> currentBatch = new ArrayList<>(); for (Map.Entry entry : SESSION_POOL.entrySet()) { currentBatch.add(entry); if (currentBatch.size() >= BATCH_SIZE) { batches.add(new ArrayList<>(currentBatch)); currentBatch.clear(); } } if (!currentBatch.isEmpty()) { batches.add(currentBatch); } // 并行处理每个批次 List> futures = new ArrayList<>(); for (List> batch : batches) { CompletableFuture future = CompletableFuture.runAsync(() -> { processBatch(batch, msg); }, MESSAGE_EXECUTOR); futures.add(future); } // 等待所有批次处理完成 CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join(); log.info("向所有在线用户发送消息完成: {}", msg); } /** * 处理一个批次的消息发送 */ private void processBatch(List> batch, String msg) { for (Map.Entry entry : batch) { Long userId = entry.getKey(); Session session = entry.getValue(); if (session != null) { sendMessageWithRetry(userId, session, msg, 0); } } } /** * 带重试机制的消息发送 */ private void sendMessageWithRetry(Long userId, Session session, String msg, int retryCount) { try { synchronized (session) { if (session.isOpen()) { session.getBasicRemote().sendText(msg); } } } catch (Exception e) { if (retryCount < MAX_RETRY) { log.warn("向用户{}发送消息失败,准备第{}次重试", userId, retryCount + 1); try { Thread.sleep(100 * (retryCount + 1)); // 递增延迟重试 sendMessageWithRetry(userId, session, msg, retryCount + 1); } catch (InterruptedException ie) { Thread.currentThread().interrupt(); } } else { log.error("向用户{}发送全局消息失败,已重试{}次", userId, MAX_RETRY, e); } } } }