Преглед изворни кода

feat:【ai】增加智能文档切片策略,支持自动识别 Markdown QA 和语义化切分「代码优化」

YunaiV пре 6 месеци
родитељ
комит
ed136ff022

+ 1 - 1
pom.xml

@@ -23,7 +23,7 @@
 <!--        <module>yudao-module-mall</module>-->
 <!--        <module>yudao-module-crm</module>-->
 <!--        <module>yudao-module-erp</module>-->
-        <module>yudao-module-ai</module>
+<!--        <module>yudao-module-ai</module>-->
 <!--        <module>yudao-module-iot</module>-->
     </modules>
 

+ 0 - 12
yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/enums/AiDocumentSplitStrategyEnum.java

@@ -50,16 +50,4 @@ public enum AiDocumentSplitStrategyEnum {
      */
     private final String name;
 
-    /**
-     * 根据代码获取枚举
-     */
-    public static AiDocumentSplitStrategyEnum fromCode(String code) {
-        for (AiDocumentSplitStrategyEnum strategy : values()) {
-            if (strategy.getCode().equals(code)) {
-                return strategy;
-            }
-        }
-        return AUTO; // 默认返回自动识别
-    }
-
 }

+ 11 - 19
yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentServiceImpl.java

@@ -107,11 +107,8 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
             if (StrUtil.isEmpty(segment.getText())) {
                 return null;
             }
-            return new AiKnowledgeSegmentDO()
-                    .setKnowledgeId(documentDO.getKnowledgeId())
-                    .setDocumentId(documentId)
-                    .setContent(segment.getText())
-                    .setContentLength(segment.getText().length())
+            return new AiKnowledgeSegmentDO().setKnowledgeId(documentDO.getKnowledgeId()).setDocumentId(documentId)
+                    .setContent(segment.getText()).setContentLength(segment.getText().length())
                     .setVectorId(AiKnowledgeSegmentDO.VECTOR_ID_EMPTY)
                     .setTokens(tokenCountEstimator.estimate(segment.getText()))
                     .setStatus(CommonStatusEnum.ENABLE.getStatus());
@@ -302,13 +299,12 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
         // 1. 读取 URL 内容
         String content = knowledgeDocumentService.readUrl(url);
 
-        // 2. 自动检测文档类型并选择策略
+        // 2.1 自动检测文档类型并选择策略
         AiDocumentSplitStrategyEnum strategy = detectDocumentStrategy(content, url);
-
-        // 3. 文档切片
+        // 2.2 文档切片
         List<Document> documentSegments = splitContentByStrategy(content, segmentMaxTokens, strategy, url);
 
-        // 4. 转换为段落对象
+        // 3. 转换为段落对象
         return convertList(documentSegments, segment -> {
             if (StrUtil.isEmpty(segment.getText())) {
                 return null;
@@ -352,6 +348,7 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
      * @param url 文档 URL(用于自动检测文件类型)
      * @return 切片后的文档列表
      */
+    @SuppressWarnings("EnhancedSwitchMigration")
     private List<Document> splitContentByStrategy(String content, Integer segmentMaxTokens,
                                                   AiDocumentSplitStrategyEnum strategy, String url) {
         // 自动检测策略
@@ -359,7 +356,7 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
             strategy = detectDocumentStrategy(content, url);
             log.info("[splitContentByStrategy][自动检测到文档策略: {}]", strategy.getName());
         }
-
+        // 根据策略切分
         TextSplitter textSplitter;
         switch (strategy) {
             case MARKDOWN_QA:
@@ -376,7 +373,7 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
                 textSplitter = buildTokenTextSplitter(segmentMaxTokens);
                 break;
         }
-
+        // 执行切分
         return textSplitter.apply(Collections.singletonList(new Document(content)));
     }
 
@@ -391,17 +388,14 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
         if (StrUtil.isEmpty(content)) {
             return AiDocumentSplitStrategyEnum.TOKEN;
         }
-
         // 1. 检测 Markdown QA 格式
         if (isMarkdownQaFormat(content, url)) {
             return AiDocumentSplitStrategyEnum.MARKDOWN_QA;
         }
-
         // 2. 检测普通 Markdown 文档
         if (isMarkdownDocument(url)) {
             return AiDocumentSplitStrategyEnum.SEMANTIC;
         }
-
         // 3. 默认使用语义切分(比 Token 切分更智能)
         return AiDocumentSplitStrategyEnum.SEMANTIC;
     }
@@ -421,16 +415,14 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
                 .filter(line -> line.trim().startsWith("## "))
                 .count();
 
-        // 至少包含 2 个二级标题才认为是 QA 格式
+        // 要求一:至少包含 2 个二级标题才认为是 QA 格式
         if (h2Count < 2) {
             return false;
         }
 
-        // 检查标题占比(QA 文档标题行数相对较多)
+        // 要求二:检查标题占比(QA 文档标题行数相对较多),如果二级标题占比超过 10%,认为是 QA 格式
         long totalLines = content.lines().count();
         double h2Ratio = (double) h2Count / totalLines;
-
-        // 如果二级标题占比超过 10%,认为是 QA 格式
         return h2Ratio > 0.1;
     }
 
@@ -438,7 +430,7 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
      * 检测是否为 Markdown 文档
      */
     private boolean isMarkdownDocument(String url) {
-        return StrUtil.isNotEmpty(url) && url.toLowerCase().endsWith(".md");
+        return StrUtil.endWithAnyIgnoreCase(url, ".md", ".markdown");
     }
 
     /**

+ 45 - 52
yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/splitter/MarkdownQaSplitter.java

@@ -1,6 +1,8 @@
 package cn.iocoder.yudao.module.ai.service.knowledge.splitter;
 
+import cn.hutool.core.collection.CollUtil;
 import cn.hutool.core.util.StrUtil;
+import lombok.AllArgsConstructor;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.ai.transformer.splitter.TextSplitter;
 
@@ -24,6 +26,7 @@ import java.util.regex.Pattern;
  * @author runzhen
  */
 @Slf4j
+@SuppressWarnings("SizeReplaceableByIsEmpty")
 public class MarkdownQaSplitter extends TextSplitter {
 
     /**
@@ -62,41 +65,38 @@ public class MarkdownQaSplitter extends TextSplitter {
             return Collections.emptyList();
         }
 
-        List<String> result = new ArrayList<>();
-
         // 解析 QA 对
         List<QaPair> qaPairs = parseQaPairs(text);
-
-        if (qaPairs.isEmpty()) {
+        if (CollUtil.isEmpty(qaPairs)) {
             // 如果没有识别到 QA 格式,按段落切分
             return fallbackSplit(text);
         }
 
         // 处理每个 QA 对
+        List<String> result = new ArrayList<>();
         for (QaPair qaPair : qaPairs) {
             result.addAll(splitQaPair(qaPair));
         }
-
         return result;
     }
 
     /**
      * 解析 Markdown QA 对
+     *
+     * @param content 文本内容
+     * @return QA 对列表
      */
     private List<QaPair> parseQaPairs(String content) {
+        // 找到所有二级标题位置
         List<QaPair> qaPairs = new ArrayList<>();
-        Matcher matcher = H2_PATTERN.matcher(content);
-
         List<Integer> headingPositions = new ArrayList<>();
         List<String> questions = new ArrayList<>();
-
-        // 找到所有二级标题位置
+        Matcher matcher = H2_PATTERN.matcher(content);
         while (matcher.find()) {
             headingPositions.add(matcher.start());
             questions.add(matcher.group(1).trim());
         }
-
-        if (headingPositions.isEmpty()) {
+        if (CollUtil.isEmpty(headingPositions)) {
             return qaPairs;
         }
 
@@ -106,55 +106,51 @@ public class MarkdownQaSplitter extends TextSplitter {
             int end = (i + 1 < headingPositions.size())
                     ? headingPositions.get(i + 1)
                     : content.length();
-
             String qaText = content.substring(start, end).trim();
             String question = questions.get(i);
-
             // 提取答案部分(去掉问题标题)
             String answer = qaText.substring(qaText.indexOf('\n') + 1).trim();
-
             qaPairs.add(new QaPair(question, answer, qaText));
         }
-
         return qaPairs;
     }
 
     /**
      * 切分单个 QA 对
+     *
+     * @param qaPair QA 对
+     * @return 切分后的文本片段列表
      */
     private List<String> splitQaPair(QaPair qaPair) {
+        // 如果整个 QA 对不超过限制,保持完整
         List<String> chunks = new ArrayList<>();
-
         String fullQa = qaPair.fullText;
         int qaTokens = tokenEstimator.estimate(fullQa);
-
-        // 如果整个 QA 对不超过限制,保持完整
         if (qaTokens <= chunkSize) {
             chunks.add(fullQa);
             return chunks;
         }
 
         // 长答案需要切分
-        log.debug("QA 对超过 Token 限制 ({} > {}),开始智能切分: {}",
-                qaTokens, chunkSize, qaPair.question);
-
+        log.debug("QA 对超过 Token 限制 ({} > {}),开始智能切分: {}", qaTokens, chunkSize, qaPair.question);
         List<String> answerChunks = splitLongAnswer(qaPair.answer, qaPair.question);
-
         for (String answerChunk : answerChunks) {
             // 每个片段都包含完整问题
             String chunkText = "## " + qaPair.question + "\n" + answerChunk;
             chunks.add(chunkText);
         }
-
         return chunks;
     }
 
     /**
      * 切分长答案
+     *
+     * @param answer 答案文本
+     * @param question 问题文本
+     * @return 切分后的答案片段列表
      */
     private List<String> splitLongAnswer(String answer, String question) {
         List<String> chunks = new ArrayList<>();
-
         // 预留问题的 Token 空间
         String questionHeader = "## " + question + "\n";
         int questionTokens = tokenEstimator.estimate(questionHeader);
@@ -162,17 +158,13 @@ public class MarkdownQaSplitter extends TextSplitter {
 
         // 先按段落切分
         String[] paragraphs = answer.split(PARAGRAPH_SEPARATOR);
-
         StringBuilder currentChunk = new StringBuilder();
         int currentTokens = 0;
-
         for (String paragraph : paragraphs) {
             if (StrUtil.isEmpty(paragraph)) {
                 continue;
             }
-
             int paragraphTokens = tokenEstimator.estimate(paragraph);
-
             // 如果单个段落就超过限制,需要按句子切分
             if (paragraphTokens > availableTokens) {
                 // 先保存当前块
@@ -181,22 +173,20 @@ public class MarkdownQaSplitter extends TextSplitter {
                     currentChunk = new StringBuilder();
                     currentTokens = 0;
                 }
-
                 // 按句子切分长段落
                 chunks.addAll(splitLongParagraph(paragraph, availableTokens));
                 continue;
             }
-
             // 如果加上这个段落会超过限制
             if (currentTokens + paragraphTokens > availableTokens && currentChunk.length() > 0) {
                 chunks.add(currentChunk.toString().trim());
                 currentChunk = new StringBuilder();
                 currentTokens = 0;
             }
-
             if (currentChunk.length() > 0) {
                 currentChunk.append("\n\n");
             }
+            // 添加段落
             currentChunk.append(paragraph);
             currentTokens += paragraphTokens;
         }
@@ -205,27 +195,29 @@ public class MarkdownQaSplitter extends TextSplitter {
         if (currentChunk.length() > 0) {
             chunks.add(currentChunk.toString().trim());
         }
-
-        return chunks.isEmpty() ? Collections.singletonList(answer) : chunks;
+        return CollUtil.isEmpty(chunks) ? Collections.singletonList(answer) : chunks;
     }
 
     /**
      * 切分长段落(按句子)
+     *
+     * @param paragraph 段落文本
+     * @param availableTokens 可用的 Token 数
+     * @return 切分后的文本片段列表
      */
     private List<String> splitLongParagraph(String paragraph, int availableTokens) {
+        // 按句子切分
         List<String> chunks = new ArrayList<>();
         String[] sentences = SENTENCE_PATTERN.split(paragraph);
 
+        // 按句子累积切分
         StringBuilder currentChunk = new StringBuilder();
         int currentTokens = 0;
-
         for (String sentence : sentences) {
             if (StrUtil.isEmpty(sentence)) {
                 continue;
             }
-
             int sentenceTokens = tokenEstimator.estimate(sentence);
-
             // 如果单个句子就超过限制,强制切分
             if (sentenceTokens > availableTokens) {
                 if (currentChunk.length() > 0) {
@@ -236,47 +228,50 @@ public class MarkdownQaSplitter extends TextSplitter {
                 chunks.add(sentence.trim());
                 continue;
             }
-
+            // 如果加上这个句子会超过限制
             if (currentTokens + sentenceTokens > availableTokens && currentChunk.length() > 0) {
                 chunks.add(currentChunk.toString().trim());
                 currentChunk = new StringBuilder();
                 currentTokens = 0;
             }
-
+            // 添加句子
             currentChunk.append(sentence);
             currentTokens += sentenceTokens;
         }
 
+        // 添加最后一块
         if (currentChunk.length() > 0) {
             chunks.add(currentChunk.toString().trim());
         }
-
         return chunks.isEmpty() ? Collections.singletonList(paragraph) : chunks;
     }
 
     /**
      * 降级切分策略(当未识别到 QA 格式时)
+     *
+     * @param content 文本内容
+     * @return 切分后的文本片段列表
      */
     private List<String> fallbackSplit(String content) {
+        // 按段落切分
         List<String> chunks = new ArrayList<>();
         String[] paragraphs = content.split(PARAGRAPH_SEPARATOR);
 
+        // 按段落累积切分
         StringBuilder currentChunk = new StringBuilder();
         int currentTokens = 0;
-
         for (String paragraph : paragraphs) {
             if (StrUtil.isEmpty(paragraph)) {
                 continue;
             }
-
             int paragraphTokens = tokenEstimator.estimate(paragraph);
-
+            // 如果加上这个段落会超过限制
             if (currentTokens + paragraphTokens > chunkSize && currentChunk.length() > 0) {
                 chunks.add(currentChunk.toString().trim());
                 currentChunk = new StringBuilder();
                 currentTokens = 0;
             }
-
+            // 添加段落
             if (currentChunk.length() > 0) {
                 currentChunk.append("\n\n");
             }
@@ -284,33 +279,32 @@ public class MarkdownQaSplitter extends TextSplitter {
             currentTokens += paragraphTokens;
         }
 
+        // 添加最后一块
         if (currentChunk.length() > 0) {
             chunks.add(currentChunk.toString().trim());
         }
-
         return chunks.isEmpty() ? Collections.singletonList(content) : chunks;
     }
 
     /**
      * QA 对数据结构
      */
+    @AllArgsConstructor
     private static class QaPair {
+
         String question;
         String answer;
         String fullText;
 
-        QaPair(String question, String answer, String fullText) {
-            this.question = question;
-            this.answer = answer;
-            this.fullText = fullText;
-        }
     }
 
     /**
      * Token 估算器接口
      */
     public interface TokenEstimator {
+
         int estimate(String text);
+
     }
 
     /**
@@ -319,6 +313,7 @@ public class MarkdownQaSplitter extends TextSplitter {
      * 英文:1 单词 ≈ 1.3 Token
      */
     private static class SimpleTokenEstimator implements TokenEstimator {
+
         @Override
         public int estimate(String text) {
             if (StrUtil.isEmpty(text)) {
@@ -327,14 +322,12 @@ public class MarkdownQaSplitter extends TextSplitter {
 
             int chineseChars = 0;
             int englishWords = 0;
-
             // 简单统计中英文
             for (char c : text.toCharArray()) {
                 if (c >= 0x4E00 && c <= 0x9FA5) {
                     chineseChars++;
                 }
             }
-
             // 英文单词估算
             String[] words = text.split("\\s+");
             for (String word : words) {
@@ -342,8 +335,8 @@ public class MarkdownQaSplitter extends TextSplitter {
                     englishWords++;
                 }
             }
-
             return chineseChars + (int) (englishWords * 1.3);
         }
     }
+
 }

+ 31 - 23
yudao-module-ai/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/splitter/SemanticTextSplitter.java

@@ -8,6 +8,7 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
+import java.util.regex.Matcher;
 import java.util.regex.Pattern;
 
 /**
@@ -72,12 +73,14 @@ public class SemanticTextSplitter extends TextSplitter {
         if (StrUtil.isEmpty(text)) {
             return Collections.emptyList();
         }
-
         return splitTextRecursive(text);
     }
 
     /**
      * 切分文本(递归策略)
+     *
+     * @param text 待切分文本
+     * @return 切分后的文本块列表
      */
     private List<String> splitTextRecursive(String text) {
         List<String> chunks = new ArrayList<>();
@@ -92,7 +95,6 @@ public class SemanticTextSplitter extends TextSplitter {
         // 尝试按不同分隔符切分
         List<String> splits = null;
         String usedSeparator = null;
-
         for (String separator : PARAGRAPH_SEPARATORS) {
             if (text.contains(separator)) {
                 splits = Arrays.asList(text.split(Pattern.quote(separator)));
@@ -109,18 +111,20 @@ public class SemanticTextSplitter extends TextSplitter {
 
         // 合并小片段
         chunks = mergeSplits(splits, usedSeparator);
-
         return chunks;
     }
 
     /**
      * 按句子切分
+     *
+     * @param text 待切分文本
+     * @return 句子列表
      */
     private List<String> splitBySentences(String text) {
+        // 使用正则表达式匹配句子结束位置
         List<String> sentences = new ArrayList<>();
         int lastEnd = 0;
-
-        java.util.regex.Matcher matcher = SENTENCE_END_PATTERN.matcher(text);
+        Matcher matcher = SENTENCE_END_PATTERN.matcher(text);
         while (matcher.find()) {
             String sentence = text.substring(lastEnd, matcher.end()).trim();
             if (StrUtil.isNotEmpty(sentence)) {
@@ -136,12 +140,15 @@ public class SemanticTextSplitter extends TextSplitter {
                 sentences.add(remaining);
             }
         }
-
         return sentences.isEmpty() ? Collections.singletonList(text) : sentences;
     }
 
     /**
      * 合并切分后的小片段
+     *
+     * @param splits 切分后的片段列表
+     * @param separator 片段间的分隔符
+     * @return 合并后的文本块列表
      */
     private List<String> mergeSplits(List<String> splits, String separator) {
         List<String> chunks = new ArrayList<>();
@@ -152,9 +159,7 @@ public class SemanticTextSplitter extends TextSplitter {
             if (StrUtil.isEmpty(split)) {
                 continue;
             }
-
             int splitTokens = tokenEstimator.estimate(split);
-
             // 如果单个片段就超过限制,进一步递归切分
             if (splitTokens > chunkSize) {
                 // 先保存当前累积的块
@@ -164,7 +169,6 @@ public class SemanticTextSplitter extends TextSplitter {
                     currentChunks.clear();
                     currentLength = 0;
                 }
-
                 // 递归切分大片段
                 if (!separator.isEmpty()) {
                     // 如果是段落分隔符,尝试按句子切分
@@ -175,10 +179,8 @@ public class SemanticTextSplitter extends TextSplitter {
                 }
                 continue;
             }
-
             // 计算加上分隔符的 Token 数
             int separatorTokens = StrUtil.isEmpty(separator) ? 0 : tokenEstimator.estimate(separator);
-
             // 如果加上这个片段会超过限制
             if (!currentChunks.isEmpty() && currentLength + splitTokens + separatorTokens > chunkSize) {
                 // 保存当前块
@@ -189,7 +191,7 @@ public class SemanticTextSplitter extends TextSplitter {
                 currentChunks = getOverlappingChunks(currentChunks, separator);
                 currentLength = estimateTokens(currentChunks, separator);
             }
-
+            // 添加当前片段
             currentChunks.add(split);
             currentLength += splitTokens + separatorTokens;
         }
@@ -199,39 +201,43 @@ public class SemanticTextSplitter extends TextSplitter {
             String chunkText = String.join(separator, currentChunks);
             chunks.add(chunkText.trim());
         }
-
         return chunks;
     }
 
     /**
      * 获取重叠的片段(用于保持上下文)
+     *
+     * @param chunks 当前片段列表
+     * @param separator 片段间的分隔符
+     * @return 重叠的片段列表
      */
     private List<String> getOverlappingChunks(List<String> chunks, String separator) {
         if (chunkOverlap == 0 || chunks.isEmpty()) {
             return new ArrayList<>();
         }
 
+        // 从后往前取片段,直到达到重叠大小
         List<String> overlapping = new ArrayList<>();
         int tokens = 0;
-
-        // 从后往前取片段,直到达到重叠大小
         for (int i = chunks.size() - 1; i >= 0; i--) {
             String chunk = chunks.get(i);
             int chunkTokens = tokenEstimator.estimate(chunk);
-
             if (tokens + chunkTokens > chunkOverlap) {
                 break;
             }
-
+            // 添加到重叠列表前端
             overlapping.add(0, chunk);
             tokens += chunkTokens + (StrUtil.isEmpty(separator) ? 0 : tokenEstimator.estimate(separator));
         }
-
         return overlapping;
     }
 
     /**
      * 估算片段列表的总 Token 数
+     *
+     * @param chunks 片段列表
+     * @param separator 片段间的分隔符
+     * @return 总 Token 数
      */
     private int estimateTokens(List<String> chunks, String separator) {
         int total = 0;
@@ -246,17 +252,18 @@ public class SemanticTextSplitter extends TextSplitter {
 
     /**
      * 强制切分长文本(当语义切分失败时)
+     *
+     * @param text 待切分文本
+     * @return 切分后的文本块列表
      */
     private List<String> forceSplitLongText(String text) {
         List<String> chunks = new ArrayList<>();
         int charsPerChunk = (int) (chunkSize * 0.8); // 保守估计
-
         for (int i = 0; i < text.length(); i += charsPerChunk) {
             int end = Math.min(i + charsPerChunk, text.length());
             String chunk = text.substring(i, end);
             chunks.add(chunk.trim());
         }
-
         log.warn("文本过长,已强制按字符切分,可能影响语义完整性");
         return chunks;
     }
@@ -265,6 +272,7 @@ public class SemanticTextSplitter extends TextSplitter {
      * 简单的 Token 估算器实现
      */
     private static class SimpleTokenEstimator implements MarkdownQaSplitter.TokenEstimator {
+
         @Override
         public int estimate(String text) {
             if (StrUtil.isEmpty(text)) {
@@ -273,21 +281,21 @@ public class SemanticTextSplitter extends TextSplitter {
 
             int chineseChars = 0;
             int englishWords = 0;
-
+            // 简单统计中英文
             for (char c : text.toCharArray()) {
                 if (c >= 0x4E00 && c <= 0x9FA5) {
                     chineseChars++;
                 }
             }
-
+            // 英文单词估算
             String[] words = text.split("\\s+");
             for (String word : words) {
                 if (word.matches(".*[a-zA-Z].*")) {
                     englishWords++;
                 }
             }
-
             return chineseChars + (int) (englishWords * 1.3);
         }
     }
+
 }