Browse Source

feat:【ai】226 修改知识库和联网搜索仅初始化一次

YunaiV 5 months ago
parent
commit
9a3bfe89e5

+ 18 - 12
yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java

@@ -59,6 +59,8 @@ import reactor.core.publisher.Flux;
 
 import java.time.LocalDateTime;
 import java.util.*;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
 import java.util.stream.Collectors;
 
 import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
@@ -231,20 +233,24 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
         // 4.3 流式返回
         StringBuffer contentBuffer = new StringBuffer();
         StringBuffer reasoningContentBuffer = new StringBuffer();
+
+        // 防止执行多次知识库和联网搜索
+        AtomicBoolean firstExecuteFlag = new AtomicBoolean(true);
+        AtomicReference<List<AiChatMessageRespVO.KnowledgeSegment>> cacheSegments = new AtomicReference<>();
+        AtomicReference<List<AiWebSearchResponse.WebPage>> cacheWebSearchPages = new AtomicReference<>();
         return streamResponse.map(chunk -> {
             // 仅首次:返回知识库、联网搜索
-            List<AiChatMessageRespVO.KnowledgeSegment> segments = null;
-            List<AiWebSearchResponse.WebPage> webSearchPages = null;
             if (StrUtil.isEmpty(contentBuffer)) {
-                Map<Long, AiKnowledgeDocumentDO> documentMap = TenantUtils.executeIgnore(() ->
-                        knowledgeDocumentService.getKnowledgeDocumentMap(
-                                convertSet(knowledgeSegments, AiKnowledgeSegmentSearchRespBO::getDocumentId)));
-                segments = BeanUtils.toBean(knowledgeSegments, AiChatMessageRespVO.KnowledgeSegment.class, segment ->  {
-                    AiKnowledgeDocumentDO document = documentMap.get(segment.getDocumentId());
-                    segment.setDocumentName(document != null ? document.getName() : null);
-                });
-                if (webSearchResponse != null) {
-                    webSearchPages = webSearchResponse.getLists();
+                if (firstExecuteFlag.compareAndSet(true, false)) { // CAS 操作,确保仅执行一次
+                    Map<Long, AiKnowledgeDocumentDO> documentMap = TenantUtils.executeIgnore(() -> knowledgeDocumentService.getKnowledgeDocumentMap(
+                            convertSet(knowledgeSegments, AiKnowledgeSegmentSearchRespBO::getDocumentId)));
+                    cacheSegments.set(BeanUtils.toBean(knowledgeSegments, AiChatMessageRespVO.KnowledgeSegment.class, segment -> {
+                        AiKnowledgeDocumentDO document = documentMap.get(segment.getDocumentId());
+                        segment.setDocumentName(document != null ? document.getName() : null);
+                    }));
+                    if (webSearchResponse != null) {
+                        cacheWebSearchPages.set(webSearchResponse.getLists());
+                    }
                 }
             }
             // 响应结果
@@ -261,7 +267,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
                     .setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class)
                             .setContent(StrUtil.nullToDefault(newContent, "")) // 避免 null 的 情况
                             .setReasoningContent(StrUtil.nullToDefault(newReasoningContent, "")) // 避免 null 的 情况
-                            .setSegments(segments).setWebSearchPages(webSearchPages))); // 知识库 + 联网搜索
+                            .setSegments(cacheSegments.get()).setWebSearchPages(cacheWebSearchPages.get()))); // 知识库 + 联网搜索
         }).doOnComplete(() -> {
             // 忽略租户,因为 Flux 异步无法透传租户
             TenantUtils.executeIgnore(() -> chatMessageMapper.updateById(