基于JAVA构建RAG检索链路提升向量召回精确度的优化工程实战

基于JAVA构建RAG检索链路提升向量召回精确度的优化工程实战

大家好,今天我们来聊聊如何使用JAVA构建一个RAG (Retrieval-Augmented Generation) 检索链路,并且着重讨论如何优化向量召回的精确度。RAG 架构的核心在于先从海量数据中检索出相关的知识,再将这些知识融入到生成模型的输入中,从而提升生成内容的质量和准确性。 向量召回作为 RAG 的第一步,其精确度直接影响着整个系统的效果。

一、RAG架构概览与JAVA选型

RAG 的基本流程可以概括为:

  1. 索引构建 (Indexing): 将知识库文档进行预处理,例如分块、清洗、转换等,然后使用 Embedding 模型将文档块转换为向量表示,并存储到向量数据库中。
  2. 检索 (Retrieval): 接收用户查询,同样使用 Embedding 模型将查询转换为向量表示,然后在向量数据库中进行相似性搜索,找到与查询最相关的文档块。
  3. 生成 (Generation): 将检索到的文档块与用户查询一起作为输入,送入大型语言模型 (LLM),生成最终的答案或内容。

为什么选择 JAVA? JAVA 在企业级应用中拥有广泛的应用,生态成熟,拥有丰富的库和框架,便于构建稳定可靠的 RAG 系统。 特别是在高并发、低延迟的检索场景下,JAVA 提供的并发模型和性能优化工具可以发挥重要作用。

二、环境搭建与依赖引入

首先,我们需要搭建 JAVA 开发环境,并引入必要的依赖。 这里我们使用 Maven 作为项目管理工具。

<dependencies>
    <!--  向量数据库,例如: Milvus  -->
    <dependency>
        <groupId>io.milvus</groupId>
        <artifactId>milvus-sdk-java</artifactId>
        <version>2.3.0</version>  <!-- 请替换为最新版本 -->
    </dependency>

    <!--  Embedding 模型,例如: Sentence Transformers  -->
    <dependency>
        <groupId>co.elastic.clients</groupId>
        <artifactId>elasticsearch-java</artifactId>
        <version>8.11.3</version>
    </dependency>
    <dependency>
        <groupId>com.fasterxml.jackson.core</groupId>
        <artifactId>jackson-databind</artifactId>
        <version>2.16.1</version>
    </dependency>
    <dependency>
        <groupId>org.slf4j</groupId>
        <artifactId>slf4j-api</artifactId>
        <version>2.0.9</version>
    </dependency>
    <dependency>
        <groupId>org.slf4j</groupId>
        <artifactId>slf4j-simple</artifactId>
        <version>2.0.9</version>
    </dependency>

    <!--  JSON 处理  -->
    <dependency>
        <groupId>com.google.code.gson</groupId>
        <artifactId>gson</artifactId>
        <version>2.10.1</version>
    </dependency>
    <dependency>
        <groupId>org.apache.commons</groupId>
        <artifactId>commons-lang3</artifactId>
        <version>3.12.0</version>
    </dependency>

    <!-- 文档解析,例如: PDFBox -->
    <dependency>
        <groupId>org.apache.pdfbox</groupId>
        <artifactId>pdfbox</artifactId>
        <version>2.0.30</version>
    </dependency>

</dependencies>

这里选择了 Milvus 作为向量数据库,Elasticsearch 作为 Embedding 模型载体(也可以选择直接调用Hugging Face API或者其他本地部署的 Embedding 模型服务),Jackson和Gson作为JSON处理库, PDFBox 用于处理PDF文档。 请根据实际情况选择合适的库和版本。

三、索引构建:文档预处理与向量化

索引构建是 RAG 的基础,其质量直接影响后续的检索效果。

  1. 文档加载与分块:

    首先,我们需要将知识库文档加载到内存中。 对于不同类型的文档,需要使用不同的解析器。 例如,对于 PDF 文档,可以使用 PDFBox;对于文本文件,可以直接读取。

    import org.apache.pdfbox.pdmodel.PDDocument;
    import org.apache.pdfbox.text.PDFTextStripper;
    
    import java.io.File;
    import java.io.IOException;
    import java.util.ArrayList;
    import java.util.List;
    
    public class DocumentLoader {
    
       public static String loadPdf(String filePath) throws IOException {
           PDDocument document = PDDocument.load(new File(filePath));
           PDFTextStripper stripper = new PDFTextStripper();
           String text = stripper.getText(document);
           document.close();
           return text;
       }
    
       // 其他文档类型的加载方法...
    
       public static List<String> chunkText(String text, int chunkSize, int overlapSize) {
           List<String> chunks = new ArrayList<>();
           String[] sentences = text.split("(?<=[.?!])\s+"); // 按句子分割
           StringBuilder currentChunk = new StringBuilder();
           int currentLength = 0;
    
           for (int i = 0; i < sentences.length; i++) {
               String sentence = sentences[i];
               int sentenceLength = sentence.length();
    
               if (currentLength + sentenceLength + 1 <= chunkSize) {
                   currentChunk.append(sentence).append(" ");
                   currentLength += sentenceLength + 1;
               } else {
                   if (currentChunk.length() > 0) {
                       chunks.add(currentChunk.toString().trim());
                   }
                   // 开始新的 chunk,考虑 overlap
                   currentChunk = new StringBuilder();
                   if (overlapSize > 0 && chunks.size() > 0) {
                       // 从上一个 chunk 提取 overlap 部分
                       String lastChunk = chunks.get(chunks.size() - 1);
                       int overlapStart = Math.max(0, lastChunk.length() - overlapSize);
                       currentChunk.append(lastChunk.substring(overlapStart)).append(" ");
                   }
                   currentChunk.append(sentence).append(" ");
                   currentLength = currentChunk.length();
               }
           }
    
           // 添加最后一个 chunk
           if (currentChunk.length() > 0) {
               chunks.add(currentChunk.toString().trim());
           }
           return chunks;
       }
    
       public static void main(String[] args) throws IOException {
           String filePath = "path/to/your/document.pdf";
           String text = loadPdf(filePath);
           List<String> chunks = chunkText(text, 256, 50); // 示例:chunkSize=256, overlapSize=50
           for (String chunk : chunks) {
               System.out.println(chunk);
           }
       }
    }

    加载文档后,我们需要将文档分割成更小的块 (chunks)。Chunk 的大小会影响检索的精确度和效率。 过大的 chunk 可能包含不相关的信息,降低精确度;过小的 chunk 可能丢失上下文信息。 一种常用的策略是按句子分割,并设置最大 chunk 大小和 overlap 大小。 chunkSize 定义了每个 chunk 的最大长度,overlapSize 定义了相邻 chunk 之间的重叠长度,用于保留上下文信息。

  2. Embedding 向量化:

    将文档块转换为向量表示是核心步骤。 我们需要选择合适的 Embedding 模型,例如 Sentence Transformers、OpenAI Embeddings 等。 这里以 Elasticsearch 为例,演示如何使用 Elasticsearch 的 _vectors API 进行向量化。

    import co.elastic.clients.elasticsearch.ElasticsearchClient;
    import co.elastic.clients.elasticsearch.core.IndexRequest;
    import co.elastic.clients.elasticsearch.core.search.Hit;
    import co.elastic.clients.elasticsearch.indices.CreateIndexRequest;
    import co.elastic.clients.json.JsonData;
    import co.elastic.clients.json.jackson.JacksonJsonpMapper;
    import co.elastic.clients.transport.ElasticsearchTransport;
    import co.elastic.clients.transport.rest_client.RestClientTransport;
    import com.fasterxml.jackson.databind.JsonNode;
    import com.fasterxml.jackson.databind.ObjectMapper;
    import org.apache.http.HttpHost;
    import org.elasticsearch.client.RestClient;
    
    import java.io.IOException;
    import java.io.StringReader;
    import java.util.ArrayList;
    import java.util.HashMap;
    import java.util.List;
    import java.util.Map;
    import java.util.stream.Collectors;
    
    public class EmbeddingService {
    
       private final ElasticsearchClient client;
    
       public EmbeddingService(String host, int port) {
           RestClient restClient = RestClient.builder(new HttpHost(host, port)).build();
    
           ElasticsearchTransport transport = new RestClientTransport(restClient, new JacksonJsonpMapper());
           this.client = new ElasticsearchClient(transport);
       }
    
       public void createIndex(String indexName, int dimension) throws IOException {
           client.indices().create(
                   new CreateIndexRequest.Builder()
                           .index(indexName)
                           .mappings(m -> m
                                   .properties("text", p -> p
                                           .text(t -> t))
                                   .properties("embedding", p -> p
                                           .denseVector(v -> v.dims(dimension)))
                           )
                           .build()
           );
       }
    
       public List<Float> generateEmbedding(String text) throws IOException {
           // 使用 Elasticsearch 的 _vectors API 获取 embedding
           Map<String, String> body = new HashMap<>();
           body.put("model_id", "sentence-transformers__all-minilm-l6-v3");
           body.put("model_text", text);
    
           JsonData jsonData = JsonData.fromJson(new ObjectMapper().writeValueAsString(body));
    
           JsonNode response = client.post(
                   "_ml/inference/sentence-transformers__all-minilm-l6-v3/_explain",
                   null,
                   jsonData,
                   JsonNode.class
           );
    
           List<Float> embedding = new ArrayList<>();
           response.get("top_classes").forEach(jsonNode -> {
               float value = jsonNode.get("value").floatValue();
               embedding.add(value);
           });
    
           return embedding;
       }
    
       public void indexDocument(String indexName, String id, String text, List<Float> embedding) throws IOException {
           Map<String, Object> document = new HashMap<>();
           document.put("text", text);
           document.put("embedding", embedding);
    
           IndexRequest<Map<String, Object>> request = new IndexRequest.Builder<Map<String, Object>>()
                   .index(indexName)
                   .id(id)
                   .document(document)
                   .build();
    
           client.index(request);
       }
    
       public static void main(String[] args) throws IOException {
           EmbeddingService embeddingService = new EmbeddingService("localhost", 9200);
           String indexName = "my_index";
           int dimension = 384; // 根据 embedding 模型确定维度
    
           // 创建 index
           embeddingService.createIndex(indexName, dimension);
    
           // 示例文档
           String text1 = "This is the first document.";
           String text2 = "This is the second document.";
    
           // 获取 embedding
           List<Float> embedding1 = embeddingService.generateEmbedding(text1);
           List<Float> embedding2 = embeddingService.generateEmbedding(text2);
    
           // 索引文档
           embeddingService.indexDocument(indexName, "1", text1, embedding1);
           embeddingService.indexDocument(indexName, "2", text2, embedding2);
    
           System.out.println("Documents indexed successfully.");
       }
    }

    这段代码演示了如何连接到 Elasticsearch,创建一个包含 textembedding 字段的索引,使用 _vectors API 生成 embedding,并将文档索引到 Elasticsearch 中。 dimension 需要根据所使用的 Embedding 模型来确定。 这里使用了 sentence-transformers__all-minilm-l6-v3 模型,其维度为 384。

  3. 向量数据库存储:

    将文档块的向量表示存储到向量数据库中,例如 Milvus。

    import io.milvus.client.MilvusClient;
    import io.milvus.client.MilvusServiceClient;
    import io.milvus.grpc.DataType;
    import io.milvus.param.ConnectParam;
    import io.milvus.param.IndexType;
    import io.milvus.param.MetricType;
    import io.milvus.param.collection.CreateCollectionParam;
    import io.milvus.param.collection.FieldType;
    import io.milvus.param.collection.InsertParam;
    import io.milvus.param.collection.LoadCollectionParam;
    import io.milvus.param.index.CreateIndexParam;
    
    import java.util.ArrayList;
    import java.util.Arrays;
    import java.util.List;
    
    public class VectorDatabaseService {
    
       private final MilvusClient milvusClient;
    
       public VectorDatabaseService(String host, int port) {
           ConnectParam connectParam = new ConnectParam.Builder()
                   .withHost(host)
                   .withPort(port)
                   .build();
           this.milvusClient = new MilvusServiceClient(connectParam);
       }
    
       public void createCollection(String collectionName, int dimension) {
           FieldType idField = FieldType.newBuilder()
                   .withName("id")
                   .withDataType(DataType.INT64)
                   .withPrimaryKey(true)
                   .withAutoID(false)
                   .build();
    
           FieldType vectorField = FieldType.newBuilder()
                   .withName("embedding")
                   .withDataType(DataType.FLOAT_VECTOR)
                   .withDimension(dimension)
                   .build();
    
           CreateCollectionParam createCollectionParam = CreateCollectionParam.newBuilder()
                   .withCollectionName(collectionName)
                   .withFields(Arrays.asList(idField, vectorField))
                   .build();
    
           milvusClient.createCollection(createCollectionParam);
       }
    
       public void createIndex(String collectionName) {
           CreateIndexParam createIndexParam = CreateIndexParam.newBuilder()
                   .withCollectionName(collectionName)
                   .withFieldName("embedding")
                   .withIndexType(IndexType.IVF_FLAT)
                   .withMetricType(MetricType.L2)
                   .withParam(new io.milvus.param.IndexParams.Builder().withNlist(1024).build())
                   .withSyncMode(Boolean.FALSE)
                   .build();
    
           milvusClient.createIndex(createIndexParam);
       }
    
       public void loadCollection(String collectionName) {
           LoadCollectionParam loadCollectionParam = LoadCollectionParam.newBuilder()
                   .withCollectionName(collectionName)
                   .build();
    
           milvusClient.loadCollection(loadCollectionParam);
       }
    
       public void insertVectors(String collectionName, List<Long> ids, List<List<Float>> vectors) {
           List<String> fieldsName = Arrays.asList("id", "embedding");
           List<List<?>> rows = new ArrayList<>();
           rows.add(ids);
           rows.add(vectors);
    
           InsertParam insertParam = InsertParam.newBuilder()
                   .withCollectionName(collectionName)
                   .withFieldsName(fieldsName)
                   .withRows(rows)
                   .build();
    
           milvusClient.insert(insertParam);
           milvusClient.flush(collectionName);
       }
    
       public static void main(String[] args) {
           VectorDatabaseService vectorDatabaseService = new VectorDatabaseService("localhost", 19530);
           String collectionName = "my_collection";
           int dimension = 384; // 根据 embedding 模型确定维度
    
           // 创建 collection
           vectorDatabaseService.createCollection(collectionName, dimension);
    
           // 创建 index
           vectorDatabaseService.createIndex(collectionName);
    
           // 加载 collection
           vectorDatabaseService.loadCollection(collectionName);
    
           // 示例数据
           List<Long> ids = Arrays.asList(1L, 2L);
           List<List<Float>> vectors = new ArrayList<>();
           vectors.add(Arrays.asList(new Float[dimension])); // 替换为实际的 embedding 向量
           vectors.add(Arrays.asList(new Float[dimension])); // 替换为实际的 embedding 向量
    
           // 插入向量
           vectorDatabaseService.insertVectors(collectionName, ids, vectors);
    
           System.out.println("Vectors inserted successfully.");
       }
    }

    这段代码演示了如何连接到 Milvus,创建一个 collection,定义 idembedding 字段,创建 IVF_FLAT 索引,加载 collection 到内存,并将向量数据插入到 Milvus 中。

四、检索优化:提升向量召回精确度

向量召回的精确度直接影响 RAG 系统的效果。 以下是一些常用的优化策略:

  1. 选择合适的 Embedding 模型:

    不同的 Embedding 模型在不同的领域和任务上表现不同。 需要根据具体的应用场景选择合适的模型。 例如,对于文本相似度任务,Sentence Transformers 通常表现良好;对于代码搜索任务,CodeBERT 或 CodeT5 可能更适合。 选择模型时,需要考虑模型的性能、大小、训练数据等因素。

    Embedding 模型 优点 缺点 适用场景
    Sentence Transformers 性能好,通用性强,支持多种语言 对于特定领域可能需要微调 文本相似度,文本检索,问答系统
    OpenAI Embeddings 效果好,易于使用,无需自己部署 成本较高,依赖网络连接 文本相似度,文本检索,问答系统
    CodeBERT/CodeT5 针对代码进行了优化 只适用于代码相关的任务 代码搜索,代码补全,代码生成
    BGE (BAAI General Embedding) 中文支持好,效果较好,有多种尺寸的模型可选 需要一定的部署和维护成本 中文文本相似度,中文文本检索,问答系统
  2. Chunk 大小优化:

    Chunk 的大小会影响检索的精确度和效率。 过大的 chunk 可能包含不相关的信息,降低精确度;过小的 chunk 可能丢失上下文信息。 可以通过实验找到最佳的 chunk 大小。 一种常用的策略是按句子分割,并设置最大 chunk 大小和 overlap 大小。 还可以根据文档的结构和内容动态调整 chunk 大小。

    • 固定大小分块: 简单直接,但可能在语义上不完整。适用于文档结构不明确的情况。
    • 滑动窗口分块: 通过设置窗口大小和步长,生成重叠的 chunk,可以保留上下文信息,但会增加计算量。
    • 语义分块: 利用 NLP 技术(例如句子分割、段落分割)将文档分割成语义完整的块,可以提高检索的精确度。
    • 递归分块: 先将文档分割成较大的块,然后递归地将每个块分割成更小的块,直到满足大小限制。 适用于长文档。
  3. 查询扩展 (Query Expansion):

    查询扩展是指在原始查询的基础上,添加相关的词语或短语,以扩大搜索范围,提高召回率。 常用的查询扩展方法包括:

    • 同义词扩展: 使用同义词词典或词向量模型,找到与查询词语相似的词语,添加到查询中。
    • 相关词扩展: 使用相关词词典或知识图谱,找到与查询词语相关的词语,添加到查询中。
    • 查询重写: 使用语言模型或规则,将查询改写成更清晰、更明确的表达。
    import java.util.Arrays;
    import java.util.HashSet;
    import java.util.Set;
    
    public class QueryExpansion {
    
       private static final Set<String> SYNONYMS_MAP = new HashSet<>(Arrays.asList("big", "large", "huge")); // 示例
    
       public static String expandQuery(String query) {
           String[] words = query.split("\s+");
           StringBuilder expandedQuery = new StringBuilder(query);
    
           for (String word : words) {
               if (SYNONYMS_MAP.contains(word)) {
                   for (String synonym : SYNONYMS_MAP) {
                       if (!synonym.equals(word)) {
                           expandedQuery.append(" ").append(synonym);
                       }
                   }
               }
           }
    
           return expandedQuery.toString();
       }
    
       public static void main(String[] args) {
           String query = "This is a big house";
           String expandedQuery = expandQuery(query);
           System.out.println("Original query: " + query);
           System.out.println("Expanded query: " + expandedQuery);
       }
    }

    这个例子展示了如何使用一个简单的同义词词典进行查询扩展。 实际应用中,可以使用更复杂的同义词词典或词向量模型。

  4. 重排序 (Re-ranking):

    向量召回的结果可能包含一些不相关的信息。 可以使用重排序模型对召回的结果进行排序,将最相关的文档排在前面。 常用的重排序模型包括:

    • BM25: 一种经典的文本检索模型,考虑了词频、文档长度等因素。
    • Cross-Encoder: 一种基于 Transformer 的模型,将查询和文档一起输入到模型中,预测它们的相关性。
    import java.util.ArrayList;
    import java.util.Comparator;
    import java.util.List;
    import java.util.Map;
    import java.util.HashMap;
    
    public class ReRanking {
    
       public static List<Map<String, Object>> reRank(String query, List<Map<String, Object>> results) {
           // 模拟 BM25 得分计算
           results.forEach(result -> {
               String text = (String) result.get("text");
               double score = calculateBM25Score(query, text);
               result.put("score", score);
           });
    
           // 按照得分降序排序
           results.sort(Comparator.comparingDouble(r -> (double) r.get("score")).reversed());
    
           return results;
       }
    
       private static double calculateBM25Score(String query, String text) {
           // 简化版 BM25 算法,仅作演示
           String[] queryTerms = query.split("\s+");
           String[] documentTerms = text.split("\s+");
    
           double score = 0;
           for (String term : queryTerms) {
               long termFrequency = Arrays.stream(documentTerms).filter(term::equals).count();
               score += termFrequency / (termFrequency + 1); // 简化版公式
           }
    
           return score;
       }
    
       public static void main(String[] args) {
           String query = "This is a test query";
           List<Map<String, Object>> results = new ArrayList<>();
    
           Map<String, Object> result1 = new HashMap<>();
           result1.put("text", "This document is about the test query.");
           results.add(result1);
    
           Map<String, Object> result2 = new HashMap<>();
           result2.put("text", "This is a completely unrelated document.");
           results.add(result2);
    
           List<Map<String, Object>> reRankedResults = reRank(query, results);
    
           System.out.println("Re-ranked results:");
           reRankedResults.forEach(System.out::println);
       }
    }

    这个例子演示了如何使用一个简化的 BM25 算法对检索结果进行重排序。 实际应用中,可以使用更复杂的 BM25 算法或 Cross-Encoder 模型。

  5. 元数据过滤 (Metadata Filtering):

    在向量检索的基础上,可以根据文档的元数据进行过滤,例如日期、作者、类别等,以缩小搜索范围,提高精确度。 需要在索引构建时,将元数据与文档块一起存储到向量数据库中。

    import io.milvus.client.MilvusClient;
    import io.milvus.client.MilvusServiceClient;
    import io.milvus.grpc.DataType;
    import io.milvus.param.ConnectParam;
    import io.milvus.param.IndexType;
    import io.milvus.param.MetricType;
    import io.milvus.param.collection.CreateCollectionParam;
    import io.milvus.param.collection.FieldType;
    import io.milvus.param.collection.InsertParam;
    import io.milvus.param.collection.LoadCollectionParam;
    import io.milvus.param.index.CreateIndexParam;
    import io.milvus.param.dml.SearchParam;
    import io.milvus.param.R;
    import io.milvus.response.SearchResults;
    
    import java.util.ArrayList;
    import java.util.Arrays;
    import java.util.List;
    import java.util.Map;
    import java.util.HashMap;
    import java.util.concurrent.TimeUnit;
    
    public class MetadataFiltering {
    
       private final MilvusClient milvusClient;
    
       public MetadataFiltering(String host, int port) {
           ConnectParam connectParam = new ConnectParam.Builder()
                   .withHost(host)
                   .withPort(port)
                   .build();
           this.milvusClient = new MilvusServiceClient(connectParam);
       }
    
       public void createCollection(String collectionName, int dimension) {
           FieldType idField = FieldType.newBuilder()
                   .withName("id")
                   .withDataType(DataType.INT64)
                   .withPrimaryKey(true)
                   .withAutoID(false)
                   .build();
    
           FieldType vectorField = FieldType.newBuilder()
                   .withName("embedding")
                   .withDataType(DataType.FLOAT_VECTOR)
                   .withDimension(dimension)
                   .build();
    
           FieldType categoryField = FieldType.newBuilder()
                   .withName("category")
                   .withDataType(DataType.VARCHAR)
                   .withMaxLength(256)
                   .build();
    
           CreateCollectionParam createCollectionParam = CreateCollectionParam.newBuilder()
                   .withCollectionName(collectionName)
                   .withFields(Arrays.asList(idField, vectorField, categoryField))
                   .build();
    
           milvusClient.createCollection(createCollectionParam);
       }
    
       public void createIndex(String collectionName) {
           CreateIndexParam createIndexParam = CreateIndexParam.newBuilder()
                   .withCollectionName(collectionName)
                   .withFieldName("embedding")
                   .withIndexType(IndexType.IVF_FLAT)
                   .withMetricType(MetricType.L2)
                   .withParam(new io.milvus.param.IndexParams.Builder().withNlist(1024).build())
                   .withSyncMode(Boolean.FALSE)
                   .build();
    
           milvusClient.createIndex(createIndexParam);
       }
    
       public void loadCollection(String collectionName) {
           LoadCollectionParam loadCollectionParam = LoadCollectionParam.newBuilder()
                   .withCollectionName(collectionName)
                   .build();
    
           milvusClient.loadCollection(loadCollectionParam);
       }
    
       public void insertVectors(String collectionName, List<Long> ids, List<List<Float>> vectors, List<String> categories) {
           List<String> fieldsName = Arrays.asList("id", "embedding", "category");
           List<List<?>> rows = new ArrayList<>();
           rows.add(ids);
           rows.add(vectors);
           rows.add(categories);
    
           InsertParam insertParam = InsertParam.newBuilder()
                   .withCollectionName(collectionName)
                   .withFieldsName(fieldsName)
                   .withRows(rows)
                   .build();
    
           milvusClient.insert(insertParam);
           milvusClient.flush(collectionName);
       }
    
       public SearchResults search(String collectionName, List<Float> queryVector, String category) {
           List<String> searchOutputFields = Arrays.asList("id", "category");
           String expression = "category == "" + category + """;
    
           SearchParam searchParam = new SearchParam.Builder()
                   .withCollectionName(collectionName)
                   .withVectors(Arrays.asList(queryVector))
                   .withExpr(expression)
                   .withTopK(10)
                   .with आउटपुटFields(searchOutputFields)
                   .build();
    
           R<SearchResults> searchResults = milvusClient.search(searchParam);
    
           return searchResults.getData();
       }
    
       public static void main(String[] args) throws InterruptedException {
           MetadataFiltering metadataFiltering = new MetadataFiltering("localhost", 19530);
           String collectionName = "my_collection";
           int dimension = 384; // 根据 embedding 模型确定维度
    
           // 创建 collection
           metadataFiltering.createCollection(collectionName, dimension);
    
           // 创建 index
           metadataFiltering.createIndex(collectionName);
    
           // 加载 collection
           metadataFiltering.loadCollection(collectionName);
    
           // 示例数据
           List<Long> ids = Arrays.asList(1L, 2L);
           List<List<Float>> vectors = new ArrayList<>();
           vectors.add(Arrays.asList(new Float[dimension])); // 替换为实际的 embedding 向量
           vectors.add(Arrays.asList(new Float[dimension])); // 替换为实际的 embedding 向量
           List<String> categories = Arrays.asList("category1", "category2");
    
           // 插入向量
           metadataFiltering.insertVectors(collectionName, ids, vectors, categories);
    
           TimeUnit.SECONDS.sleep(5);
    
           // 搜索
           List<Float> queryVector = Arrays.asList(new Float[dimension]); // 替换为实际的查询向量
           String category = "category1";
           SearchResults searchResults = metadataFiltering.search(collectionName, queryVector, category);
    
           System.out.println("Search results:");
           searchResults.getFieldData("id").forEach(System.out::println);
           searchResults.getFieldData("category").forEach(System.out::println);
       }
    }

    这个例子演示了如何在 Milvus 中使用元数据过滤。 在创建 collection 时,定义了 category 字段,并在插入数据时,将 category 信息与向量一起存储。 在搜索时,可以使用 expression 参数指定过滤条件。

  6. 评估指标与迭代优化:

    需要使用合适的评估指标来衡量向量召回的精确度,例如 Precision@K、Recall@K、NDCG@K 等。 根据评估结果,不断调整优化策略,例如调整 chunk 大小、选择不同的 Embedding 模型、使用不同的查询扩展方法等。 这是一个迭代的过程,需要不断尝试和改进。

    评估指标 描述 优点 缺点
    Precision@K 在返回的前 K 个结果中,有多少个是相关的。 简单易懂,计算方便 只考虑了前 K 个结果,忽略了排序顺序,对于 K 值敏感
    Recall@K 在所有相关的文档中,有多少个被返回到前 K 个结果中。 衡量了召回率,避免只关注精确度 只考虑了前 K 个结果,忽略了排序顺序,对于 K 值敏感
    NDCG@K 归一化折损累计增益 (Normalized Discounted Cumulative Gain)。 考虑了结果的排序顺序,越相关的文档排在前面,得分越高。 考虑了排序顺序,更全面地衡量了检索效果 计算复杂度较高,需要人工标注相关性等级
    Mean Average Precision (MAP) 对多次查询的 Average Precision (AP) 取平均。 AP 是对单次查询的 Precision 积分,考虑了所有返回结果的精确度和召回率。 综合考虑了 Precision 和 Recall,能够更全面地衡量检索效果 计算复杂度较高,对于不同的查询,AP 值可能差异很大

五、RAG流程整合与服务部署

将上述各个模块整合起来,构建完整的 RAG 流程。

  1. API 封装: 将索引构建、检索、生成等功能封装成 API 接口,方便外部调用。 可以使用 Spring Boot、Micronaut 等框架来构建 API 服务。
  2. 流程编排: 使用流程编排工具 (例如 Apache NiFi、Apache Airflow) 将各个 API 接口串联起来,实现自动化流程。
  3. 服务部署: 将 API 服务和流程编排服务部署到生产环境,例如 Kubernetes、Docker Swarm 等。
  4. 监控与告警: 监控RAG系统的各个环节,包括向量数据库的性能、API服务的响应时间、LLM的生成质量等。设置告警机制,及时发现和解决问题。

六、其他优化方向:

  • Prompt 工程: 设计合适的 Prompt,引导 LLM 生成更准确、更流畅的内容。
  • 知识图谱融合: 将知识图谱融入到 RAG 流程中,提供更丰富的知识来源。
  • 多模态 RAG: 支持图像、音频、视频等多种模态的数据,扩展 RAG 的应用范围。
  • 持续学习: 不断收集用户反馈,优化 RAG 系统的性能。

精准向量召回是基础,构建可维护和可扩展的RAG系统是目标。

七、持续优化是关键

选择合适的 Embedding 模型、调整 Chunk 大小、使用查询扩展、重排序、元数据过滤等策略,并使用合适的评估指标进行迭代优化,最终才能构建一个高效的 RAG 系统。RAG 系统的效果取决于多个因素,需要根据具体的应用场景进行调整和优化。

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注