企业知识更新快?JAVA RAG 实现动态调权召回策略确保答案时效性

JAVA RAG 实现动态调权召回策略确保答案时效性

各位朋友,大家好!今天我们来聊聊如何利用 JAVA 和 RAG (Retrieval-Augmented Generation) 技术,实现一个能够动态调整召回策略,以保证答案时效性的知识库系统。在企业知识更新速度飞快的今天,保证知识库的时效性至关重要。传统的知识库往往难以应对快速变化的信息,导致用户获取的答案过时甚至错误。RAG 架构通过在生成答案前,先从外部知识库中检索相关信息,可以有效提升答案的准确性和时效性。而通过动态调权召回策略,我们可以进一步优化 RAG 系统的性能,使其能够更好地适应不断变化的知识环境。

一、RAG 架构回顾与时效性挑战

RAG 架构的核心思想是将检索 (Retrieval) 和生成 (Generation) 两个阶段结合起来。

  1. 检索阶段 (Retrieval): 根据用户的问题,从外部知识库中检索出相关的文档或知识片段。这部分通常涉及向量数据库、相似度计算等技术。
  2. 生成阶段 (Generation): 利用检索到的知识片段,结合用户的问题,生成最终的答案。这部分通常使用预训练的语言模型 (LLM)。

RAG 架构在提升答案准确性方面表现出色,但同时也面临时效性挑战。如果知识库中的信息没有及时更新,或者检索策略无法区分新旧信息,那么 RAG 系统仍然可能生成过时的答案。

二、动态调权召回策略的设计思路

为了解决时效性问题,我们需要设计一种动态调权召回策略,其核心思路是:

  • 引入时间衰减因子: 对知识库中的文档或知识片段,根据其发布或更新时间,赋予不同的权重。越新的信息,权重越高。
  • 结合多种检索策略: 除了传统的基于语义相似度的检索外,还可以引入基于时间信息的检索,或者基于其他元数据的检索。
  • 动态调整权重组合: 根据用户的问题和知识库的实际情况,动态调整不同检索策略的权重,以获得最佳的召回效果。

三、JAVA RAG 框架搭建

我们使用 Spring Boot 框架搭建一个 JAVA RAG 系统,并集成向量数据库和 LLM。

  1. 项目初始化: 创建一个 Spring Boot 项目,并添加必要的依赖,例如:

    <dependencies>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>
        <dependency>
            <groupId>io.milvus</groupId>
            <artifactId>milvus-sdk-java</artifactId>
            <version>2.3.0</version>
        </dependency>
        <dependency>
            <groupId>com.fasterxml.jackson.core</groupId>
            <artifactId>jackson-databind</artifactId>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
        </dependency>
    </dependencies>

    这里我们使用 Milvus 作为向量数据库,并使用 Jackson 进行 JSON 序列化/反序列化。

  2. 向量数据库配置: 配置 Milvus 连接信息。

    @Configuration
    public class MilvusConfig {
    
        @Value("${milvus.host}")
        private String host;
    
        @Value("${milvus.port}")
        private int port;
    
        @Bean
        public MilvusClient milvusClient() {
            ConnectParam connectParam = new ConnectParam.Builder()
                    .withHost(host)
                    .withPort(port)
                    .build();
            return new MilvusClient(connectParam);
        }
    }
  3. 知识库数据模型: 定义知识库文档的数据模型。

    @Data
    public class Document {
        private String id;
        private String content;
        private Date publishDate;
        private float[] embedding; // 文档向量表示
        private String source; //文档来源,例如"公司公告","内部文档"等
    }
  4. 向量化服务: 实现一个向量化服务,用于将文档内容转换为向量表示。 这里使用一个简单的示例,实际应用中可以使用更强大的 Embedding 模型。

    @Service
    public class EmbeddingService {
    
        public float[] embed(String text) {
            // TODO: 使用 Embedding 模型将文本转换为向量
            // 这里只是一个示例,实际需要调用 LLM 的 API
            float[] embedding = new float[128];
            Random random = new Random();
            for (int i = 0; i < embedding.length; i++) {
                embedding[i] = random.nextFloat();
            }
            return embedding;
        }
    }
  5. 知识库管理服务: 实现一个知识库管理服务,用于向 Milvus 插入、更新和删除文档。

    @Service
    public class KnowledgeBaseService {
    
        @Autowired
        private MilvusClient milvusClient;
    
        @Autowired
        private EmbeddingService embeddingService;
    
        @Value("${milvus.collectionName}")
        private String collectionName;
    
        public void createCollection() {
            FieldType id = FieldType.newBuilder()
                    .withName("id")
                    .withDataType(DataType.VARCHAR)
                    .withMaxLength(256)
                    .withPrimaryKey(true)
                    .build();
    
            FieldType content = FieldType.newBuilder()
                    .withName("content")
                    .withDataType(DataType.VARCHAR)
                    .withMaxLength(65535)
                    .build();
    
            FieldType publishDate = FieldType.newBuilder()
                    .withName("publish_date")
                    .withDataType(DataType.INT64)
                    .build(); //Store publish date as epoch milliseconds
    
            FieldType embedding = FieldType.newBuilder()
                    .withName("embedding")
                    .withDataType(DataType.FLOAT_VECTOR)
                    .withDimension(128)
                    .build();
    
            CreateCollectionParam createCollectionReq = CreateCollectionParam.newBuilder()
                    .withCollectionName(collectionName)
                    .withFieldTypes(Arrays.asList(id, content, publishDate, embedding))
                    .build();
    
            milvusClient.createCollection(createCollectionReq);
            Index index = Index.newBuilder()
                    .withFieldName("embedding")
                    .withIndexType(IndexType.IVF_FLAT)
                    .withMetricType(MetricType.L2)
                    .withParam(new IndexParam("nlist", 128))
                    .build();
            milvusClient.createIndex(CreateIndexParam.newBuilder()
                    .withCollectionName(collectionName)
                    .withIndex(index)
                    .withSyncMode(Boolean.TRUE)
                    .build());
        }
    
        public void insertDocument(Document document) {
            List<String> ids = new ArrayList<>();
            List<String> contents = new ArrayList<>();
            List<Long> publishDates = new ArrayList<>();
            List<float[]> embeddings = new ArrayList<>();
    
            ids.add(document.getId());
            contents.add(document.getContent());
            publishDates.add(document.getPublishDate().getTime());
            embeddings.add(document.getEmbedding());
    
            List<InsertParam.Field> fields = Arrays.asList(
                    new InsertParam.Field("id", ids),
                    new InsertParam.Field("content", contents),
                    new InsertParam.Field("publish_date", publishDates),
                    new InsertParam.Field("embedding", embeddings)
            );
    
            InsertParam insertParam = InsertParam.newBuilder()
                    .withCollectionName(collectionName)
                    .withFields(fields)
                    .build();
    
            milvusClient.insert(insertParam);
            milvusClient.flush(FlushParam.newBuilder().withCollectionNames(Collections.singletonList(collectionName)).build());
        }
    
        public void updateDocument(Document document) {
            // TODO: 实现文档更新逻辑
        }
    
        public void deleteDocument(String id) {
            // TODO: 实现文档删除逻辑
        }
    }
  6. 召回服务: 实现召回服务,根据用户问题从知识库中检索相关文档。

    @Service
    public class RetrievalService {
    
        @Autowired
        private MilvusClient milvusClient;
    
        @Autowired
        private EmbeddingService embeddingService;
    
        @Value("${milvus.collectionName}")
        private String collectionName;
    
        @Value("${retrieval.topK}")
        private int topK;
    
        public List<Document> retrieve(String query, Map<String, Double> weights) {
            // 1. 将用户问题向量化
            float[] queryEmbedding = embeddingService.embed(query);
    
            // 2. 构建搜索参数
            SearchParam searchParam = SearchParam.newBuilder()
                    .withCollectionName(collectionName)
                    .withVectors(Collections.singletonList(queryEmbedding))
                    .withTopK(topK)
                    .withMetricType(MetricType.L2)
                    .withParams(new SearchParam.KeyValuePair("ef", "64")) // Adjust ef parameter as needed
                    .build();
    
            // 3. 执行搜索
            List<SearchResult> searchResults = milvusClient.search(searchParam);
    
            // 4. 解析搜索结果
            List<Document> results = new ArrayList<>();
            for (SearchResult searchResult : searchResults) {
                for (int i = 0; i < searchResult.getResults().size(); i++) {
                    SearchResult.Result result = searchResult.getResults().get(i);
                    String id = (String) result.get("id");
                    String content = (String) result.get("content");
                    Long publishDateMillis = (Long) result.get("publish_date");
                    Date publishDate = new Date(publishDateMillis);
    
                    Document document = new Document();
                    document.setId(id);
                    document.setContent(content);
                    document.setPublishDate(publishDate);
                    document.setEmbedding(queryEmbedding); //Just return query embedding for now;
                    // TODO: 考虑在这里添加评分计算,例如结合时间衰减因子
    
                    results.add(document);
                }
            }
    
            return results;
        }
    }
  7. 生成服务: 对接 LLM,根据检索到的文档生成答案。 这部分需要调用 LLM 的 API,例如 OpenAI 或其他开源 LLM。

    @Service
    public class GenerationService {
    
        public String generate(String query, List<Document> context) {
            // TODO: 调用 LLM API,结合 query 和 context 生成答案
            // 这里只是一个示例,实际需要调用 LLM 的 API
            StringBuilder sb = new StringBuilder();
            sb.append("根据以下信息,回答问题:").append(query).append("n");
            for (Document doc : context) {
                sb.append(doc.getContent()).append("n");
            }
            return "这是一个根据检索结果生成的示例答案。";
        }
    }
  8. API 接口: 创建一个 API 接口,接收用户问题,调用召回服务和生成服务,返回答案。

    @RestController
    public class KnowledgeController {
    
        @Autowired
        private RetrievalService retrievalService;
    
        @Autowired
        private GenerationService generationService;
    
        @PostMapping("/answer")
        public String answer(@RequestBody QuestionRequest request) {
            String query = request.getQuestion();
            Map<String, Double> weights = request.getWeights(); // 接收权重参数
    
            List<Document> documents = retrievalService.retrieve(query, weights);
            String answer = generationService.generate(query, documents);
    
            return answer;
        }
    }
    
    @Data
    static class QuestionRequest {
        private String question;
        private Map<String, Double> weights; // 允许用户传入权重参数
    }

四、动态调权召回策略的具体实现

  1. 时间衰减因子: 在召回服务中,引入时间衰减因子,对文档的评分进行调整。

    // 在 RetrievalService 的 retrieve 方法中添加
    public List<Document> retrieve(String query, Map<String, Double> weights) {
        // ... 之前的代码
    
        // 5. 计算评分,并应用时间衰减因子
        for (SearchResult searchResult : searchResults) {
            for (int i = 0; i < searchResult.getResults().size(); i++) {
                 SearchResult.Result result = searchResult.getResults().get(i);
                String id = (String) result.get("id");
                String content = (String) result.get("content");
                Long publishDateMillis = (Long) result.get("publish_date");
                Date publishDate = new Date(publishDateMillis);
                double similarityScore = searchResult.getResults().get(i).getDistance();
    
                Document document = new Document();
                document.setId(id);
                document.setContent(content);
                document.setPublishDate(publishDate);
                document.setEmbedding(queryEmbedding); //Just return query embedding for now;
    
                // 计算时间衰减因子
                double timeDecayFactor = calculateTimeDecayFactor(publishDate);
    
                // 计算最终得分
                double finalScore = similarityScore * timeDecayFactor; // 简化版
    
                // TODO: 添加到文档对象中,用于后续排序或筛选
                document.setScore(finalScore);
                results.add(document);
            }
        }
    
        // 根据得分排序
        results.sort(Comparator.comparingDouble(Document::getScore).reversed());
    
        return results;
    }
    
    private double calculateTimeDecayFactor(Date publishDate) {
        // 定义时间衰减参数
        long now = System.currentTimeMillis();
        long publishTime = publishDate.getTime();
        long timeDiff = now - publishTime;
    
        // 可以使用不同的时间衰减函数,例如指数衰减
        double halfLife = 30 * 24 * 60 * 60 * 1000.0; // 30 天半衰期
        return Math.exp(-timeDiff / halfLife);
    }

    这里 calculateTimeDecayFactor 方法根据文档的发布时间计算时间衰减因子。半衰期 (halfLife) 可以根据实际情况调整。

  2. 结合多种检索策略: 除了基于语义相似度的检索外,还可以引入基于时间信息的检索。例如,可以先检索最近一段时间内的文档,然后再进行语义相似度检索。

    public List<Document> retrieve(String query, Map<String, Double> weights) {
        // ... 之前的代码
    
        // 1. 基于时间范围的检索 (可选)
        List<Document> recentDocuments = retrieveRecentDocuments(query, 7); // 检索最近 7 天的文档
    
        // 2. 基于语义相似度的检索
        List<Document> semanticDocuments = retrieveSemanticDocuments(query);
    
        // 3. 合并结果,并应用权重
        List<Document> combinedResults = new ArrayList<>();
        combinedResults.addAll(recentDocuments);
        combinedResults.addAll(semanticDocuments);
    
        // 4. 应用时间衰减因子和权重
        for (Document document : combinedResults) {
            double timeDecayFactor = calculateTimeDecayFactor(document.getPublishDate());
            double semanticScore = document.getScore(); // 假设语义相似度得分已经计算好
            double finalScore = weights.get("time") * timeDecayFactor + weights.get("semantic") * semanticScore; // 应用权重
            document.setScore(finalScore);
        }
    
        // 5. 排序
        combinedResults.sort(Comparator.comparingDouble(Document::getScore).reversed());
    
        return combinedResults.subList(0, Math.min(topK, combinedResults.size())); // 返回 topK 个结果
    }
    
    private List<Document> retrieveRecentDocuments(String query, int days) {
        // TODO: 实现基于时间范围的检索
        // 可以使用 Milvus 的 range query 功能
        //  构建 range query 的条件,例如 publish_date > now - days * 24 * 60 * 60 * 1000
        //  然后执行搜索
        return new ArrayList<>(); // 示例:返回空列表
    }
    
    private List<Document> retrieveSemanticDocuments(String query) {
        // TODO: 实现基于语义相似度的检索
        // 使用 Milvus 的向量搜索功能
        return new ArrayList<>(); // 示例:返回空列表
    }

    这里,我们先使用 retrieveRecentDocuments 检索最近一段时间内的文档,然后使用 retrieveSemanticDocuments 进行语义相似度检索。最后,将两部分结果合并,并根据权重计算最终得分。

  3. 动态调整权重组合: 可以根据用户的问题和知识库的实际情况,动态调整不同检索策略的权重。例如,如果用户的问题涉及到最新的政策法规,那么可以增加时间信息的权重。

    • 规则引擎: 可以使用规则引擎,例如 Drools,根据用户问题和知识库的元数据,动态调整权重。
    • 机器学习模型: 可以使用机器学习模型,例如 LightGBM,根据用户问题和历史数据,预测最佳的权重组合。
    // 在 RetrievalService 的 retrieve 方法中
    public List<Document> retrieve(String query, Map<String, Double> weights) {
        // 1. 动态调整权重 (示例)
        Map<String, Double> adjustedWeights = adjustWeights(query, weights);
    
        // 2. 使用调整后的权重进行检索
        // ...
    }
    
    private Map<String, Double> adjustWeights(String query, Map<String, Double> weights) {
        // TODO: 根据 query 和知识库元数据,动态调整权重
        // 这里只是一个示例
        if (query.contains("最新")) {
            weights.put("time", weights.get("time") * 1.5); // 增加时间权重
        }
        return weights;
    }

五、评估与优化

  • 评估指标: 使用 MRR (Mean Reciprocal Rank)、Recall 等指标评估 RAG 系统的性能。
  • A/B 测试: 使用 A/B 测试比较不同召回策略的效果。
  • 持续优化: 根据评估结果和用户反馈,持续优化 RAG 系统的各个环节,包括知识库更新、Embedding 模型、召回策略和生成模型。

六、代码示例:Document类添加Score字段

Document 类中,我们需要添加一个score字段,用于存储文档的最终得分:

@Data
public class Document {
    private String id;
    private String content;
    private Date publishDate;
    private float[] embedding; // 文档向量表示
    private String source; //文档来源,例如"公司公告","内部文档"等
    private double score; //文档得分
}

七、代码示例:添加文档得分,修改排序方式

RetrievalService 类中,添加文档得分,并根据得分进行排序。

    @Service
    public class RetrievalService {

        @Autowired
        private MilvusClient milvusClient;

        @Autowired
        private EmbeddingService embeddingService;

        @Value("${milvus.collectionName}")
        private String collectionName;

        @Value("${retrieval.topK}")
        private int topK;

        public List<Document> retrieve(String query, Map<String, Double> weights) {
            // 1. 将用户问题向量化
            float[] queryEmbedding = embeddingService.embed(query);

            // 2. 构建搜索参数
            SearchParam searchParam = SearchParam.newBuilder()
                    .withCollectionName(collectionName)
                    .withVectors(Collections.singletonList(queryEmbedding))
                    .withTopK(topK)
                    .withMetricType(MetricType.L2)
                    .withParams(new SearchParam.KeyValuePair("ef", "64")) // Adjust ef parameter as needed
                    .build();

            // 3. 执行搜索
            List<SearchResult> searchResults = milvusClient.search(searchParam);

            // 4. 解析搜索结果
            List<Document> results = new ArrayList<>();
            for (SearchResult searchResult : searchResults) {
                for (int i = 0; i < searchResult.getResults().size(); i++) {
                    SearchResult.Result result = searchResult.getResults().get(i);
                    String id = (String) result.get("id");
                    String content = (String) result.get("content");
                    Long publishDateMillis = (Long) result.get("publish_date");
                    Date publishDate = new Date(publishDateMillis);
                    double similarityScore = result.getDistance();  // 获取相似度得分

                    Document document = new Document();
                    document.setId(id);
                    document.setContent(content);
                    document.setPublishDate(publishDate);
                    document.setEmbedding(queryEmbedding); //Just return query embedding for now;

                    // 计算时间衰减因子
                    double timeDecayFactor = calculateTimeDecayFactor(publishDate);

                    // 计算最终得分
                    double finalScore = similarityScore * timeDecayFactor; // 简化版

                    // 添加到文档对象中,用于后续排序或筛选
                    document.setScore(finalScore);
                    results.add(document);
                }
            }

            // 根据得分排序
            results.sort(Comparator.comparingDouble(Document::getScore).reversed());

            return results;
        }

        private double calculateTimeDecayFactor(Date publishDate) {
            // 定义时间衰减参数
            long now = System.currentTimeMillis();
            long publishTime = publishDate.getTime();
            long timeDiff = now - publishTime;

            // 可以使用不同的时间衰减函数,例如指数衰减
            double halfLife = 30 * 24 * 60 * 60 * 1000.0; // 30 天半衰期
            return Math.exp(-timeDiff / halfLife);
        }
    }

八、配置参数示例

application.propertiesapplication.yml 中配置相关参数。

milvus.host=localhost
milvus.port=19530
milvus.collectionName=knowledge_base

retrieval.topK=10

九、表格:不同检索策略的权重示例

检索策略 权重 说明
语义相似度 0.7 基于语义相似度的检索
时间信息 0.3 基于文档发布时间的检索
来源信息 0.1 基于文档来源的检索

十、表格:时间衰减函数示例

函数类型 公式 说明
指数衰减 exp(-t / halfLife) t: 时间差,halfLife: 半衰期
线性衰减 max(0, 1 – t / maxAge) t: 时间差,maxAge: 最大有效期

十一、保证答案时效性的关键点

通过上述步骤,我们构建了一个基于 JAVA 和 RAG 架构的知识库系统,并实现了动态调权召回策略。为了保证答案的时效性,还需要注意以下几点:

  • 知识库及时更新: 建立完善的知识库更新机制,确保信息能够及时同步到知识库中。
  • 元数据管理: 完善知识库的元数据管理,例如文档发布时间、来源等,以便进行更精细的检索和调权。
  • 监控与反馈: 建立完善的监控和反馈机制,及时发现和解决 RAG 系统中存在的问题。

十二、系统架构及代码逻辑

整体架构如下:

  1. Controller Layer: 接收用户请求,并将请求转发到 Service Layer。
  2. Service Layer: 包含RetrievalServiceGenerationService,负责检索相关文档和生成答案。RetrievalService 通过 EmbeddingService 将用户问题转化为向量,然后从 Milvus 向量数据库中检索相关文档。
  3. Data Access Layer: 通过 Milvus SDK 与 Milvus 向量数据库交互,存储和检索文档向量。
  4. Embedding Model: 用于将文本转化为向量表示,可以选择预训练的 Embedding 模型,例如 OpenAI 的 Embedding API。

核心代码逻辑:

  1. RetrievalService 接收用户问题,并使用 EmbeddingService 将问题转化为向量。
  2. RetrievalService 构建 Milvus 搜索参数,包括向量、TopK、MetricType 等。
  3. RetrievalService 调用 Milvus SDK 执行搜索,并解析搜索结果。
  4. RetrievalService 根据文档的发布时间计算时间衰减因子,并结合相似度得分计算最终得分。
  5. RetrievalService 根据得分对文档进行排序,并返回 TopK 个文档。
  6. GenerationService 接收用户问题和检索到的文档,并调用 LLM 生成答案。
  7. Controller 将生成的答案返回给用户。

十三、多种策略的组合使用

介绍了多种检索策略,包括基于语义相似度的检索、基于时间信息的检索和基于其他元数据的检索。在实际应用中,可以将这些策略组合使用,以获得更好的召回效果。例如,可以先使用基于时间信息的检索过滤掉过时的文档,然后再使用基于语义相似度的检索找到与用户问题最相关的文档。

十四、灵活的权重调整方式

动态调权召回策略的核心在于能够根据用户的问题和知识库的实际情况,动态调整不同检索策略的权重。介绍了多种权重调整方式,包括规则引擎和机器学习模型。可以根据实际情况选择合适的权重调整方式,以获得最佳的召回效果。

十五、保证时效性需要关注的点

总结一下,保证知识库答案的时效性需要关注以下几个关键点:及时更新知识库、完善元数据管理、动态调整召回策略、监控与反馈。只有做好这些方面,才能确保用户能够获得准确、及时的答案。

希望今天的分享对大家有所帮助! 谢谢!

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注