Browse Source

feat(iot):【协议改造】tcp 初步改造(100%):基于 code review 进一步完善,对应 iot-tcp-fix-plan.md

YunaiV 4 months ago
parent
commit
09041a24d7

+ 9 - 6
yudao-module-iot/yudao-module-iot-gateway/src/main/java/cn/iocoder/yudao/module/iot/gateway/protocol/tcp/IotTcpProtocol.java

@@ -20,6 +20,7 @@ import io.vertx.core.net.NetServerOptions;
 import io.vertx.core.net.PemKeyCertOptions;
 import lombok.Getter;
 import lombok.extern.slf4j.Slf4j;
+import org.springframework.util.Assert;
 
 /**
  * IoT TCP 协议实现
@@ -89,11 +90,12 @@ public class IotTcpProtocol implements IotProtocol {
         this.serializer = serializerManager.get(serializeType);
         // 初始化帧编解码器
         IotTcpConfig tcpConfig = properties.getTcp();
-        IotTcpConfig.CodecConfig codecConfig = tcpConfig != null ? tcpConfig.getCodec() : null;
-        this.frameCodec = IotTcpFrameCodecFactory.create(codecConfig);
+        Assert.notNull(tcpConfig, "TCP 协议配置(tcp)不能为空");
+        Assert.notNull(tcpConfig.getCodec(), "TCP 拆包配置(tcp.codec)不能为空");
+        this.frameCodec = IotTcpFrameCodecFactory.create(tcpConfig.getCodec());
 
         // 初始化连接管理器
-        this.connectionManager = new IotTcpConnectionManager();
+        this.connectionManager = new IotTcpConnectionManager(tcpConfig.getMaxConnections());
 
         // 初始化下行消息订阅者
         IotTcpDownstreamHandler downstreamHandler = new IotTcpDownstreamHandler(connectionManager, frameCodec, serializer);
@@ -117,7 +119,7 @@ public class IotTcpProtocol implements IotProtocol {
             return;
         }
 
-        // 1.1 创建 Vertx 实例(每个 Protocol 独立管理)
+        // 1.1 创建 Vertx 实例
         this.vertx = Vertx.vertx();
 
         // 1.2 创建服务器选项
@@ -126,8 +128,9 @@ public class IotTcpProtocol implements IotProtocol {
                 .setPort(properties.getPort())
                 .setTcpKeepAlive(true)
                 .setTcpNoDelay(true)
-                .setReuseAddress(true);
-        if (tcpConfig != null && Boolean.TRUE.equals(tcpConfig.getSslEnabled())) {
+                .setReuseAddress(true)
+                .setIdleTimeout((int) (tcpConfig.getKeepAliveTimeoutMs() / 1000)); // 设置空闲超时
+        if (Boolean.TRUE.equals(tcpConfig.getSslEnabled())) {
             PemKeyCertOptions pemKeyCertOptions = new PemKeyCertOptions()
                     .setKeyPath(tcpConfig.getSslKeyPath())
                     .setCertPath(tcpConfig.getSslCertPath());

+ 8 - 1
yudao-module-iot/yudao-module-iot-gateway/src/main/java/cn/iocoder/yudao/module/iot/gateway/protocol/tcp/codec/delimiter/IotTcpDelimiterFrameCodec.java

@@ -1,5 +1,6 @@
 package cn.iocoder.yudao.module.iot.gateway.protocol.tcp.codec.delimiter;
 
+import cn.hutool.core.util.StrUtil;
 import cn.iocoder.yudao.module.iot.gateway.protocol.tcp.IotTcpConfig;
 import cn.iocoder.yudao.module.iot.gateway.protocol.tcp.codec.IotTcpCodecTypeEnum;
 import cn.iocoder.yudao.module.iot.gateway.protocol.tcp.codec.IotTcpFrameCodec;
@@ -27,6 +28,11 @@ import org.springframework.util.Assert;
 @Slf4j
 public class IotTcpDelimiterFrameCodec implements IotTcpFrameCodec {
 
+    /**
+     * 最大记录大小(64KB),防止 DoS 攻击
+     */
+    private static final int MAX_RECORD_SIZE = 65536;
+
     /**
      * 解析后的分隔符字节数组
      */
@@ -45,6 +51,7 @@ public class IotTcpDelimiterFrameCodec implements IotTcpFrameCodec {
     @Override
     public RecordParser createDecodeParser(Handler<Buffer> handler) {
         RecordParser parser = RecordParser.newDelimited(Buffer.buffer(delimiterBytes));
+        parser.maxRecordSize(MAX_RECORD_SIZE); // 设置最大记录大小,防止 DoS 攻击
         // 处理完整消息(不包含分隔符)
         parser.handler(handler);
         parser.exceptionHandler(ex -> {
@@ -76,7 +83,7 @@ public class IotTcpDelimiterFrameCodec implements IotTcpFrameCodec {
                 .replace("\\r", "\r")
                 .replace("\\n", "\n")
                 .replace("\\t", "\t");
-        return parsed.getBytes();
+        return StrUtil.utf8Bytes(parsed);
     }
 
 }

+ 5 - 0
yudao-module-iot/yudao-module-iot-gateway/src/main/java/cn/iocoder/yudao/module/iot/gateway/protocol/tcp/codec/length/IotTcpFixedLengthFrameCodec.java

@@ -46,6 +46,11 @@ public class IotTcpFixedLengthFrameCodec implements IotTcpFrameCodec {
 
     @Override
     public Buffer encode(byte[] data) {
+        // 校验数据长度不能超过固定长度
+        if (data.length > fixedLength) {
+            throw new IllegalArgumentException(String.format(
+                    "数据长度 %d 超过固定长度 %d", data.length, fixedLength));
+        }
         Buffer buffer = Buffer.buffer(fixedLength);
         buffer.appendBytes(data);
         // 如果数据不足固定长度,填充 0(RecordParser.newFixed 解码时按固定长度读取,所以发送端需要填充)

+ 36 - 16
yudao-module-iot/yudao-module-iot-gateway/src/main/java/cn/iocoder/yudao/module/iot/gateway/protocol/tcp/handler/upstream/IotTcpUpstreamHandler.java

@@ -25,7 +25,7 @@ import io.vertx.core.parsetools.RecordParser;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.util.Assert;
 
-import static cn.iocoder.yudao.framework.common.exception.enums.GlobalErrorCodeConstants.SUCCESS;
+import static cn.iocoder.yudao.framework.common.exception.enums.GlobalErrorCodeConstants.*;
 import static cn.iocoder.yudao.framework.common.exception.enums.GlobalErrorCodeConstants.UNAUTHORIZED;
 
 /**
@@ -95,8 +95,8 @@ public class IotTcpUpstreamHandler implements Handler<NetSocket> {
             try {
                 processMessage(clientId, buffer, socket);
             } catch (Exception e) {
-                log.error("[handle][消息处理失败,客户端 ID: {},地址: {},错误: {}]",
-                        clientId, socket.remoteAddress(), e.getMessage());
+                log.error("[handle][消息处理失败,客户端 ID: {},地址: {}]",
+                        clientId, socket.remoteAddress(), e);
                 socket.close();
             }
         };
@@ -114,20 +114,40 @@ public class IotTcpUpstreamHandler implements Handler<NetSocket> {
      * @param socket   网络连接
      */
     private void processMessage(String clientId, Buffer buffer, NetSocket socket) {
-        // 1. 反序列化消息
-        IotDeviceMessage message = serializer.deserialize(buffer.getBytes());
-        Assert.notNull(message, "反序列化后消息为空");
+        IotDeviceMessage message = null;
+        try {
+            // 1. 反序列化消息
+            message = serializer.deserialize(buffer.getBytes());
+            if (message == null) {
+                sendErrorResponse(socket, null, null, BAD_REQUEST.getCode(), "消息反序列化失败");
+                return;
+            }
 
-        // 2. 根据消息类型路由处理
-        if (AUTH_METHOD.equals(message.getMethod())) {
-            // 认证请求
-            handleAuthenticationRequest(clientId, message, socket);
-        } else if (IotDeviceMessageMethodEnum.DEVICE_REGISTER.getMethod().equals(message.getMethod())) {
-            // 设备动态注册请求
-            handleRegisterRequest(clientId, message, socket);
-        } else {
-            // 业务消息
-            handleBusinessRequest(clientId, message, socket);
+            // 2. 根据消息类型路由处理
+            if (AUTH_METHOD.equals(message.getMethod())) {
+                // 认证请求
+                handleAuthenticationRequest(clientId, message, socket);
+            } else if (IotDeviceMessageMethodEnum.DEVICE_REGISTER.getMethod().equals(message.getMethod())) {
+                // 设备动态注册请求
+                handleRegisterRequest(clientId, message, socket);
+            } else {
+                // 业务消息
+                handleBusinessRequest(clientId, message, socket);
+            }
+        } catch (IllegalArgumentException e) {
+            // 参数校验失败,返回 400
+            log.warn("[processMessage][参数校验失败,客户端 ID: {},错误: {}]", clientId, e.getMessage());
+            String requestId = message != null ? message.getRequestId() : null;
+            String method = message != null ? message.getMethod() : null;
+            sendErrorResponse(socket, requestId, method, BAD_REQUEST.getCode(), e.getMessage());
+        } catch (Exception e) {
+            // 其他异常,返回 500 并重新抛出让上层关闭连接
+            log.error("[processMessage][处理消息失败,客户端 ID: {}]", clientId, e);
+            String requestId = message != null ? message.getRequestId() : null;
+            String method = message != null ? message.getMethod() : null;
+            sendErrorResponse(socket, requestId, method,
+                    INTERNAL_SERVER_ERROR.getCode(), INTERNAL_SERVER_ERROR.getMsg());
+            throw e;
         }
     }
 

+ 29 - 1
yudao-module-iot/yudao-module-iot-gateway/src/main/java/cn/iocoder/yudao/module/iot/gateway/protocol/tcp/manager/IotTcpConnectionManager.java

@@ -21,6 +21,11 @@ import java.util.concurrent.ConcurrentHashMap;
 @Slf4j
 public class IotTcpConnectionManager {
 
+    /**
+     * 最大连接数
+     */
+    private final int maxConnections;
+
     /**
      * 连接信息映射:NetSocket -> 连接信息
      */
@@ -31,6 +36,24 @@ public class IotTcpConnectionManager {
      */
     private final Map<Long, NetSocket> deviceSocketMap = new ConcurrentHashMap<>();
 
+    public IotTcpConnectionManager(int maxConnections) {
+        this.maxConnections = maxConnections;
+    }
+
+    /**
+     * 获取当前连接数
+     */
+    public int getConnectionCount() {
+        return connectionMap.size();
+    }
+
+    /**
+     * 检查是否可以接受新连接
+     */
+    public boolean canAcceptConnection() {
+        return connectionMap.size() < maxConnections;
+    }
+
     /**
      * 注册设备连接(包含认证信息)
      *
@@ -39,6 +62,10 @@ public class IotTcpConnectionManager {
      * @param connectionInfo 连接信息
      */
     public void registerConnection(NetSocket socket, Long deviceId, ConnectionInfo connectionInfo) {
+        // 检查连接数是否已达上限
+        if (connectionMap.size() >= maxConnections) {
+            throw new IllegalStateException("连接数已达上限: " + maxConnections);
+        }
         // 如果设备已有其他连接,先清理旧连接
         NetSocket oldSocket = deviceSocketMap.get(deviceId);
         if (oldSocket != null && oldSocket != socket) {
@@ -67,7 +94,8 @@ public class IotTcpConnectionManager {
             return;
         }
         Long deviceId = connectionInfo.getDeviceId();
-        deviceSocketMap.remove(deviceId);
+        // 仅当 deviceSocketMap 中的 socket 是当前 socket 时才移除,避免误删新连接
+        deviceSocketMap.remove(deviceId, socket);
         log.info("[unregisterConnection][注销设备连接,设备 ID: {},连接: {}]", deviceId, socket.remoteAddress());
     }