|
|
@@ -23,6 +23,9 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiToolDO;
|
|
|
import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper;
|
|
|
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
|
|
|
import cn.iocoder.yudao.module.ai.enums.model.AiPlatformEnum;
|
|
|
+import cn.iocoder.yudao.module.ai.framework.ai.core.webserch.AiWebSearchClient;
|
|
|
+import cn.iocoder.yudao.module.ai.framework.ai.core.webserch.AiWebSearchRequest;
|
|
|
+import cn.iocoder.yudao.module.ai.framework.ai.core.webserch.AiWebSearchResponse;
|
|
|
import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeDocumentService;
|
|
|
import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeSegmentService;
|
|
|
import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchReqBO;
|
|
|
@@ -44,6 +47,7 @@ import org.springframework.ai.chat.model.ChatResponse;
|
|
|
import org.springframework.ai.chat.model.StreamingChatModel;
|
|
|
import org.springframework.ai.chat.prompt.ChatOptions;
|
|
|
import org.springframework.ai.chat.prompt.Prompt;
|
|
|
+import org.springframework.beans.factory.annotation.Autowired;
|
|
|
import org.springframework.stereotype.Service;
|
|
|
import org.springframework.transaction.annotation.Transactional;
|
|
|
import reactor.core.publisher.Flux;
|
|
|
@@ -69,6 +73,11 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.CHAT_MESSAGE_N
|
|
|
@Slf4j
|
|
|
public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|
|
|
|
|
+ /**
|
|
|
+ * 联网搜索的结束数
|
|
|
+ */
|
|
|
+ private static final Integer WEB_SEARCH_COUNT = 10;
|
|
|
+
|
|
|
// TODO @芋艿:后续优化下对话的 Prompt 整体结构
|
|
|
|
|
|
/**
|
|
|
@@ -78,6 +87,10 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|
|
"%s\n\n" + // 多个 <Reference></Reference> 的拼接
|
|
|
"回答要求:\n- 避免提及你是从 <Reference></Reference> 获取的知识。";
|
|
|
|
|
|
+ private static final String WEB_SEARCH_USER_MESSAGE_TEMPLATE = "使用 <WebSearch></WebSearch> 标记中的内容作为本次对话的参考:\n\n" +
|
|
|
+ "%s\n\n" + // 多个 <WebSearch></WebSearch> 的拼接
|
|
|
+ "回答要求:\n- 避免提及你是从 <WebSearch></WebSearch> 获取的知识。";
|
|
|
+
|
|
|
/**
|
|
|
* 附件转 ${@link UserMessage} 的内容模版
|
|
|
*/
|
|
|
@@ -102,6 +115,10 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|
|
@Resource
|
|
|
private AiToolService toolService;
|
|
|
|
|
|
+ @SuppressWarnings("SpringJavaAutowiredFieldsWarningInspection")
|
|
|
+ @Autowired(required = false) // 由于 yudao.ai.web-search.enable 配置项,可以关闭 AiWebSearchClient 的功能,所以这里只能不强制注入
|
|
|
+ private AiWebSearchClient webSearchClient;
|
|
|
+
|
|
|
@Transactional(rollbackFor = Exception.class)
|
|
|
public AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId) {
|
|
|
// 1.1 校验对话存在
|
|
|
@@ -115,30 +132,35 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|
|
AiModelDO model = modalService.validateModel(conversation.getModelId());
|
|
|
ChatModel chatModel = modalService.getChatModel(model.getId());
|
|
|
|
|
|
- // 2. 知识库找回
|
|
|
+ // 2.1 知识库召回
|
|
|
List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments = recallKnowledgeSegment(
|
|
|
sendReqVO.getContent(), conversation);
|
|
|
|
|
|
+ // 2.2 联网搜索
|
|
|
+ AiWebSearchResponse webSearchResponse = Boolean.TRUE.equals(sendReqVO.getUseSearch()) && webSearchClient != null ?
|
|
|
+ webSearchClient.search(new AiWebSearchRequest().setQuery(sendReqVO.getContent())
|
|
|
+ .setSummary(true).setCount(WEB_SEARCH_COUNT)) : null;
|
|
|
+
|
|
|
// 3. 插入 user 发送消息
|
|
|
AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
|
|
|
userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent(), sendReqVO.getUseContext(),
|
|
|
- null, sendReqVO.getAttachmentUrls());
|
|
|
+ null, sendReqVO.getAttachmentUrls(), null);
|
|
|
|
|
|
- // 3.1 插入 assistant 接收消息
|
|
|
+ // 4.1 插入 assistant 接收消息
|
|
|
AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
|
|
|
userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext(),
|
|
|
- knowledgeSegments, null);
|
|
|
+ knowledgeSegments, null, webSearchResponse);
|
|
|
|
|
|
- // 3.2 创建 chat 需要的 Prompt
|
|
|
- Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, model, sendReqVO);
|
|
|
+ // 4.2 创建 chat 需要的 Prompt
|
|
|
+ Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, webSearchResponse, model, sendReqVO);
|
|
|
ChatResponse chatResponse = chatModel.call(prompt);
|
|
|
|
|
|
- // 3.3 更新响应内容
|
|
|
+ // 4.3 更新响应内容
|
|
|
String newContent = AiUtils.getChatResponseContent(chatResponse);
|
|
|
String newReasoningContent = AiUtils.getChatResponseReasoningContent(chatResponse);
|
|
|
chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId())
|
|
|
.setContent(newContent).setReasoningContent(newReasoningContent));
|
|
|
- // 3.4 响应结果
|
|
|
+ // 4.4 响应结果
|
|
|
Map<Long, AiKnowledgeDocumentDO> documentMap = knowledgeDocumentService.getKnowledgeDocumentMap(
|
|
|
convertSet(knowledgeSegments, AiKnowledgeSegmentSearchRespBO::getDocumentId));
|
|
|
List<AiChatMessageRespVO.KnowledgeSegment> segments = BeanUtils.toBean(knowledgeSegments,
|
|
|
@@ -149,7 +171,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|
|
return new AiChatMessageSendRespVO()
|
|
|
.setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class))
|
|
|
.setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class)
|
|
|
- .setContent(newContent).setSegments(segments));
|
|
|
+ .setContent(newContent).setSegments(segments)
|
|
|
+ .setWebSearchPages(webSearchResponse != null ? webSearchResponse.getLists() : null));
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
@@ -166,30 +189,36 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|
|
AiModelDO model = modalService.validateModel(conversation.getModelId());
|
|
|
StreamingChatModel chatModel = modalService.getChatModel(model.getId());
|
|
|
|
|
|
- // 2. 知识库找回
|
|
|
+ // 2.1 知识库找回
|
|
|
List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments = recallKnowledgeSegment(
|
|
|
sendReqVO.getContent(), conversation);
|
|
|
|
|
|
+ // 2.2 联网搜索
|
|
|
+ AiWebSearchResponse webSearchResponse = Boolean.TRUE.equals(sendReqVO.getUseSearch()) && webSearchClient != null ?
|
|
|
+ webSearchClient.search(new AiWebSearchRequest().setQuery(sendReqVO.getContent())
|
|
|
+ .setSummary(true).setCount(WEB_SEARCH_COUNT)) : null;
|
|
|
+
|
|
|
// 3. 插入 user 发送消息
|
|
|
AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
|
|
|
userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent(), sendReqVO.getUseContext(),
|
|
|
- null, sendReqVO.getAttachmentUrls());
|
|
|
+ null, sendReqVO.getAttachmentUrls(), null);
|
|
|
|
|
|
// 4.1 插入 assistant 接收消息
|
|
|
AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
|
|
|
userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext(),
|
|
|
- knowledgeSegments, null);
|
|
|
+ knowledgeSegments, null, webSearchResponse);
|
|
|
|
|
|
// 4.2 构建 Prompt,并进行调用
|
|
|
- Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, model, sendReqVO);
|
|
|
+ Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, webSearchResponse, model, sendReqVO);
|
|
|
Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
|
|
|
|
|
|
// 4.3 流式返回
|
|
|
StringBuffer contentBuffer = new StringBuffer();
|
|
|
StringBuffer reasoningContentBuffer = new StringBuffer();
|
|
|
return streamResponse.map(chunk -> {
|
|
|
- // 处理知识库的返回,只有首次才有
|
|
|
+ // 仅首次:返回知识库、联网搜索
|
|
|
List<AiChatMessageRespVO.KnowledgeSegment> segments = null;
|
|
|
+ List<AiWebSearchResponse.WebPage> webSearchPages = null;
|
|
|
if (StrUtil.isEmpty(contentBuffer)) {
|
|
|
Map<Long, AiKnowledgeDocumentDO> documentMap = TenantUtils.executeIgnore(() ->
|
|
|
knowledgeDocumentService.getKnowledgeDocumentMap(
|
|
|
@@ -198,6 +227,9 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|
|
AiKnowledgeDocumentDO document = documentMap.get(segment.getDocumentId());
|
|
|
segment.setDocumentName(document != null ? document.getName() : null);
|
|
|
});
|
|
|
+ if (webSearchResponse != null) {
|
|
|
+ webSearchPages = webSearchResponse.getLists();
|
|
|
+ }
|
|
|
}
|
|
|
// 响应结果
|
|
|
String newContent = AiUtils.getChatResponseContent(chunk);
|
|
|
@@ -213,7 +245,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|
|
.setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class)
|
|
|
.setContent(StrUtil.nullToDefault(newContent, "")) // 避免 null 的 情况
|
|
|
.setReasoningContent(StrUtil.nullToDefault(newReasoningContent, "")) // 避免 null 的 情况
|
|
|
- .setSegments(segments))); // 知识库返回
|
|
|
+ .setSegments(segments).setWebSearchPages(webSearchPages))); // 知识库 + 联网搜索
|
|
|
}).doOnComplete(() -> {
|
|
|
// 忽略租户,因为 Flux 异步无法透传租户
|
|
|
TenantUtils.executeIgnore(() -> chatMessageMapper.updateById(
|
|
|
@@ -239,7 +271,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|
|
return Collections.emptyList();
|
|
|
}
|
|
|
|
|
|
- // 2. 遍历找回
|
|
|
+ // 2. 遍历召回
|
|
|
List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments = new ArrayList<>();
|
|
|
for (Long knowledgeId : role.getKnowledgeIds()) {
|
|
|
knowledgeSegments.addAll(knowledgeSegmentService.searchKnowledgeSegment(new AiKnowledgeSegmentSearchReqBO()
|
|
|
@@ -250,6 +282,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|
|
|
|
|
private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,
|
|
|
List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments,
|
|
|
+ AiWebSearchResponse webSearchResponse,
|
|
|
AiModelDO model, AiChatMessageSendReqVO sendReqVO) {
|
|
|
List<Message> chatMessages = new ArrayList<>();
|
|
|
// 1.1 System Context 角色设定
|
|
|
@@ -265,6 +298,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|
|
if (attachmentUserMessage != null) {
|
|
|
chatMessages.add(attachmentUserMessage);
|
|
|
}
|
|
|
+ // TODO @芋艿:历史的知识库;历史的搜索,要不要拼接?
|
|
|
});
|
|
|
|
|
|
// 1.3 当前 user message 新发送消息
|
|
|
@@ -278,7 +312,20 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|
|
chatMessages.add(new UserMessage(String.format(KNOWLEDGE_USER_MESSAGE_TEMPLATE, reference)));
|
|
|
}
|
|
|
|
|
|
- // 1.5 附件,通过 UserMessage 实现
|
|
|
+ // 1.5 联网搜索,通过 UserMessage 实现
|
|
|
+ if (webSearchResponse != null && CollUtil.isNotEmpty(webSearchResponse.getLists())) {
|
|
|
+ String webSearch = webSearchResponse.getLists().stream()
|
|
|
+ .map(page -> {
|
|
|
+ String summary = StrUtil.isNotEmpty(page.getSummary()) ?
|
|
|
+ "\nSummary: " + page.getSummary() : "";
|
|
|
+ return "<WebSearch title=\"" + page.getTitle() + "\" url=\"" + page.getUrl() + "\">"
|
|
|
+ + StrUtil.blankToDefault(page.getSummary(), page.getSnippet()) + "</WebSearch>";
|
|
|
+ })
|
|
|
+ .collect(Collectors.joining("\n\n"));
|
|
|
+ chatMessages.add(new UserMessage(String.format(WEB_SEARCH_USER_MESSAGE_TEMPLATE, webSearch)));
|
|
|
+ }
|
|
|
+
|
|
|
+ // 1.6 附件,通过 UserMessage 实现
|
|
|
if (CollUtil.isNotEmpty(sendReqVO.getAttachmentUrls())) {
|
|
|
UserMessage attachmentUserMessage = buildAttachmentUserMessage(sendReqVO.getAttachmentUrls());
|
|
|
if (attachmentUserMessage != null) {
|
|
|
@@ -383,12 +430,16 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|
|
AiModelDO model, Long userId, Long roleId,
|
|
|
MessageType messageType, String content, Boolean useContext,
|
|
|
List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments,
|
|
|
- List<String> attachmentUrls) {
|
|
|
+ List<String> attachmentUrls,
|
|
|
+ AiWebSearchResponse webSearchResponse) {
|
|
|
AiChatMessageDO message = new AiChatMessageDO().setConversationId(conversationId).setReplyId(replyId)
|
|
|
.setModel(model.getModel()).setModelId(model.getId()).setUserId(userId).setRoleId(roleId)
|
|
|
.setType(messageType.getValue()).setContent(content).setUseContext(useContext)
|
|
|
.setSegmentIds(convertList(knowledgeSegments, AiKnowledgeSegmentSearchRespBO::getId))
|
|
|
.setAttachmentUrls(attachmentUrls);
|
|
|
+ if (webSearchResponse != null) {
|
|
|
+ message.setWebSearchPages(webSearchResponse.getLists());
|
|
|
+ }
|
|
|
message.setCreateTime(LocalDateTime.now());
|
|
|
chatMessageMapper.insert(message);
|
|
|
return message;
|