瀏覽代碼

feat:【ai 大模型】RAG 增加 rerank 模型

YunaiV 10 月之前
父節點
當前提交
c31b66b6cc

+ 55 - 17
yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentServiceImpl.java

@@ -18,6 +18,10 @@ import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeSegmentMapper;
 import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchReqBO;
 import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchRespBO;
 import cn.iocoder.yudao.module.ai.service.model.AiModelService;
+import com.alibaba.cloud.ai.dashscope.rerank.DashScopeRerankOptions;
+import com.alibaba.cloud.ai.model.RerankModel;
+import com.alibaba.cloud.ai.model.RerankRequest;
+import com.alibaba.cloud.ai.model.RerankResponse;
 import jakarta.annotation.Resource;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.ai.document.Document;
@@ -27,6 +31,7 @@ import org.springframework.ai.transformer.splitter.TokenTextSplitter;
 import org.springframework.ai.vectorstore.SearchRequest;
 import org.springframework.ai.vectorstore.VectorStore;
 import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder;
+import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.context.annotation.Lazy;
 import org.springframework.stereotype.Service;
 
@@ -36,6 +41,7 @@ import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionU
 import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
 import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_SEGMENT_CONTENT_TOO_LONG;
 import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_SEGMENT_NOT_EXISTS;
+import static org.springframework.ai.vectorstore.SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL;
 
 /**
  * AI 知识库分片 Service 实现类
@@ -55,6 +61,11 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
             VECTOR_STORE_METADATA_DOCUMENT_ID, String.class,
             VECTOR_STORE_METADATA_SEGMENT_ID, String.class);
 
+    /**
+     * Rerank 在向量检索时,检索数量 * 该系数,目的是为了提升 Rerank 的效果
+     */
+    private static final Integer RERANK_RETRIEVAL_FACTOR = 4;
+
     @Resource
     private AiKnowledgeSegmentMapper segmentMapper;
 
@@ -69,6 +80,9 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
     @Resource
     private TokenCountEstimator tokenCountEstimator;
 
+    @Autowired(required = false) // 由于 spring.ai.model.rerank 配置项,可以关闭 RerankModel 的功能,所以这里只能不强制注入
+    private RerankModel rerankModel;
+
     @Override
     public PageResult<AiKnowledgeSegmentDO> getKnowledgeSegmentPage(AiKnowledgeSegmentPageReqVO pageReqVO) {
         return segmentMapper.selectPage(pageReqVO);
@@ -211,28 +225,16 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
         // 1. 校验
         AiKnowledgeDO knowledge = knowledgeService.validateKnowledgeExists(reqBO.getKnowledgeId());
 
-        // 2.1 向量检索
-        VectorStore vectorStore = getVectorStoreById(knowledge);
-        List<Document> documents = vectorStore.similaritySearch(SearchRequest.builder()
-                .query(reqBO.getContent())
-                .topK(ObjUtil.defaultIfNull(reqBO.getTopK(), knowledge.getTopK()))
-                .similarityThreshold(
-                        ObjUtil.defaultIfNull(reqBO.getSimilarityThreshold(), knowledge.getSimilarityThreshold()))
-                .filterExpression(new FilterExpressionBuilder()
-                        .eq(VECTOR_STORE_METADATA_KNOWLEDGE_ID, reqBO.getKnowledgeId().toString())
-                        .build())
-                .build());
-        if (CollUtil.isEmpty(documents)) {
-            return ListUtil.empty();
-        }
-        // 2.2 段落召回
+        // 2. 检索
+        List<Document> documents = searchDocument(knowledge, reqBO);
+
+        // 3.1 段落召回
         List<AiKnowledgeSegmentDO> segments = segmentMapper
                 .selectListByVectorIds(convertList(documents, Document::getId));
         if (CollUtil.isEmpty(segments)) {
             return ListUtil.empty();
         }
-
-        // 3. 增加召回次数
+        // 3.2 增加召回次数
         segmentMapper.updateRetrievalCountIncrByIds(convertList(segments, AiKnowledgeSegmentDO::getId));
 
         // 4. 构建结果
@@ -249,6 +251,42 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
         return result;
     }
 
+    /**
+     * 基于 Embedding + Rerank Model,检索知识库中的文档
+     *
+     * @param knowledge 知识库
+     * @param reqBO 检索请求
+     * @return 文档列表
+     */
+    private List<Document> searchDocument(AiKnowledgeDO knowledge, AiKnowledgeSegmentSearchReqBO reqBO) {
+        VectorStore vectorStore = getVectorStoreById(knowledge);
+        Integer topK = ObjUtil.defaultIfNull(reqBO.getTopK(), knowledge.getTopK());
+        Double similarityThreshold = ObjUtil.defaultIfNull(reqBO.getSimilarityThreshold(), knowledge.getSimilarityThreshold());
+
+        // 1. 向量检索
+        int searchTopK = rerankModel != null ? topK * RERANK_RETRIEVAL_FACTOR : topK;
+        double searchSimilarityThreshold = rerankModel != null ? SIMILARITY_THRESHOLD_ACCEPT_ALL : similarityThreshold;
+        SearchRequest.Builder searchRequestBuilder = SearchRequest.builder()
+                .query(reqBO.getContent())
+                .topK(searchTopK).similarityThreshold(searchSimilarityThreshold)
+                .filterExpression(new FilterExpressionBuilder()
+                        .eq(VECTOR_STORE_METADATA_KNOWLEDGE_ID, reqBO.getKnowledgeId().toString()).build());
+        List<Document> documents = vectorStore.similaritySearch(searchRequestBuilder.build());
+        if (CollUtil.isEmpty(documents)) {
+            return documents;
+        }
+
+        // 2. Rerank 重排序
+        if (rerankModel != null) {
+            RerankResponse rerankResponse = rerankModel.call(new RerankRequest(reqBO.getContent(), documents,
+                    DashScopeRerankOptions.builder().withTopN(topK).build()));
+            documents = convertList(rerankResponse.getResults(),
+                    documentWithScore -> documentWithScore.getScore() >= similarityThreshold
+                            ? documentWithScore.getOutput() : null);
+        }
+        return documents;
+    }
+
     @Override
     public List<AiKnowledgeSegmentDO> splitContent(String url, Integer segmentMaxTokens) {
         // 1. 读取 URL 内容

+ 37 - 0
yudao-module-ai/src/test/java/cn/iocoder/yudao/module/ai/framework/ai/core/model/chat/TongYiChatModelTests.java

@@ -1,8 +1,15 @@
 package cn.iocoder.yudao.module.ai.framework.ai.core.model.chat;
 
+import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
 import com.alibaba.cloud.ai.dashscope.api.DashScopeApi;
 import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatModel;
 import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatOptions;
+import com.alibaba.cloud.ai.dashscope.rerank.DashScopeRerankModel;
+import com.alibaba.cloud.ai.dashscope.rerank.DashScopeRerankOptions;
+import com.alibaba.cloud.ai.model.RerankModel;
+import com.alibaba.cloud.ai.model.RerankOptions;
+import com.alibaba.cloud.ai.model.RerankRequest;
+import com.alibaba.cloud.ai.model.RerankResponse;
 import org.junit.jupiter.api.Disabled;
 import org.junit.jupiter.api.Test;
 import org.springframework.ai.chat.messages.Message;
@@ -10,11 +17,14 @@ import org.springframework.ai.chat.messages.SystemMessage;
 import org.springframework.ai.chat.messages.UserMessage;
 import org.springframework.ai.chat.model.ChatResponse;
 import org.springframework.ai.chat.prompt.Prompt;
+import org.springframework.ai.document.Document;
 import reactor.core.publisher.Flux;
 
 import java.util.ArrayList;
 import java.util.List;
 
+import static java.util.Arrays.asList;
+
 /**
  * {@link DashScopeChatModel} 集成测试类
  *
@@ -89,4 +99,31 @@ public class TongYiChatModelTests {
         }).then().block();
     }
 
+    @Test
+    @Disabled
+    public void testRerank() {
+        // 准备环境
+        RerankModel rerankModel = new DashScopeRerankModel(
+                DashScopeApi.builder()
+                        .apiKey("sk-47aa124781be4bfb95244cc62f63f7d0")
+                        .build());
+        // 准备参数
+        String query = "spring";
+        Document document01 = new Document("abc");
+        Document document02 = new Document("sapring");
+        RerankOptions options = DashScopeRerankOptions.builder()
+                .withTopN(1)
+                .withModel("gte-rerank-v2")
+                .build();
+        RerankRequest rerankRequest = new RerankRequest(
+                query,
+                asList(document01, document02),
+                options);
+
+        // 调用
+        RerankResponse call = rerankModel.call(rerankRequest);
+        // 打印结果
+        System.out.println(JsonUtils.toJsonPrettyString(call));
+    }
+
 }

+ 3 - 1
yudao-server/src/main/resources/application.yaml

@@ -184,7 +184,7 @@ spring:
     stabilityai:
       api-key: sk-e53UqbboF8QJCscYvzJscJxJXoFcFg4iJjl1oqgE7baJETmx
     dashscope: # 通义千问
-      api-key: sk-71800982914041848008480000000000
+      api-key: sk-47aa124781be4bfb95244cc62f6xxxx
     minimax: # Minimax:https://www.minimaxi.com/
       api-key: xxxx
     moonshot: # 月之暗灭(KIMI)
@@ -194,6 +194,8 @@ spring:
       chat:
         options:
           model: deepseek-chat
+    model:
+      rerank: dashscope # 是否开启“通义千问”的 Rerank 模型,填写 dashscope 开启
 
 yudao:
   ai: