构建企业级 RAG 文本向量检索:Elasticsearch + Embedding 实战
大家好,今天我们来聊聊如何构建一个企业级的 RAG(Retrieval Augmented Generation)文本向量检索系统,并结合 Elasticsearch 和 Embedding 技术进行实战。RAG 是一种强大的技术,它将检索模型的优势与生成模型的优势结合起来,可以更好地理解用户的问题并生成相关且准确的答案。
1. RAG 架构概览
RAG 架构通常包含以下几个核心组件:
- 文档索引 (Document Indexing): 将原始文本数据转换为向量表示,并存储到向量数据库中,以便快速检索。
- 检索 (Retrieval): 根据用户查询,从向量数据库中检索最相关的文档片段。
- 生成 (Generation): 将检索到的文档片段作为上下文,输入到生成模型(例如大型语言模型,LLM),生成最终的答案。
2. 技术选型
- Elasticsearch: 作为向量数据库,负责存储和检索文本向量。Elasticsearch 具有强大的搜索能力、可扩展性和成熟的生态系统。
- Embedding 模型: 将文本转换为向量表示。常用的 Embedding 模型包括 Sentence Transformers、OpenAI Embeddings、Hugging Face Transformers 等。
- LLM: 用于生成最终答案。可以选择 OpenAI 的 GPT 系列模型、Google 的 Gemini 模型、或者开源的 LLM,例如 Llama 2。
- 编程语言: Java 作为后端开发语言,具有强大的企业级应用开发能力。
3. 环境准备
首先,我们需要准备以下环境:
- Java: JDK 1.8 或更高版本。
- Maven: 用于管理项目依赖。
- Elasticsearch: 安装并运行 Elasticsearch 集群。
- Elasticsearch Java High Level REST Client: 用于与 Elasticsearch 集群进行交互。
- Embedding 模型库: 例如 Sentence Transformers 的 Java 版本 SentenceSimilarity。
- LLM API 客户端: 例如 OpenAI Java SDK。
4. 数据准备
我们需要准备一些文本数据作为 RAG 系统的知识库。例如,可以收集公司内部的文档、FAQ、技术博客等。为了演示,我们创建一个简单的示例数据集:
import java.util.ArrayList;
import java.util.List;
public class SampleData {
public static List<String> getDocuments() {
List<String> documents = new ArrayList<>();
documents.add("Elasticsearch is a distributed, RESTful search and analytics engine.");
documents.add("RAG combines retrieval and generation to improve answer quality.");
documents.add("Java is a popular programming language for enterprise applications.");
documents.add("Sentence Transformers are used for generating sentence embeddings.");
documents.add("Large language models (LLMs) can generate human-quality text.");
return documents;
}
public static void main(String[] args) {
List<String> docs = getDocuments();
for(String doc : docs){
System.out.println(doc);
}
}
}
5. Embedding 模型集成
我们将使用 Sentence Transformers 模型将文本转换为向量。 首先,需要在 Maven 项目中添加 SentenceSimilarity 依赖:
<dependency>
<groupId>com.github.shijiebei</groupId>
<artifactId>SentenceSimilarity</artifactId>
<version>1.0.2</version>
</dependency>
然后,编写代码将文本转换为向量:
import com.github.shijiebei.similarity.SentenceSimilarity;
import com.github.shijiebei.similarity.WordEmbedding;
import java.io.IOException;
import java.util.List;
import java.util.ArrayList;
public class EmbeddingService {
private SentenceSimilarity sentenceSimilarity;
public EmbeddingService() throws IOException {
// Initialize SentenceSimilarity with a pre-trained word embedding model
// Here we're using a simple pre-trained model, but you can use more advanced models
// such as those from Hugging Face. This is a local model path example.
WordEmbedding wordEmbedding = new WordEmbedding("/path/to/your/word2vec.bin"); // Replace with your actual path
sentenceSimilarity = new SentenceSimilarity(wordEmbedding);
// Optional: You can set similarity metrics here
// sentenceSimilarity.setSimilarityMetric(SimilarityMetric.COSINE); // Example
}
public double[] getEmbedding(String text) {
try{
return sentenceSimilarity.sentence2Vec(text);
} catch(Exception e){
System.err.println("Error creating embedding: " + e.getMessage());
return null;
}
}
public static void main(String[] args) throws IOException {
EmbeddingService embeddingService = new EmbeddingService();
String text = "This is a sample sentence.";
double[] embedding = embeddingService.getEmbedding(text);
if (embedding != null) {
System.out.println("Embedding for: " + text);
for (int i = 0; i < embedding.length; i++) {
System.out.print(embedding[i] + " ");
}
System.out.println();
}
}
}
注意: 上面的代码中,你需要替换 /path/to/your/word2vec.bin 为你实际的 Word2Vec 模型文件路径。SentenceSimilarity 依赖于一个预训练的 Word2Vec 模型。你可以从网上下载一个,例如 Google News vectors (约 1.5GB)。如果你的任务需要更高的精度,可以考虑使用 Sentence Transformers 模型,例如 all-mpnet-base-v2 或 all-MiniLM-L6-v2,这些模型需要通过 Hugging Face Transformers 库来加载和使用,但 SentenceSimilarity 库本身可能需要做调整才能完美适配 Hugging Face Transformers 的输出。
由于 SentenceSimilarity 库相对简单,且需要预先下载 Word2Vec 模型,因此在实际企业级应用中,更推荐直接使用 Hugging Face Transformers 的 Java 版本 ai.djl.huggingface。 使用方法如下:
首先,添加依赖:
<dependency>
<groupId>ai.djl.huggingface</groupId>
<artifactId>tokenizers</artifactId>
<version>0.24.0</version>
</dependency>
<dependency>
<groupId>ai.djl.huggingface</groupId>
<artifactId>transformers</artifactId>
<version>0.24.0</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
<version>0.24.0</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-simple</artifactId>
<version>2.0.9</version>
</dependency>
然后,编写代码获取 Embedding:
import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.Tokenizer;
import ai.djl.huggingface.transformers.BertModel;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.training.util.DownloadUtils;
import java.io.IOException;
import java.net.URL;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class HuggingFaceEmbeddingService {
private static final Logger logger = LoggerFactory.getLogger(HuggingFaceEmbeddingService.class);
private BertModel model;
private Tokenizer tokenizer;
public HuggingFaceEmbeddingService(String modelName) throws IOException {
try {
tokenizer = Tokenizer.newInstance(modelName);
model = BertModel.newInstance(modelName);
} catch (Exception e) {
logger.error("Failed to initialize Hugging Face model: {}", e.getMessage(), e);
throw new IOException("Failed to initialize Hugging Face model", e);
}
}
public float[] getEmbedding(String text) {
try (NDManager manager = NDManager.newBaseManager()) {
Encoding encoding = tokenizer.encode(text, true);
long[] indices = encoding.getIds();
long[] attentionMask = encoding.getAttentionMask();
NDArray inputIds = manager.create(indices);
NDArray attentionMaskArray = manager.create(attentionMask);
NDArray embeddings = model.encode(manager, inputIds.expandDims(0), attentionMaskArray.expandDims(0));
// Mean pooling
NDArray sumEmbeddings = embeddings.mul(manager.create(attentionMaskArray)).sum(new int[] {1});
NDArray sumMask = attentionMaskArray.sum();
NDArray pooledEmbedding = sumEmbeddings.div(sumMask.expandDims(0).transpose());
return pooledEmbedding.toFloatArray();
} catch (Exception e) {
logger.error("Failed to generate embedding: {}", e.getMessage(), e);
return null;
}
}
public static void main(String[] args) throws IOException {
// You can choose different models from Hugging Face Model Hub
// Recommended: "sentence-transformers/all-mpnet-base-v2" or "sentence-transformers/all-MiniLM-L6-v2"
String modelName = "sentence-transformers/all-MiniLM-L6-v2";
HuggingFaceEmbeddingService embeddingService = new HuggingFaceEmbeddingService(modelName);
String text = "This is a sample sentence for embedding generation.";
float[] embedding = embeddingService.getEmbedding(text);
if (embedding != null) {
System.out.println("Embedding for: " + text);
System.out.println("Embedding Dimension: " + embedding.length);
// You can print the embedding but it's usually a large array
// System.out.println(Arrays.toString(embedding));
}
}
}
关键点:
- 模型选择:
sentence-transformers/all-mpnet-base-v2和sentence-transformers/all-MiniLM-L6-v2是常用的 Sentence Transformers 模型,前者精度更高,后者速度更快,体积更小。 - Tokenization: 使用 Tokenizer 将文本转换为 token IDs。
- Encoding:
tokenizer.encode(text, true)会返回一个 Encoding 对象,包含 token IDs 和 attention mask。true表示添加 special tokens (例如 [CLS] 和 [SEP])。 - Mean Pooling: 对所有 token 的 embeddings 进行平均池化,得到句子级别的 embedding。
- NDManager: DJL 使用 NDManager 来管理 NDArray 对象,需要确保在使用完 NDManager 后关闭它,防止内存泄漏。使用
try-with-resources语句可以自动关闭 NDManager。
注意: 第一次运行这段代码时,DJL 会自动下载模型文件,可能需要一些时间。 你需要确保你的机器可以访问 Hugging Face 的模型仓库。
6. Elasticsearch 集成
我们需要将文本向量存储到 Elasticsearch 中。 首先,需要在 Maven 项目中添加 Elasticsearch Java High Level REST Client 依赖:
<dependency>
<groupId>org.elasticsearch.client</groupId>
<artifactId>elasticsearch-rest-high-level-client</artifactId>
<version>7.17.17</version>
</dependency>
<dependency>
<groupId>org.elasticsearch</groupId>
<artifactId>elasticsearch</artifactId>
<version>7.17.17</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>2.13.0</version>
</dependency>
然后,编写代码创建 Elasticsearch 索引,并将文本向量存储到 Elasticsearch 中:
import org.apache.http.HttpHost;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.index.IndexResponse;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.RestClient;
import org.elasticsearch.client.RestHighLevelClient;
import org.elasticsearch.client.indices.CreateIndexRequest;
import org.elasticsearch.client.indices.CreateIndexResponse;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.XContentType;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class ElasticsearchService {
private RestHighLevelClient client;
private String indexName;
public ElasticsearchService(String indexName) {
this.indexName = indexName;
client = new RestHighLevelClient(
RestClient.builder(
new HttpHost("localhost", 9200, "http"))); // Replace with your Elasticsearch host and port
}
public void createIndex(int embeddingDimension) throws IOException {
CreateIndexRequest request = new CreateIndexRequest(indexName);
request.settings(Settings.builder()
.put("index.number_of_shards", 1)
.put("index.number_of_replicas", 0));
// Define the mapping for the vector field
Map<String, Object> mapping = new HashMap<>();
Map<String, Object> properties = new HashMap<>();
Map<String, Object> vectorField = new HashMap<>();
vectorField.put("type", "dense_vector");
vectorField.put("dims", embeddingDimension);
vectorField.put("index", "true");
vectorField.put("similarity", "cosine"); // You can choose different similarity functions
properties.put("embedding", vectorField);
properties.put("text", Map.of("type", "text")); // Add a text field for storing the original text
mapping.put("properties", properties);
request.mapping(mapping);
CreateIndexResponse createIndexResponse = client.indices().create(request, RequestOptions.DEFAULT);
System.out.println("Index creation: " + createIndexResponse.isAcknowledged());
}
public void indexDocument(String text, float[] embedding) throws IOException {
Map<String, Object> document = new HashMap<>();
document.put("text", text);
document.put("embedding", embedding);
IndexRequest request = new IndexRequest(indexName)
.source(document, XContentType.JSON);
IndexResponse indexResponse = client.index(request, RequestOptions.DEFAULT);
System.out.println("Document indexed with id: " + indexResponse.getId());
}
public void close() throws IOException {
client.close();
}
public static void main(String[] args) throws IOException {
String indexName = "rag_index";
int embeddingDimension = 384; // Adjust based on your embedding model
ElasticsearchService esService = new ElasticsearchService(indexName);
// Create the index
esService.createIndex(embeddingDimension);
// Index some sample documents
HuggingFaceEmbeddingService embeddingService = new HuggingFaceEmbeddingService("sentence-transformers/all-MiniLM-L6-v2"); // Use HuggingFaceEmbeddingService
List<String> documents = SampleData.getDocuments(); // Your sample documents
for (String document : documents) {
float[] embedding = embeddingService.getEmbedding(document);
if (embedding != null) {
esService.indexDocument(document, embedding);
}
}
// Close the client
esService.close();
}
}
关键点:
dense_vector类型: Elasticsearch 7.x 之后引入了dense_vector类型,用于存储向量数据。dims参数: 指定向量的维度。index参数: 设置为true表示对向量字段建立索引,可以加速检索。similarity参数: 指定向量相似度计算方法。常用的方法包括cosine(余弦相似度)、l2_norm(欧几里得距离)等。CreateIndexRequest: 用于创建 Elasticsearch 索引。IndexRequest: 用于向 Elasticsearch 索引添加文档。- 向量归一化: 在使用余弦相似度之前,建议对向量进行归一化处理,可以提高检索精度。
7. 向量检索
编写代码,根据用户查询,从 Elasticsearch 中检索最相关的文档:
import org.apache.http.HttpHost;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.RestClient;
import org.elasticsearch.client.RestHighLevelClient;
import org.elasticsearch.common.lucene.search.function.FieldValueFactorFunction;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.functionscore.FieldValueFactorFunctionBuilder;
import org.elasticsearch.index.query.functionscore.FunctionScoreQueryBuilder;
import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
public class VectorSearchService {
private RestHighLevelClient client;
private String indexName;
public VectorSearchService(String indexName) {
this.indexName = indexName;
client = new RestHighLevelClient(
RestClient.builder(
new HttpHost("localhost", 9200, "http"))); // Replace with your Elasticsearch host and port
}
public List<SearchResult> search(String query, int topK, float[] queryVector) throws IOException {
SearchRequest searchRequest = new SearchRequest(indexName);
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
// Build the KNN query
Map<String, Object> knnQuery = Map.of(
"field", "embedding",
"query_vector", queryVector,
"k", topK,
"num_candidates", topK * 10 // Adjust based on your needs
);
searchSourceBuilder.knnQuery(knnQuery);
searchSourceBuilder.size(topK);
searchSourceBuilder.fetchSource(true); // Retrieve the _source fields
searchRequest.source(searchSourceBuilder);
SearchResponse searchResponse = client.search(searchRequest, RequestOptions.DEFAULT);
List<SearchResult> results = new ArrayList<>();
for (SearchHit hit : searchResponse.getHits().getHits()) {
Map<String, Object> source = hit.getSourceAsMap();
String text = (String) source.get("text");
float score = hit.getScore();
results.add(new SearchResult(text, score));
}
return results;
}
public void close() throws IOException {
client.close();
}
public static void main(String[] args) throws IOException {
String indexName = "rag_index";
VectorSearchService searchService = new VectorSearchService(indexName);
HuggingFaceEmbeddingService embeddingService = new HuggingFaceEmbeddingService("sentence-transformers/all-MiniLM-L6-v2");
String query = "What is RAG?";
float[] queryVector = embeddingService.getEmbedding(query);
int topK = 3;
List<SearchResult> results = searchService.search(query, topK, queryVector);
System.out.println("Results for query: " + query);
for (SearchResult result : results) {
System.out.println("Text: " + result.getText() + ", Score: " + result.getScore());
}
searchService.close();
}
static class SearchResult {
private String text;
private float score;
public SearchResult(String text, float score) {
this.text = text;
this.score = score;
}
public String getText() {
return text;
}
public float getScore() {
return score;
}
}
}
关键点:
- KNN 查询: Elasticsearch 提供了 KNN (K-Nearest Neighbors) 查询,用于进行向量相似度检索。
knnQuery: 指定要检索的向量字段、查询向量、以及返回的 Top K 结果数量。num_candidates: 指定搜索的候选文档数量,通常设置为 Top K 结果数量的 10 倍或更高,以提高检索精度。fetchSource(true): 确保检索结果包含文档的原始数据。- 调整
num_candidates: 根据你的数据集大小和性能需求,调整num_candidates的值。 如果数据集很大,可能需要增加num_candidates的值,以提高召回率,但会增加检索时间。
8. LLM 集成
将检索到的文档片段作为上下文,输入到 LLM,生成最终的答案。 这里以 OpenAI 为例:
import com.theokanning.openai.OpenAiHttpException;
import com.theokanning.openai.completion.CompletionRequest;
import com.theokanning.openai.service.OpenAiService;
import java.util.List;
public class LLMService {
private OpenAiService service;
public LLMService(String apiKey) {
this.service = new OpenAiService(apiKey);
}
public String generateAnswer(String query, List<String> context) {
String prompt = "Answer the following question based on the context provided:n" +
"Question: " + query + "n" +
"Context:n" + String.join("n", context) + "n" +
"Answer:";
CompletionRequest completionRequest = CompletionRequest.builder()
.prompt(prompt)
.model("text-davinci-003") // Or any other suitable model
.maxTokens(200)
.temperature(0.7)
.topP(1.0)
.frequencyPenalty(0.0)
.presencePenalty(0.0)
.build();
try {
return service.createCompletion(completionRequest).getChoices().get(0).getText();
} catch (OpenAiHttpException e) {
System.err.println("Error calling OpenAI API: " + e.getMessage());
return "Error generating answer.";
}
}
public static void main(String[] args) {
String apiKey = System.getenv("OPENAI_API_KEY"); // Replace with your OpenAI API key
if (apiKey == null || apiKey.isEmpty()) {
System.err.println("Please set the OPENAI_API_KEY environment variable.");
return;
}
LLMService llmService = new LLMService(apiKey);
String query = "What is RAG?";
// Mock context (replace with results from Elasticsearch)
List<String> context = List.of(
"RAG combines retrieval and generation to improve answer quality."
);
String answer = llmService.generateAnswer(query, context);
System.out.println("Question: " + query);
System.out.println("Answer: " + answer);
}
}
关键点:
OpenAiService: 用于与 OpenAI API 进行交互。CompletionRequest: 指定 LLM 的输入 prompt、模型、以及生成参数。- Prompt 工程: Prompt 的设计至关重要,需要清晰地告诉 LLM 要做什么,并提供相关的上下文信息。
- 模型选择: 选择合适的 LLM 模型,例如
text-davinci-003、gpt-3.5-turbo等。 不同的模型具有不同的能力和价格。 - 生成参数: 调整生成参数,例如
maxTokens、temperature、topP等,以控制生成答案的质量和多样性。
9. 整合 RAG 系统
将 Embedding 模型、Elasticsearch 和 LLM 整合到一起,构建完整的 RAG 系统:
import java.io.IOException;
import java.util.List;
public class RAGService {
private HuggingFaceEmbeddingService embeddingService;
private VectorSearchService searchService;
private LLMService llmService;
public RAGService(String embeddingModelName, String elasticsearchIndexName, String openAIApiKey) throws IOException {
this.embeddingService = new HuggingFaceEmbeddingService(embeddingModelName);
this.searchService = new VectorSearchService(elasticsearchIndexName);
this.llmService = new LLMService(openAIApiKey);
}
public String answerQuestion(String query, int topK) throws IOException {
// 1. Generate embedding for the query
float[] queryVector = embeddingService.getEmbedding(query);
if (queryVector == null) {
return "Could not generate embedding for the query.";
}
// 2. Search Elasticsearch for relevant documents
List<VectorSearchService.SearchResult> searchResults = searchService.search(query, topK, queryVector);
if (searchResults.isEmpty()) {
return "No relevant documents found.";
}
// 3. Extract context from search results
List<String> context = searchResults.stream()
.map(VectorSearchService.SearchResult::getText)
.toList();
// 4. Generate answer using LLM
String answer = llmService.generateAnswer(query, context);
return answer;
}
public void close() throws IOException {
searchService.close();
}
public static void main(String[] args) throws IOException {
String embeddingModelName = "sentence-transformers/all-MiniLM-L6-v2";
String elasticsearchIndexName = "rag_index";
String openAIApiKey = System.getenv("OPENAI_API_KEY");
if (openAIApiKey == null || openAIApiKey.isEmpty()) {
System.err.println("Please set the OPENAI_API_KEY environment variable.");
return;
}
RAGService ragService = new RAGService(embeddingModelName, elasticsearchIndexName, openAIApiKey);
String query = "What are large language models?";
int topK = 3;
String answer = ragService.answerQuestion(query, topK);
System.out.println("Question: " + query);
System.out.println("Answer: " + answer);
ragService.close();
}
}
10. 性能优化
- 向量索引优化: 选择合适的向量索引算法,例如 HNSW (Hierarchical Navigable Small World),可以提高检索速度。 Elasticsearch 提供了对 HNSW 的支持。
- 缓存: 对 Embedding 模型和 LLM 的结果进行缓存,可以避免重复计算。
- 异步处理: 使用异步处理来提高系统的并发能力。
- 硬件加速: 使用 GPU 加速 Embedding 模型和 LLM 的计算。
11. 总结
我们构建了一个基于 Elasticsearch 和 Embedding 的 RAG 系统,并演示了如何进行文本向量检索和答案生成。通过将检索模型的优势与生成模型的优势结合起来,可以更好地理解用户的问题并生成相关且准确的答案。
12. 优化RAG系统的几个方向
RAG 系统的优化是一个持续的过程,可以通过以下几个方面进行改进:
- 数据清洗和预处理: 对原始文本数据进行清洗和预处理,可以提高 Embedding 模型的精度。
- Embedding 模型选择: 选择更适合特定任务的 Embedding 模型,例如领域相关的 Embedding 模型。
- Prompt 工程: 设计更好的 Prompt,可以提高 LLM 生成答案的质量。
- 负样本挖掘: 使用负样本挖掘技术,可以提高检索模型的精度。
- 持续学习: 使用持续学习技术,可以使 RAG 系统不断适应新的数据和用户需求。