JAVA 如何构建企业级 RAG 文本向量检索?结合 Elasticsearch + Embedding 实战

构建企业级 RAG 文本向量检索:Elasticsearch + Embedding 实战

大家好,今天我们来聊聊如何构建一个企业级的 RAG(Retrieval Augmented Generation)文本向量检索系统,并结合 Elasticsearch 和 Embedding 技术进行实战。RAG 是一种强大的技术,它将检索模型的优势与生成模型的优势结合起来,可以更好地理解用户的问题并生成相关且准确的答案。

1. RAG 架构概览

RAG 架构通常包含以下几个核心组件:

  • 文档索引 (Document Indexing): 将原始文本数据转换为向量表示,并存储到向量数据库中,以便快速检索。
  • 检索 (Retrieval): 根据用户查询,从向量数据库中检索最相关的文档片段。
  • 生成 (Generation): 将检索到的文档片段作为上下文,输入到生成模型(例如大型语言模型,LLM),生成最终的答案。

2. 技术选型

  • Elasticsearch: 作为向量数据库,负责存储和检索文本向量。Elasticsearch 具有强大的搜索能力、可扩展性和成熟的生态系统。
  • Embedding 模型: 将文本转换为向量表示。常用的 Embedding 模型包括 Sentence Transformers、OpenAI Embeddings、Hugging Face Transformers 等。
  • LLM: 用于生成最终答案。可以选择 OpenAI 的 GPT 系列模型、Google 的 Gemini 模型、或者开源的 LLM,例如 Llama 2。
  • 编程语言: Java 作为后端开发语言,具有强大的企业级应用开发能力。

3. 环境准备

首先,我们需要准备以下环境:

  • Java: JDK 1.8 或更高版本。
  • Maven: 用于管理项目依赖。
  • Elasticsearch: 安装并运行 Elasticsearch 集群。
  • Elasticsearch Java High Level REST Client: 用于与 Elasticsearch 集群进行交互。
  • Embedding 模型库: 例如 Sentence Transformers 的 Java 版本 SentenceSimilarity。
  • LLM API 客户端: 例如 OpenAI Java SDK。

4. 数据准备

我们需要准备一些文本数据作为 RAG 系统的知识库。例如,可以收集公司内部的文档、FAQ、技术博客等。为了演示,我们创建一个简单的示例数据集:

import java.util.ArrayList;
import java.util.List;

public class SampleData {

    public static List<String> getDocuments() {
        List<String> documents = new ArrayList<>();
        documents.add("Elasticsearch is a distributed, RESTful search and analytics engine.");
        documents.add("RAG combines retrieval and generation to improve answer quality.");
        documents.add("Java is a popular programming language for enterprise applications.");
        documents.add("Sentence Transformers are used for generating sentence embeddings.");
        documents.add("Large language models (LLMs) can generate human-quality text.");
        return documents;
    }

    public static void main(String[] args) {
        List<String> docs = getDocuments();
        for(String doc : docs){
            System.out.println(doc);
        }
    }
}

5. Embedding 模型集成

我们将使用 Sentence Transformers 模型将文本转换为向量。 首先,需要在 Maven 项目中添加 SentenceSimilarity 依赖:

<dependency>
    <groupId>com.github.shijiebei</groupId>
    <artifactId>SentenceSimilarity</artifactId>
    <version>1.0.2</version>
</dependency>

然后,编写代码将文本转换为向量:

import com.github.shijiebei.similarity.SentenceSimilarity;
import com.github.shijiebei.similarity.WordEmbedding;
import java.io.IOException;
import java.util.List;
import java.util.ArrayList;

public class EmbeddingService {

    private SentenceSimilarity sentenceSimilarity;

    public EmbeddingService() throws IOException {
        // Initialize SentenceSimilarity with a pre-trained word embedding model
        // Here we're using a simple pre-trained model, but you can use more advanced models
        // such as those from Hugging Face. This is a local model path example.
        WordEmbedding wordEmbedding = new WordEmbedding("/path/to/your/word2vec.bin"); // Replace with your actual path
        sentenceSimilarity = new SentenceSimilarity(wordEmbedding);

        // Optional: You can set similarity metrics here
        // sentenceSimilarity.setSimilarityMetric(SimilarityMetric.COSINE); // Example
    }

    public double[] getEmbedding(String text) {
        try{
        return sentenceSimilarity.sentence2Vec(text);
        } catch(Exception e){
            System.err.println("Error creating embedding: " + e.getMessage());
            return null;
        }

    }

    public static void main(String[] args) throws IOException {
        EmbeddingService embeddingService = new EmbeddingService();
        String text = "This is a sample sentence.";
        double[] embedding = embeddingService.getEmbedding(text);

        if (embedding != null) {
            System.out.println("Embedding for: " + text);
            for (int i = 0; i < embedding.length; i++) {
                System.out.print(embedding[i] + " ");
            }
            System.out.println();
        }
    }
}

注意: 上面的代码中,你需要替换 /path/to/your/word2vec.bin 为你实际的 Word2Vec 模型文件路径。SentenceSimilarity 依赖于一个预训练的 Word2Vec 模型。你可以从网上下载一个,例如 Google News vectors (约 1.5GB)。如果你的任务需要更高的精度,可以考虑使用 Sentence Transformers 模型,例如 all-mpnet-base-v2all-MiniLM-L6-v2,这些模型需要通过 Hugging Face Transformers 库来加载和使用,但 SentenceSimilarity 库本身可能需要做调整才能完美适配 Hugging Face Transformers 的输出。

由于 SentenceSimilarity 库相对简单,且需要预先下载 Word2Vec 模型,因此在实际企业级应用中,更推荐直接使用 Hugging Face Transformers 的 Java 版本 ai.djl.huggingface。 使用方法如下:

首先,添加依赖:

        <dependency>
            <groupId>ai.djl.huggingface</groupId>
            <artifactId>tokenizers</artifactId>
            <version>0.24.0</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.huggingface</groupId>
            <artifactId>transformers</artifactId>
            <version>0.24.0</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-engine</artifactId>
            <version>0.24.0</version>
            <scope>runtime</scope>
        </dependency>
        <dependency>
            <groupId>org.slf4j</groupId>
            <artifactId>slf4j-simple</artifactId>
            <version>2.0.9</version>
        </dependency>

然后,编写代码获取 Embedding:

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.Tokenizer;
import ai.djl.huggingface.transformers.BertModel;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.training.util.DownloadUtils;
import java.io.IOException;
import java.net.URL;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class HuggingFaceEmbeddingService {

    private static final Logger logger = LoggerFactory.getLogger(HuggingFaceEmbeddingService.class);
    private BertModel model;
    private Tokenizer tokenizer;

    public HuggingFaceEmbeddingService(String modelName) throws IOException {
        try {
            tokenizer = Tokenizer.newInstance(modelName);
            model = BertModel.newInstance(modelName);
        } catch (Exception e) {
            logger.error("Failed to initialize Hugging Face model: {}", e.getMessage(), e);
            throw new IOException("Failed to initialize Hugging Face model", e);
        }
    }

    public float[] getEmbedding(String text) {
        try (NDManager manager = NDManager.newBaseManager()) {
            Encoding encoding = tokenizer.encode(text, true);
            long[] indices = encoding.getIds();
            long[] attentionMask = encoding.getAttentionMask();

            NDArray inputIds = manager.create(indices);
            NDArray attentionMaskArray = manager.create(attentionMask);

            NDArray embeddings = model.encode(manager, inputIds.expandDims(0), attentionMaskArray.expandDims(0));

            // Mean pooling
            NDArray sumEmbeddings = embeddings.mul(manager.create(attentionMaskArray)).sum(new int[] {1});
            NDArray sumMask = attentionMaskArray.sum();
            NDArray pooledEmbedding = sumEmbeddings.div(sumMask.expandDims(0).transpose());

            return pooledEmbedding.toFloatArray();

        } catch (Exception e) {
            logger.error("Failed to generate embedding: {}", e.getMessage(), e);
            return null;
        }
    }

    public static void main(String[] args) throws IOException {
        // You can choose different models from Hugging Face Model Hub
        // Recommended: "sentence-transformers/all-mpnet-base-v2" or "sentence-transformers/all-MiniLM-L6-v2"
        String modelName = "sentence-transformers/all-MiniLM-L6-v2";
        HuggingFaceEmbeddingService embeddingService = new HuggingFaceEmbeddingService(modelName);
        String text = "This is a sample sentence for embedding generation.";
        float[] embedding = embeddingService.getEmbedding(text);

        if (embedding != null) {
            System.out.println("Embedding for: " + text);
            System.out.println("Embedding Dimension: " + embedding.length);
            // You can print the embedding but it's usually a large array
            // System.out.println(Arrays.toString(embedding));
        }
    }
}

关键点:

  • 模型选择: sentence-transformers/all-mpnet-base-v2sentence-transformers/all-MiniLM-L6-v2 是常用的 Sentence Transformers 模型,前者精度更高,后者速度更快,体积更小。
  • Tokenization: 使用 Tokenizer 将文本转换为 token IDs。
  • Encoding: tokenizer.encode(text, true) 会返回一个 Encoding 对象,包含 token IDs 和 attention mask。true 表示添加 special tokens (例如 [CLS] 和 [SEP])。
  • Mean Pooling: 对所有 token 的 embeddings 进行平均池化,得到句子级别的 embedding。
  • NDManager: DJL 使用 NDManager 来管理 NDArray 对象,需要确保在使用完 NDManager 后关闭它,防止内存泄漏。使用 try-with-resources 语句可以自动关闭 NDManager。

注意: 第一次运行这段代码时,DJL 会自动下载模型文件,可能需要一些时间。 你需要确保你的机器可以访问 Hugging Face 的模型仓库。

6. Elasticsearch 集成

我们需要将文本向量存储到 Elasticsearch 中。 首先,需要在 Maven 项目中添加 Elasticsearch Java High Level REST Client 依赖:

        <dependency>
            <groupId>org.elasticsearch.client</groupId>
            <artifactId>elasticsearch-rest-high-level-client</artifactId>
            <version>7.17.17</version>
        </dependency>
        <dependency>
            <groupId>org.elasticsearch</groupId>
            <artifactId>elasticsearch</artifactId>
            <version>7.17.17</version>
        </dependency>
        <dependency>
            <groupId>com.fasterxml.jackson.core</groupId>
            <artifactId>jackson-databind</artifactId>
            <version>2.13.0</version>
        </dependency>

然后,编写代码创建 Elasticsearch 索引,并将文本向量存储到 Elasticsearch 中:

import org.apache.http.HttpHost;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.index.IndexResponse;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.RestClient;
import org.elasticsearch.client.RestHighLevelClient;
import org.elasticsearch.client.indices.CreateIndexRequest;
import org.elasticsearch.client.indices.CreateIndexResponse;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.XContentType;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class ElasticsearchService {

    private RestHighLevelClient client;
    private String indexName;

    public ElasticsearchService(String indexName) {
        this.indexName = indexName;
        client = new RestHighLevelClient(
                RestClient.builder(
                        new HttpHost("localhost", 9200, "http"))); // Replace with your Elasticsearch host and port
    }

    public void createIndex(int embeddingDimension) throws IOException {
        CreateIndexRequest request = new CreateIndexRequest(indexName);
        request.settings(Settings.builder()
                .put("index.number_of_shards", 1)
                .put("index.number_of_replicas", 0));

        // Define the mapping for the vector field
        Map<String, Object> mapping = new HashMap<>();
        Map<String, Object> properties = new HashMap<>();
        Map<String, Object> vectorField = new HashMap<>();
        vectorField.put("type", "dense_vector");
        vectorField.put("dims", embeddingDimension);
        vectorField.put("index", "true");
        vectorField.put("similarity", "cosine"); // You can choose different similarity functions

        properties.put("embedding", vectorField);
        properties.put("text", Map.of("type", "text"));  // Add a text field for storing the original text

        mapping.put("properties", properties);
        request.mapping(mapping);

        CreateIndexResponse createIndexResponse = client.indices().create(request, RequestOptions.DEFAULT);
        System.out.println("Index creation: " + createIndexResponse.isAcknowledged());
    }

    public void indexDocument(String text, float[] embedding) throws IOException {
        Map<String, Object> document = new HashMap<>();
        document.put("text", text);
        document.put("embedding", embedding);

        IndexRequest request = new IndexRequest(indexName)
                .source(document, XContentType.JSON);

        IndexResponse indexResponse = client.index(request, RequestOptions.DEFAULT);
        System.out.println("Document indexed with id: " + indexResponse.getId());
    }

    public void close() throws IOException {
        client.close();
    }

    public static void main(String[] args) throws IOException {
        String indexName = "rag_index";
        int embeddingDimension = 384; // Adjust based on your embedding model
        ElasticsearchService esService = new ElasticsearchService(indexName);

        // Create the index
        esService.createIndex(embeddingDimension);

        // Index some sample documents
        HuggingFaceEmbeddingService embeddingService = new HuggingFaceEmbeddingService("sentence-transformers/all-MiniLM-L6-v2"); // Use HuggingFaceEmbeddingService
        List<String> documents = SampleData.getDocuments(); // Your sample documents
        for (String document : documents) {
            float[] embedding = embeddingService.getEmbedding(document);
            if (embedding != null) {
                esService.indexDocument(document, embedding);
            }
        }

        // Close the client
        esService.close();
    }
}

关键点:

  • dense_vector 类型: Elasticsearch 7.x 之后引入了 dense_vector 类型,用于存储向量数据。
  • dims 参数: 指定向量的维度。
  • index 参数: 设置为 true 表示对向量字段建立索引,可以加速检索。
  • similarity 参数: 指定向量相似度计算方法。常用的方法包括 cosine(余弦相似度)、l2_norm(欧几里得距离)等。
  • CreateIndexRequest: 用于创建 Elasticsearch 索引。
  • IndexRequest: 用于向 Elasticsearch 索引添加文档。
  • 向量归一化: 在使用余弦相似度之前,建议对向量进行归一化处理,可以提高检索精度。

7. 向量检索

编写代码,根据用户查询,从 Elasticsearch 中检索最相关的文档:

import org.apache.http.HttpHost;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.RestClient;
import org.elasticsearch.client.RestHighLevelClient;
import org.elasticsearch.common.lucene.search.function.FieldValueFactorFunction;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.functionscore.FieldValueFactorFunctionBuilder;
import org.elasticsearch.index.query.functionscore.FunctionScoreQueryBuilder;
import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

public class VectorSearchService {

    private RestHighLevelClient client;
    private String indexName;

    public VectorSearchService(String indexName) {
        this.indexName = indexName;
        client = new RestHighLevelClient(
                RestClient.builder(
                        new HttpHost("localhost", 9200, "http"))); // Replace with your Elasticsearch host and port
    }

    public List<SearchResult> search(String query, int topK, float[] queryVector) throws IOException {
        SearchRequest searchRequest = new SearchRequest(indexName);
        SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();

        // Build the KNN query
        Map<String, Object> knnQuery = Map.of(
                "field", "embedding",
                "query_vector", queryVector,
                "k", topK,
                "num_candidates", topK * 10  // Adjust based on your needs
        );

        searchSourceBuilder.knnQuery(knnQuery);
        searchSourceBuilder.size(topK);
        searchSourceBuilder.fetchSource(true);  // Retrieve the _source fields

        searchRequest.source(searchSourceBuilder);

        SearchResponse searchResponse = client.search(searchRequest, RequestOptions.DEFAULT);

        List<SearchResult> results = new ArrayList<>();
        for (SearchHit hit : searchResponse.getHits().getHits()) {
            Map<String, Object> source = hit.getSourceAsMap();
            String text = (String) source.get("text");
            float score = hit.getScore();
            results.add(new SearchResult(text, score));
        }

        return results;
    }

    public void close() throws IOException {
        client.close();
    }

    public static void main(String[] args) throws IOException {
        String indexName = "rag_index";
        VectorSearchService searchService = new VectorSearchService(indexName);
        HuggingFaceEmbeddingService embeddingService = new HuggingFaceEmbeddingService("sentence-transformers/all-MiniLM-L6-v2");

        String query = "What is RAG?";
        float[] queryVector = embeddingService.getEmbedding(query);
        int topK = 3;

        List<SearchResult> results = searchService.search(query, topK, queryVector);

        System.out.println("Results for query: " + query);
        for (SearchResult result : results) {
            System.out.println("Text: " + result.getText() + ", Score: " + result.getScore());
        }

        searchService.close();
    }

    static class SearchResult {
        private String text;
        private float score;

        public SearchResult(String text, float score) {
            this.text = text;
            this.score = score;
        }

        public String getText() {
            return text;
        }

        public float getScore() {
            return score;
        }
    }
}

关键点:

  • KNN 查询: Elasticsearch 提供了 KNN (K-Nearest Neighbors) 查询,用于进行向量相似度检索。
  • knnQuery: 指定要检索的向量字段、查询向量、以及返回的 Top K 结果数量。
  • num_candidates: 指定搜索的候选文档数量,通常设置为 Top K 结果数量的 10 倍或更高,以提高检索精度。
  • fetchSource(true): 确保检索结果包含文档的原始数据。
  • 调整 num_candidates: 根据你的数据集大小和性能需求,调整 num_candidates 的值。 如果数据集很大,可能需要增加 num_candidates 的值,以提高召回率,但会增加检索时间。

8. LLM 集成

将检索到的文档片段作为上下文,输入到 LLM,生成最终的答案。 这里以 OpenAI 为例:

import com.theokanning.openai.OpenAiHttpException;
import com.theokanning.openai.completion.CompletionRequest;
import com.theokanning.openai.service.OpenAiService;
import java.util.List;

public class LLMService {

    private OpenAiService service;

    public LLMService(String apiKey) {
        this.service = new OpenAiService(apiKey);
    }

    public String generateAnswer(String query, List<String> context) {
        String prompt = "Answer the following question based on the context provided:n" +
                "Question: " + query + "n" +
                "Context:n" + String.join("n", context) + "n" +
                "Answer:";

        CompletionRequest completionRequest = CompletionRequest.builder()
                .prompt(prompt)
                .model("text-davinci-003") // Or any other suitable model
                .maxTokens(200)
                .temperature(0.7)
                .topP(1.0)
                .frequencyPenalty(0.0)
                .presencePenalty(0.0)
                .build();

        try {
            return service.createCompletion(completionRequest).getChoices().get(0).getText();
        } catch (OpenAiHttpException e) {
            System.err.println("Error calling OpenAI API: " + e.getMessage());
            return "Error generating answer.";
        }
    }

    public static void main(String[] args) {
        String apiKey = System.getenv("OPENAI_API_KEY"); // Replace with your OpenAI API key

        if (apiKey == null || apiKey.isEmpty()) {
            System.err.println("Please set the OPENAI_API_KEY environment variable.");
            return;
        }

        LLMService llmService = new LLMService(apiKey);
        String query = "What is RAG?";

        // Mock context (replace with results from Elasticsearch)
        List<String> context = List.of(
                "RAG combines retrieval and generation to improve answer quality."
        );

        String answer = llmService.generateAnswer(query, context);
        System.out.println("Question: " + query);
        System.out.println("Answer: " + answer);
    }
}

关键点:

  • OpenAiService: 用于与 OpenAI API 进行交互。
  • CompletionRequest: 指定 LLM 的输入 prompt、模型、以及生成参数。
  • Prompt 工程: Prompt 的设计至关重要,需要清晰地告诉 LLM 要做什么,并提供相关的上下文信息。
  • 模型选择: 选择合适的 LLM 模型,例如 text-davinci-003gpt-3.5-turbo 等。 不同的模型具有不同的能力和价格。
  • 生成参数: 调整生成参数,例如 maxTokenstemperaturetopP 等,以控制生成答案的质量和多样性。

9. 整合 RAG 系统

将 Embedding 模型、Elasticsearch 和 LLM 整合到一起,构建完整的 RAG 系统:

import java.io.IOException;
import java.util.List;

public class RAGService {

    private HuggingFaceEmbeddingService embeddingService;
    private VectorSearchService searchService;
    private LLMService llmService;

    public RAGService(String embeddingModelName, String elasticsearchIndexName, String openAIApiKey) throws IOException {
        this.embeddingService = new HuggingFaceEmbeddingService(embeddingModelName);
        this.searchService = new VectorSearchService(elasticsearchIndexName);
        this.llmService = new LLMService(openAIApiKey);
    }

    public String answerQuestion(String query, int topK) throws IOException {
        // 1. Generate embedding for the query
        float[] queryVector = embeddingService.getEmbedding(query);

        if (queryVector == null) {
            return "Could not generate embedding for the query.";
        }

        // 2. Search Elasticsearch for relevant documents
        List<VectorSearchService.SearchResult> searchResults = searchService.search(query, topK, queryVector);

        if (searchResults.isEmpty()) {
            return "No relevant documents found.";
        }

        // 3. Extract context from search results
        List<String> context = searchResults.stream()
                .map(VectorSearchService.SearchResult::getText)
                .toList();

        // 4. Generate answer using LLM
        String answer = llmService.generateAnswer(query, context);
        return answer;
    }

    public void close() throws IOException {
        searchService.close();
    }

    public static void main(String[] args) throws IOException {
        String embeddingModelName = "sentence-transformers/all-MiniLM-L6-v2";
        String elasticsearchIndexName = "rag_index";
        String openAIApiKey = System.getenv("OPENAI_API_KEY");

        if (openAIApiKey == null || openAIApiKey.isEmpty()) {
            System.err.println("Please set the OPENAI_API_KEY environment variable.");
            return;
        }

        RAGService ragService = new RAGService(embeddingModelName, elasticsearchIndexName, openAIApiKey);

        String query = "What are large language models?";
        int topK = 3;

        String answer = ragService.answerQuestion(query, topK);
        System.out.println("Question: " + query);
        System.out.println("Answer: " + answer);

        ragService.close();
    }
}

10. 性能优化

  • 向量索引优化: 选择合适的向量索引算法,例如 HNSW (Hierarchical Navigable Small World),可以提高检索速度。 Elasticsearch 提供了对 HNSW 的支持。
  • 缓存: 对 Embedding 模型和 LLM 的结果进行缓存,可以避免重复计算。
  • 异步处理: 使用异步处理来提高系统的并发能力。
  • 硬件加速: 使用 GPU 加速 Embedding 模型和 LLM 的计算。

11. 总结

我们构建了一个基于 Elasticsearch 和 Embedding 的 RAG 系统,并演示了如何进行文本向量检索和答案生成。通过将检索模型的优势与生成模型的优势结合起来,可以更好地理解用户的问题并生成相关且准确的答案。

12. 优化RAG系统的几个方向

RAG 系统的优化是一个持续的过程,可以通过以下几个方面进行改进:

  • 数据清洗和预处理: 对原始文本数据进行清洗和预处理,可以提高 Embedding 模型的精度。
  • Embedding 模型选择: 选择更适合特定任务的 Embedding 模型,例如领域相关的 Embedding 模型。
  • Prompt 工程: 设计更好的 Prompt,可以提高 LLM 生成答案的质量。
  • 负样本挖掘: 使用负样本挖掘技术,可以提高检索模型的精度。
  • 持续学习: 使用持续学习技术,可以使 RAG 系统不断适应新的数据和用户需求。

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注