Browse Source

feat:【ai 大模型】增加联网搜索功能

YunaiV 10 months ago
parent
commit
9b2f2f581b

+ 15 - 2
yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/AiChatMessageController.http

@@ -39,14 +39,27 @@ Authorization: {{token}}
 tenant-id: {{adminTenantId}}
 
 {
-  "conversationId": "1781604279872581797",
+  "conversationId": "1781604279872581799",
   "content": "说下图片里,有哪些字?",
   "useContext": true
 }
 
+### 发送消息(流式)【联网搜索】
+POST {{baseUrl}}/ai/chat/message/send-stream
+Content-Type: application/json
+Authorization: {{token}}
+tenant-id: {{adminTenantId}}
+
+{
+  "conversationId": "1781604279872581799",
+  "content": "今天是周几?",
+  "useSearch": true
+}
+
 ### 获得指定对话的消息列表
-GET {{baseUrl}}/ai/chat/message/list-by-conversation-id?conversationId=1781604279872581649
+GET {{baseUrl}}/ai/chat/message/list-by-conversation-id?conversationId=1781604279872581799
 Authorization: {{token}}
+tenant-id: {{adminTenantId}}
 
 ### 删除消息
 DELETE {{baseUrl}}/ai/chat/message/delete?id=50

+ 4 - 0
yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageRespVO.java

@@ -1,5 +1,6 @@
 package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message;
 
+import cn.iocoder.yudao.module.ai.framework.ai.core.webserch.AiWebSearchResponse;
 import io.swagger.v3.oas.annotations.media.Schema;
 import lombok.Data;
 
@@ -49,6 +50,9 @@ public class AiChatMessageRespVO {
     @Schema(description = "知识库段落数组")
     private List<KnowledgeSegment> segments;
 
+    @Schema(description = "联网搜索的网页内容数组")
+    private List<AiWebSearchResponse.WebPage> webSearchPages;
+
     @Schema(description = "附件 URL 数组", example = "https://www.iocoder.cn/1.png")
     private List<String> attachmentUrls;
 

+ 4 - 0
yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageSendRespVO.java

@@ -1,5 +1,6 @@
 package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message;
 
+import cn.iocoder.yudao.module.ai.framework.ai.core.webserch.AiWebSearchResponse;
 import io.swagger.v3.oas.annotations.media.Schema;
 import lombok.Data;
 
@@ -38,6 +39,9 @@ public class AiChatMessageSendRespVO {
         @Schema(description = "知识库段落数组")
         private List<AiChatMessageRespVO.KnowledgeSegment> segments;
 
+        @Schema(description = "联网搜索的网页内容数组")
+        private List<AiWebSearchResponse.WebPage> webSearchPages;
+
         @Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED)
         private LocalDateTime createTime;
 

+ 12 - 1
yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java

@@ -6,11 +6,16 @@ import cn.iocoder.yudao.framework.mybatis.core.type.StringListTypeHandler;
 import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
+import cn.iocoder.yudao.module.ai.framework.ai.core.webserch.AiWebSearchResponse;
 import com.baomidou.mybatisplus.annotation.KeySequence;
 import com.baomidou.mybatisplus.annotation.TableField;
 import com.baomidou.mybatisplus.annotation.TableId;
 import com.baomidou.mybatisplus.annotation.TableName;
-import lombok.*;
+import com.baomidou.mybatisplus.extension.handlers.JacksonTypeHandler;
+import lombok.AllArgsConstructor;
+import lombok.Builder;
+import lombok.Data;
+import lombok.NoArgsConstructor;
 import org.springframework.ai.chat.messages.MessageType;
 
 import java.util.List;
@@ -106,6 +111,12 @@ public class AiChatMessageDO extends BaseDO {
     @TableField(typeHandler = LongListTypeHandler.class)
     private List<Long> segmentIds;
 
+    /**
+     * 联网搜索的网页内容数组
+     */
+    @TableField(typeHandler = JacksonTypeHandler.class)
+    private List<AiWebSearchResponse.WebPage> webSearchPages;
+
     /**
      * 附件 URL 数组
      */

+ 69 - 18
yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java

@@ -23,6 +23,9 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiToolDO;
 import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper;
 import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
 import cn.iocoder.yudao.module.ai.enums.model.AiPlatformEnum;
+import cn.iocoder.yudao.module.ai.framework.ai.core.webserch.AiWebSearchClient;
+import cn.iocoder.yudao.module.ai.framework.ai.core.webserch.AiWebSearchRequest;
+import cn.iocoder.yudao.module.ai.framework.ai.core.webserch.AiWebSearchResponse;
 import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeDocumentService;
 import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeSegmentService;
 import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchReqBO;
@@ -44,6 +47,7 @@ import org.springframework.ai.chat.model.ChatResponse;
 import org.springframework.ai.chat.model.StreamingChatModel;
 import org.springframework.ai.chat.prompt.ChatOptions;
 import org.springframework.ai.chat.prompt.Prompt;
+import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.stereotype.Service;
 import org.springframework.transaction.annotation.Transactional;
 import reactor.core.publisher.Flux;
@@ -69,6 +73,11 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.CHAT_MESSAGE_N
 @Slf4j
 public class AiChatMessageServiceImpl implements AiChatMessageService {
 
+    /**
+     * 联网搜索的结束数
+     */
+    private static final Integer WEB_SEARCH_COUNT = 10;
+
     // TODO @芋艿:后续优化下对话的 Prompt 整体结构
 
     /**
@@ -78,6 +87,10 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
             "%s\n\n" + // 多个 <Reference></Reference> 的拼接
             "回答要求:\n- 避免提及你是从 <Reference></Reference> 获取的知识。";
 
+    private static final String WEB_SEARCH_USER_MESSAGE_TEMPLATE = "使用 <WebSearch></WebSearch> 标记中的内容作为本次对话的参考:\n\n" +
+            "%s\n\n" + // 多个 <WebSearch></WebSearch> 的拼接
+            "回答要求:\n- 避免提及你是从 <WebSearch></WebSearch> 获取的知识。";
+
     /**
      * 附件转 ${@link UserMessage} 的内容模版
      */
@@ -102,6 +115,10 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
     @Resource
     private AiToolService toolService;
 
+    @SuppressWarnings("SpringJavaAutowiredFieldsWarningInspection")
+    @Autowired(required = false) // 由于 yudao.ai.web-search.enable 配置项,可以关闭 AiWebSearchClient 的功能,所以这里只能不强制注入
+    private AiWebSearchClient webSearchClient;
+
     @Transactional(rollbackFor = Exception.class)
     public AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId) {
         // 1.1 校验对话存在
@@ -115,30 +132,35 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
         AiModelDO model = modalService.validateModel(conversation.getModelId());
         ChatModel chatModel = modalService.getChatModel(model.getId());
 
-        // 2. 知识库找
+        // 2.1 知识库召
         List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments = recallKnowledgeSegment(
                 sendReqVO.getContent(), conversation);
 
+        // 2.2 联网搜索
+        AiWebSearchResponse webSearchResponse = Boolean.TRUE.equals(sendReqVO.getUseSearch()) && webSearchClient != null ?
+                webSearchClient.search(new AiWebSearchRequest().setQuery(sendReqVO.getContent())
+                        .setSummary(true).setCount(WEB_SEARCH_COUNT)) : null;
+
         // 3. 插入 user 发送消息
         AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
                 userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent(), sendReqVO.getUseContext(),
-                null, sendReqVO.getAttachmentUrls());
+                null, sendReqVO.getAttachmentUrls(), null);
 
-        // 3.1 插入 assistant 接收消息
+        // 4.1 插入 assistant 接收消息
         AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
                 userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext(),
-                knowledgeSegments, null);
+                knowledgeSegments, null, webSearchResponse);
 
-        // 3.2 创建 chat 需要的 Prompt
-        Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, model, sendReqVO);
+        // 4.2 创建 chat 需要的 Prompt
+        Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, webSearchResponse, model, sendReqVO);
         ChatResponse chatResponse = chatModel.call(prompt);
 
-        // 3.3 更新响应内容
+        // 4.3 更新响应内容
         String newContent = AiUtils.getChatResponseContent(chatResponse);
         String newReasoningContent = AiUtils.getChatResponseReasoningContent(chatResponse);
         chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId())
                 .setContent(newContent).setReasoningContent(newReasoningContent));
-        // 3.4 响应结果
+        // 4.4 响应结果
         Map<Long, AiKnowledgeDocumentDO> documentMap = knowledgeDocumentService.getKnowledgeDocumentMap(
                 convertSet(knowledgeSegments, AiKnowledgeSegmentSearchRespBO::getDocumentId));
         List<AiChatMessageRespVO.KnowledgeSegment> segments = BeanUtils.toBean(knowledgeSegments,
@@ -149,7 +171,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
         return new AiChatMessageSendRespVO()
                 .setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class))
                 .setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class)
-                        .setContent(newContent).setSegments(segments));
+                        .setContent(newContent).setSegments(segments)
+                        .setWebSearchPages(webSearchResponse != null ? webSearchResponse.getLists() : null));
     }
 
     @Override
@@ -166,30 +189,36 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
         AiModelDO model = modalService.validateModel(conversation.getModelId());
         StreamingChatModel chatModel = modalService.getChatModel(model.getId());
 
-        // 2. 知识库找回
+        // 2.1 知识库找回
         List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments = recallKnowledgeSegment(
                 sendReqVO.getContent(), conversation);
 
+        // 2.2 联网搜索
+        AiWebSearchResponse webSearchResponse = Boolean.TRUE.equals(sendReqVO.getUseSearch()) && webSearchClient != null ?
+                webSearchClient.search(new AiWebSearchRequest().setQuery(sendReqVO.getContent())
+                        .setSummary(true).setCount(WEB_SEARCH_COUNT)) : null;
+
         // 3. 插入 user 发送消息
         AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
                 userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent(), sendReqVO.getUseContext(),
-                null, sendReqVO.getAttachmentUrls());
+                null, sendReqVO.getAttachmentUrls(), null);
 
         // 4.1 插入 assistant 接收消息
         AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
                 userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext(),
-                knowledgeSegments, null);
+                knowledgeSegments, null, webSearchResponse);
 
         // 4.2 构建 Prompt,并进行调用
-        Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, model, sendReqVO);
+        Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, webSearchResponse, model, sendReqVO);
         Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
 
         // 4.3 流式返回
         StringBuffer contentBuffer = new StringBuffer();
         StringBuffer reasoningContentBuffer = new StringBuffer();
         return streamResponse.map(chunk -> {
-            // 处理知识库的返回,只有首次才有
+            // 仅首次:返回知识库、联网搜索
             List<AiChatMessageRespVO.KnowledgeSegment> segments = null;
+            List<AiWebSearchResponse.WebPage> webSearchPages = null;
             if (StrUtil.isEmpty(contentBuffer)) {
                 Map<Long, AiKnowledgeDocumentDO> documentMap = TenantUtils.executeIgnore(() ->
                         knowledgeDocumentService.getKnowledgeDocumentMap(
@@ -198,6 +227,9 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
                     AiKnowledgeDocumentDO document = documentMap.get(segment.getDocumentId());
                     segment.setDocumentName(document != null ? document.getName() : null);
                 });
+                if (webSearchResponse != null) {
+                    webSearchPages = webSearchResponse.getLists();
+                }
             }
             // 响应结果
             String newContent = AiUtils.getChatResponseContent(chunk);
@@ -213,7 +245,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
                     .setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class)
                             .setContent(StrUtil.nullToDefault(newContent, "")) // 避免 null 的 情况
                             .setReasoningContent(StrUtil.nullToDefault(newReasoningContent, "")) // 避免 null 的 情况
-                            .setSegments(segments))); // 知识库返回
+                            .setSegments(segments).setWebSearchPages(webSearchPages))); // 知识库 + 联网搜索
         }).doOnComplete(() -> {
             // 忽略租户,因为 Flux 异步无法透传租户
             TenantUtils.executeIgnore(() -> chatMessageMapper.updateById(
@@ -239,7 +271,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
             return Collections.emptyList();
         }
 
-        // 2. 遍历
+        // 2. 遍历
         List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments = new ArrayList<>();
         for (Long knowledgeId : role.getKnowledgeIds()) {
             knowledgeSegments.addAll(knowledgeSegmentService.searchKnowledgeSegment(new AiKnowledgeSegmentSearchReqBO()
@@ -250,6 +282,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
 
     private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,
                                List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments,
+                               AiWebSearchResponse webSearchResponse,
                                AiModelDO model, AiChatMessageSendReqVO sendReqVO) {
         List<Message> chatMessages = new ArrayList<>();
         // 1.1 System Context 角色设定
@@ -265,6 +298,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
             if (attachmentUserMessage != null) {
                 chatMessages.add(attachmentUserMessage);
             }
+            // TODO @芋艿:历史的知识库;历史的搜索,要不要拼接?
         });
 
         // 1.3 当前 user message 新发送消息
@@ -278,7 +312,20 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
             chatMessages.add(new UserMessage(String.format(KNOWLEDGE_USER_MESSAGE_TEMPLATE, reference)));
         }
 
-        // 1.5 附件,通过 UserMessage 实现
+        // 1.5 联网搜索,通过 UserMessage 实现
+        if (webSearchResponse != null && CollUtil.isNotEmpty(webSearchResponse.getLists())) {
+            String webSearch = webSearchResponse.getLists().stream()
+                    .map(page -> {
+                        String summary = StrUtil.isNotEmpty(page.getSummary()) ?
+                                "\nSummary: " + page.getSummary() : "";
+                        return "<WebSearch title=\"" + page.getTitle() + "\" url=\"" + page.getUrl() + "\">"
+                                + StrUtil.blankToDefault(page.getSummary(), page.getSnippet()) + "</WebSearch>";
+                    })
+                    .collect(Collectors.joining("\n\n"));
+            chatMessages.add(new UserMessage(String.format(WEB_SEARCH_USER_MESSAGE_TEMPLATE, webSearch)));
+        }
+
+        // 1.6 附件,通过 UserMessage 实现
         if (CollUtil.isNotEmpty(sendReqVO.getAttachmentUrls())) {
             UserMessage attachmentUserMessage = buildAttachmentUserMessage(sendReqVO.getAttachmentUrls());
             if (attachmentUserMessage != null) {
@@ -383,12 +430,16 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
                                               AiModelDO model, Long userId, Long roleId,
                                               MessageType messageType, String content, Boolean useContext,
                                               List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments,
-                                              List<String> attachmentUrls) {
+                                              List<String> attachmentUrls,
+                                              AiWebSearchResponse webSearchResponse) {
         AiChatMessageDO message = new AiChatMessageDO().setConversationId(conversationId).setReplyId(replyId)
                 .setModel(model.getModel()).setModelId(model.getId()).setUserId(userId).setRoleId(roleId)
                 .setType(messageType.getValue()).setContent(content).setUseContext(useContext)
                 .setSegmentIds(convertList(knowledgeSegments, AiKnowledgeSegmentSearchRespBO::getId))
                 .setAttachmentUrls(attachmentUrls);
+        if (webSearchResponse != null) {
+            message.setWebSearchPages(webSearchResponse.getLists());
+        }
         message.setCreateTime(LocalDateTime.now());
         chatMessageMapper.insert(message);
         return message;