Java RAG 召回结果不相关?Embedding 模型选择对比与优化
大家好,今天我们来深入探讨一个在Java RAG(Retrieval-Augmented Generation,检索增强生成)应用中经常遇到的问题:召回结果不相关。RAG的核心在于从知识库中检索相关文档,并将其作为上下文提供给生成模型,以提高生成结果的准确性和相关性。如果召回阶段出了问题,后续的生成质量自然会受到影响。
本次讲座将围绕以下几个方面展开:
- RAG 流程回顾与问题诊断: 简要回顾RAG流程,并详细分析召回结果不相关的常见原因。
- Embedding 模型选择: 对比几种常用的Embedding模型,包括其原理、优缺点以及适用场景,并通过代码示例展示如何在Java RAG应用中使用它们。
- Embedding 模型优化: 探讨优化Embedding模型效果的各种策略,包括数据预处理、微调技术以及向量索引的选择。
- 代码实战:Java RAG 示例: 提供一个基于Java的RAG示例,并演示如何通过调整Embedding模型来改善召回结果。
- 评估指标与监控: 介绍评估召回效果的常用指标,并讨论如何在生产环境中监控RAG系统的性能。
1. RAG 流程回顾与问题诊断
RAG流程通常包含以下几个步骤:
- 数据准备: 将知识库中的文档分割成块(chunks)。Chunk的大小会影响召回效果,过小可能缺乏上下文信息,过大则可能包含无关信息。
- Embedding 生成: 使用Embedding模型将每个Chunk转换为向量表示。这些向量将被存储在向量数据库中。
- 检索: 接收用户查询,将其转换为向量,并在向量数据库中搜索与查询向量最相似的Chunk向量。
- 生成: 将检索到的Chunk作为上下文提供给生成模型(例如,LLM),生成最终的答案。
召回结果不相关的原因有很多,常见的包括:
- Chunk质量差: Chunk分割方式不合理,导致Chunk包含过多噪声或缺乏关键信息。
- Embedding模型不适合: 选择的Embedding模型无法准确捕捉文本的语义信息,导致相似的文本向量距离较远,不相似的文本向量距离较近。
- 向量索引不佳: 向量索引的选择会影响检索效率和准确性。不合适的索引可能导致无法召回最相关的Chunk。
- 查询理解偏差: 查询本身存在歧义或表达不清,导致Embedding模型无法准确理解查询意图。
- 数据偏差: 知识库中存在偏差,导致模型学习到的向量表示无法泛化到新的查询。
2. Embedding 模型选择
选择合适的Embedding模型是提高RAG召回效果的关键。以下是一些常用的Embedding模型:
| 模型名称 | 优点 | 缺点 | 适用场景 |
|---|---|---|---|
| Word2Vec/GloVe | 训练速度快,计算资源消耗低,易于部署。 | 无法捕捉上下文信息,对一词多义处理效果差,无法处理OOV(Out-of-Vocabulary)问题。 | 适用于对性能要求较高,数据量较小,对语义理解要求不高的场景。例如,简单的文本分类任务。 |
| FastText | 速度快,能够处理OOV问题,对 morphologically rich 的语言有较好的效果。 | 无法捕捉上下文信息,对一词多义处理效果差。 | 适用于需要处理OOV问题,并且对速度有较高要求的场景。例如,文本分类,词性标注。 |
| Sentence-BERT | 能够生成句子的embedding,考虑了上下文信息,在语义相似度任务上表现优秀。 | 训练和推理速度相对较慢,资源消耗较高。 | 适用于对语义理解要求较高,需要生成句子级别embedding的场景。例如,语义搜索,文本聚类,问答系统。 |
| OpenAI Embeddings (e.g., text-embedding-ada-002) | 易于使用,性能良好,支持多种语言,可以处理长文本。 | 需要通过API调用,存在一定的成本,可能存在隐私问题。 | 适用于对性能要求较高,需要处理长文本,并且能够接受API调用的场景。例如,知识库问答,语义搜索。 |
| BGE (BAAI General Embedding) | 在多个benchmark上取得了领先的性能,支持对比学习,对中文文本有较好的效果。 | 模型较大,训练和推理速度相对较慢。 | 适用于对性能要求较高,需要处理中文文本,并且希望获得最佳性能的场景。例如,知识库问答,语义搜索。 |
2.1 Word2Vec 示例 (使用 Deeplearning4j)
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.io.File;
import java.io.IOException;
public class Word2VecExample {
public static void main(String[] args) throws Exception {
// 1. 加载文本数据
String filePath = "your_text_data.txt"; // 替换成你的文本文件路径
SentenceIterator iter = new BasicLineIterator(filePath);
// 2. 配置 Tokenizer
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
// 3. 构建 Word2Vec 模型
Word2Vec vec = new Word2Vec.Builder()
.minWordFrequency(5) // 设置最小词频
.iterations(1) // 设置迭代次数
.layerSize(100) // 设置词向量维度
.seed(42) // 设置随机种子
.windowSize(5) // 设置窗口大小
.iterate(iter) // 设置 SentenceIterator
.tokenizerFactory(t) // 设置 TokenizerFactory
.build();
// 4. 训练模型
System.out.println("Building model....");
vec.fit();
// 5. 保存模型
WordVectorSerializer.writeWord2VecModel(vec, "word2vec.model");
// 6. 使用模型
INDArray wordVector = vec.getWordVectorMatrix("example"); // 获取 "example" 的词向量
System.out.println("Word vector for 'example': " + wordVector);
double similarity = vec.similarity("example", "sample"); // 计算 "example" 和 "sample" 的相似度
System.out.println("Similarity between 'example' and 'sample': " + similarity);
// 7. 计算句子向量(简单平均)
String sentence = "This is an example sentence.";
INDArray sentenceVector = getSentenceVector(sentence, vec, t);
System.out.println("Sentence vector: " + sentenceVector);
}
// 计算句子向量(简单平均)
public static INDArray getSentenceVector(String sentence, Word2Vec vec, TokenizerFactory t) {
t.create(sentence).getTokens().stream().filter(vec::hasWord).toList();
List<String> tokens = t.create(sentence).getTokens().stream().filter(vec::hasWord).toList();
if (tokens.isEmpty()) {
return Nd4j.zeros(vec.getLayerSize());
}
INDArray sentenceVector = Nd4j.zeros(vec.getLayerSize());
for (String token : tokens) {
sentenceVector.addi(vec.getWordVectorMatrix(token));
}
return sentenceVector.divi(tokens.size());
}
}
代码解释:
- 数据加载: 使用
BasicLineIterator从文本文件中逐行读取句子。 - Tokenizer配置: 使用
DefaultTokenizerFactory进行分词,并使用CommonPreprocessor进行预处理(例如,转换为小写,去除标点符号)。 - Word2Vec模型构建: 使用
Word2Vec.Builder配置 Word2Vec 模型,包括最小词频、迭代次数、词向量维度、随机种子和窗口大小。 - 模型训练: 使用
vec.fit()训练 Word2Vec 模型。 - 模型保存: 使用
WordVectorSerializer.writeWord2VecModel()将训练好的模型保存到文件中。 - 模型使用:
vec.getWordVectorMatrix("example"): 获取单词 "example" 的词向量。vec.similarity("example", "sample"): 计算单词 "example" 和 "sample" 之间的相似度。getSentenceVector(): 计算句子的向量表示(简单平均所有单词的向量)。
2.2 Sentence-BERT 示例 (使用 Hugging Face Transformers Java API)
import ai.djl.ModelException;
import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslateException;
import ai.djl.translate.TranslatorContext;
import java.io.IOException;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
public class SentenceBertExample {
public static void main(String[] args) throws IOException, ModelException, TranslateException {
String modelName = "sentence-transformers/all-mpnet-base-v2"; // 选择 Sentence-BERT 模型
// 1. 构建 Criteria
Criteria<String, float[]> criteria = Criteria.builder()
.setTypes(String.class, float[].class)
.optModelPath(Paths.get("models")) // 可选:指定模型下载路径
.optModelName(modelName)
.optTranslator(new SentenceTranslator(modelName))
.optEngine("PyTorch") // 确保安装了 PyTorch 引擎
.build();
// 2. 加载模型
try (ZooModel<String, float[]> model = criteria.loadModel()) {
// 3. 创建 Predictor
try (Predictor<String, float[]> predictor = model.newPredictor()) {
// 4. 生成句子 Embedding
String sentence1 = "This is a sentence.";
String sentence2 = "This is another sentence.";
float[] embedding1 = predictor.predict(sentence1);
float[] embedding2 = predictor.predict(sentence2);
System.out.println("Embedding for sentence 1: " + Arrays.toString(embedding1));
System.out.println("Embedding for sentence 2: " + Arrays.toString(embedding2));
// 5. 计算相似度 (Cosine Similarity)
double similarity = cosineSimilarity(embedding1, embedding2);
System.out.println("Similarity between sentence 1 and sentence 2: " + similarity);
}
}
}
// Cosine Similarity 计算
private static double cosineSimilarity(float[] vectorA, float[] vectorB) {
double dotProduct = 0.0;
double normA = 0.0;
double normB = 0.0;
for (int i = 0; i < vectorA.length; i++) {
dotProduct += vectorA[i] * vectorB[i];
normA += Math.pow(vectorA[i], 2);
normB += Math.pow(vectorB[i], 2);
}
return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
}
// 自定义 Translator
private static class SentenceTranslator implements Translator<String, float[]> {
private final String modelName;
private HuggingFaceTokenizer tokenizer;
public SentenceTranslator(String modelName) {
this.modelName = modelName;
}
@Override
public void prepare(TranslatorContext ctx) throws IOException {
tokenizer = HuggingFaceTokenizer.newInstance(modelName);
}
@Override
public NDArray processInput(TranslatorContext ctx, String input) {
NDManager manager = ctx.getNDManager();
Encoding encoding = tokenizer.encode(input);
long[] indices = encoding.getIds();
NDArray inputIds = manager.create(indices);
inputIds.setName("input_ids");
return inputIds.expandDims(0);
}
@Override
public float[] processOutput(TranslatorContext ctx, NDArray output) {
// Pooling layer (Mean Pooling)
NDArray inputMask = ctx.getNDManager().ones(new long[]{1, output.getShape()[1]});
NDArray expandedMask = inputMask.expandDims(-1);
NDArray maskedOutput = output.mul(expandedMask);
NDArray sumVectors = maskedOutput.sum(new int[]{1});
NDArray sumMask = inputMask.sum(new int[]{1}).expandDims(1);
NDArray pooled = sumVectors.div(sumMask);
return pooled.toFloatArray();
}
@Override
public Batchifier getBatchifier() {
return Batchifier.STACK;
}
}
}
代码解释:
- 依赖: 确保项目中包含
ai.djl相关的依赖,包括ai.djl-api,ai.djl-model-zoo,ai.djl-huggingface,ai.djl-pytorch(或其他你选择的引擎). - Criteria构建: 使用
Criteria.builder()定义模型的加载标准,包括模型类型、模型路径、模型名称和自定义的Translator。 - 模型加载: 使用
criteria.loadModel()加载 Sentence-BERT 模型。 - Predictor创建: 使用
model.newPredictor()创建Predictor对象,用于执行推理。 - 句子 Embedding 生成: 使用
predictor.predict(sentence)生成句子的 Embedding 向量。 - 相似度计算: 使用
cosineSimilarity()函数计算两个 Embedding 向量之间的 Cosine 相似度。 - 自定义 Translator:
SentenceTranslator类实现了Translator接口,用于处理输入和输出。prepare()方法初始化HuggingFaceTokenizer。processInput()方法将输入句子转换为模型所需的NDArray格式(token IDs)。processOutput()方法将模型的输出转换为 Embedding 向量(使用 Mean Pooling)。getBatchifier()方法指定 Batchifier 类型。
注意事项:
- Hugging Face 模型: 代码中使用的是
sentence-transformers/all-mpnet-base-v2模型,你也可以选择其他 Sentence-BERT 模型。 - 引擎选择:
optEngine("PyTorch")指定使用 PyTorch 引擎。你需要确保已经安装了相应的引擎。 - 依赖管理: 确保你的项目中包含了所有必要的
ai.djl依赖。 - 模型下载: 模型会自动下载到
optModelPath指定的路径下。 首次运行可能需要一些时间下载模型。 - Pooling策略:
processOutput方法中使用了 Mean Pooling。 其他的Pooling策略(例如,CLS token embedding)也可以尝试。
2.3 OpenAI Embeddings 示例
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.util.ArrayList;
import java.util.List;
public class OpenAIEmbeddingsExample {
private static final String API_KEY = "YOUR_OPENAI_API_KEY"; // 替换成你的 OpenAI API Key
private static final String MODEL_NAME = "text-embedding-ada-002"; // 选择 Embedding 模型
public static void main(String[] args) throws IOException, InterruptedException {
String text1 = "This is the first text.";
String text2 = "This is the second text.";
List<Double> embedding1 = getEmbedding(text1);
List<Double> embedding2 = getEmbedding(text2);
System.out.println("Embedding for text 1: " + embedding1);
System.out.println("Embedding for text 2: " + embedding2);
// 计算相似度 (Cosine Similarity)
double similarity = cosineSimilarity(embedding1, embedding2);
System.out.println("Similarity between text 1 and text 2: " + similarity);
}
// 获取 OpenAI Embedding
public static List<Double> getEmbedding(String text) throws IOException, InterruptedException {
String apiUrl = "https://api.openai.com/v1/embeddings";
// 构建请求体
String requestBody = String.format("{"input": "%s", "model": "%s"}", text, MODEL_NAME);
// 创建 HttpClient
HttpClient client = HttpClient.newHttpClient();
// 创建 HttpRequest
HttpRequest request = HttpRequest.newBuilder()
.uri(URI.create(apiUrl))
.header("Content-Type", "application/json")
.header("Authorization", "Bearer " + API_KEY)
.POST(HttpRequest.BodyPublishers.ofString(requestBody))
.build();
// 发送请求并获取响应
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
// 解析 JSON 响应
ObjectMapper mapper = new ObjectMapper();
JsonNode root = mapper.readTree(response.body());
// 提取 Embedding 向量
List<Double> embedding = new ArrayList<>();
JsonNode data = root.get("data").get(0);
JsonNode embeddingNode = data.get("embedding");
if (embeddingNode.isArray()) {
for (JsonNode element : embeddingNode) {
embedding.add(element.asDouble());
}
}
return embedding;
}
// Cosine Similarity 计算
private static double cosineSimilarity(List<Double> vectorA, List<Double> vectorB) {
double dotProduct = 0.0;
double normA = 0.0;
double normB = 0.0;
for (int i = 0; i < vectorA.size(); i++) {
dotProduct += vectorA.get(i) * vectorB.get(i);
normA += Math.pow(vectorA.get(i), 2);
normB += Math.pow(vectorB.get(i), 2);
}
return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
}
}
代码解释:
- API Key: 替换
YOUR_OPENAI_API_KEY为你的 OpenAI API Key。 - 依赖: 确保项目中包含
com.fasterxml.jackson.databind依赖 (Jackson JSON 库)。 - 构建请求: 使用
HttpClient构建 HTTP POST 请求,包含 OpenAI API 的 URL、请求头 (Content-Type, Authorization) 和请求体 (JSON 格式的输入文本和模型名称)。 - 发送请求: 使用
client.send()发送请求并获取响应。 - 解析响应: 使用
ObjectMapper解析 JSON 响应,提取 Embedding 向量。 - Cosine 相似度计算: 使用
cosineSimilarity()函数计算两个 Embedding 向量之间的 Cosine 相似度。 - 错误处理: 实际应用中,需要添加更完善的错误处理机制,例如处理 API 调用失败、JSON 解析错误等。
注意事项:
- OpenAI API Key: 你需要拥有一个有效的 OpenAI API Key 才能使用 OpenAI Embeddings API。
- 模型选择:
MODEL_NAME变量指定了使用的 OpenAI Embedding 模型。text-embedding-ada-002是一个常用的模型,你也可以选择其他模型。 - API 调用限制: OpenAI API 存在调用频率限制,需要注意控制调用频率,避免超出限制。
- 费用: 使用 OpenAI API 会产生费用,需要了解 OpenAI 的定价策略。
- 依赖管理: 确保你的项目中包含了 Jackson JSON 库的依赖。
3. Embedding 模型优化
仅仅选择一个合适的Embedding模型是不够的,还需要对其进行优化,以进一步提高召回效果。
- 数据预处理: 对原始文本进行清洗和预处理,例如去除HTML标签、特殊字符、停用词等。这可以减少噪声,提高Embedding模型的准确性。
- Chunk优化: 调整Chunk的大小和分割方式。可以尝试固定大小的Chunk,也可以使用基于语义的Chunk分割方法,例如将句子或段落作为Chunk。
- 微调 (Fine-tuning): 如果知识库的领域比较特殊,可以考虑使用知识库中的数据对Embedding模型进行微调。这可以使模型更好地适应特定领域的语义信息。
- 对比学习 (Contrastive Learning): 一种常用的微调方法,通过构造正负样本对,训练模型区分相似和不相似的文本。
- 向量索引选择: 选择合适的向量索引可以提高检索效率和准确性。常用的向量索引包括:
- Annoy: 适用于高维向量的近似最近邻搜索。
- HNSW: 一种基于图的索引,能够提供较高的检索精度和效率。
- Faiss: Facebook AI Similarity Search,提供了多种向量索引算法,适用于大规模向量搜索。
3.1 使用 BGE 模型进行对比学习微调示例 (伪代码)
由于篇幅限制,这里只提供伪代码,展示对比学习微调的思路。实际实现需要使用深度学习框架(例如,PyTorch)。
# 伪代码 (Python)
import torch
from transformers import AutoModel, AutoTokenizer
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
# 1. 定义数据集
class ContrastiveDataset(Dataset):
def __init__(self, data):
self.data = data # data: list of tuples (anchor, positive, negative)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
# 2. 定义模型
class BGEModel(torch.nn.Module):
def __init__(self, model_name):
super().__init__()
self.model = AutoModel.from_pretrained(model_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
def forward(self, text):
encoded_input = self.tokenizer(text, padding=True, truncation=True, return_tensors='pt')
output = self.model(**encoded_input).pooler_output # 或者使用 mean pooling
return output
# 3. 定义损失函数 (InfoNCE Loss)
def info_nce_loss(anchor, positive, negative, temperature=0.05):
# 计算相似度
similarity_positive = torch.cosine_similarity(anchor, positive, dim=1)
similarity_negative = torch.cosine_similarity(anchor, negative, dim=1)
# 计算 logits
logits = torch.cat([similarity_positive, similarity_negative], dim=0)
logits = logits / temperature
# 创建 labels
labels = torch.tensor([0]).to(anchor.device)
# 计算损失
loss = torch.nn.CrossEntropyLoss()(logits.unsqueeze(0), labels)
return loss
# 4. 训练循环
def train(model, dataloader, optimizer, epochs=10):
model.train()
for epoch in range(epochs):
for anchor, positive, negative in dataloader:
# 将数据移动到 GPU (如果可用)
anchor = anchor.to(device)
positive = positive.to(device)
negative = negative.to(device)
# 计算 embeddings
anchor_embedding = model(anchor)
positive_embedding = model(positive)
negative_embedding = model(negative)
# 计算损失
loss = info_nce_loss(anchor_embedding, positive_embedding, negative_embedding)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch: {epoch}, Loss: {loss.item()}")
# 5. 准备数据
# data = [("anchor text", "positive text", "negative text"), ...] # 构建正负样本对
# 6. 初始化模型、优化器和数据加载器
model_name = "BAAI/bge-small-en-v1.5" # 或者其他 BGE 模型
model = BGEModel(model_name).to(device)
optimizer = AdamW(model.parameters(), lr=1e-5)
dataset = ContrastiveDataset(data)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
# 7. 开始训练
train(model, dataloader, optimizer)
# 8. 保存微调后的模型
model.model.save_pretrained("fine_tuned_bge_model")
model.tokenizer.save_pretrained("fine_tuned_bge_model")
代码解释:
- ContrastiveDataset: 自定义数据集,用于加载正负样本对。每个样本包含一个 anchor 文本、一个 positive 文本(与 anchor 相似)和一个 negative 文本(与 anchor 不相似)。
- BGEModel: 加载 BGE 模型和 tokenizer,并定义
forward方法,用于计算文本的 embedding。 - info_nce_loss: 计算 InfoNCE Loss,用于训练模型区分相似和不相似的文本。
- train: 训练循环,包括前向传播、损失计算、反向传播和优化。
- 数据准备: 需要根据你的知识库构建正负样本对。
- 模型初始化: 初始化 BGE 模型、AdamW 优化器和 DataLoader。
- 训练: 调用
train函数开始训练。 - 模型保存: 保存微调后的模型和 tokenizer。
注意事项:
- 正负样本构建: 正负样本的质量对微调效果至关重要。可以使用各种方法构建正负样本,例如:
- 同义词替换: 将 anchor 文本中的一些词替换成同义词,作为 positive 样本。
- 随机替换: 将 anchor 文本中的一些词随机替换成其他词,作为 negative 样本。
- BM25 检索: 使用 BM25 检索与 anchor 文本相关的文档,将检索到的文档作为 positive 样本,将其他文档作为 negative 样本。
- 超参数调整: 需要根据实际情况调整超参数,例如学习率、batch size、temperature 等。
- 评估: 使用评估指标(例如,Recall@K)评估微调后的模型效果。
- 硬件资源: 微调需要大量的计算资源,建议使用 GPU 进行训练。
4. 代码实战:Java RAG 示例
下面提供一个简化的Java RAG示例,演示如何使用Sentence-BERT模型进行召回。
import ai.djl.ModelException;
import ai.djl.TranslateException;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
public class JavaRAGExample {
public static void main(String[] args) throws IOException, ModelException, TranslateException {
// 1. 准备知识库
List<String> knowledgeBase = new ArrayList<>();
knowledgeBase.add("Java is a high-level, class-based, object-oriented programming language.");
knowledgeBase.add("Java is designed to have as few implementation dependencies as possible.");
knowledgeBase.add("Java is intended to let application developers write once, run anywhere (WORA).");
knowledgeBase.add("Java is currently one of the most popular programming languages in use.");
// 2. 创建 SentenceBertEmbedder (使用 SentenceBertExample 中的代码)
SentenceBertExample.SentenceTranslator translator = new SentenceBertExample.SentenceTranslator("sentence-transformers/all-mpnet-base-v2");
// 3. 生成知识库的 Embedding
List<float[]> knowledgeBaseEmbeddings = new ArrayList<>();
for (String text : knowledgeBase) {
knowledgeBaseEmbeddings.add(getSentenceEmbedding(text, translator));
}
// 4. 用户查询
String query = "What are the key features of Java?";
// 5. 生成查询的 Embedding
float[] queryEmbedding = getSentenceEmbedding(query, translator);
// 6. 检索最相关的文档
int topK = 2;
List<Integer> relevantDocumentIndices = retrieveTopK(queryEmbedding, knowledgeBaseEmbeddings, topK);
// 7. 打印召回结果
System.out.println("Query: " + query);
System.out.println("Retrieved documents:");
for (int index : relevantDocumentIndices) {
System.out.println("- " + knowledgeBase.get(index));
}
}
// 获取句子 Embedding
private static float[] getSentenceEmbedding(String text, SentenceBertExample.SentenceTranslator translator) throws IOException, ModelException, TranslateException {
// 使用 SentenceBertExample 中的代码,简化起见,这里省略了 Criteria 和 ZooModel 的创建
// 实际应用中,需要创建 Criteria 和 ZooModel,并使用 Predictor 进行推理
// 这里直接调用 translator.processInput 和 translator.processOutput
// 注意: 这只是一个简化示例,实际需要使用 DJL 的完整流程
// 创建一个假的 NDManager,用于 translator
ai.djl.ndarray.NDManager manager = ai.djl.ndarray.NDManager.newBaseManager();
// 创建 TranslatorContext
ai.djl.translate.TranslatorContext ctx = new ai.djl.translate.NoBatchifyTranslatorContext(manager);
// 准备 Translator
translator.prepare(ctx);
// 处理输入
ai.djl.ndarray.NDArray input = translator.processInput(ctx, text);
// 模拟模型输出 (这里假设模型输出已经加载到 NDArray 中)
ai.djl.ndarray.NDArray output = ai.djl.ndarray.NDArrays.randn(manager, new long[]{1, 768}); // 768 是 all-mpnet-base-v2 的 embedding 维度
// 处理输出
float[] embedding = translator.processOutput(ctx, output);
return embedding;
}
// 检索最相关的文档
private static List<Integer> retrieveTopK(float[] queryEmbedding, List<float[]> knowledgeBaseEmbeddings, int topK) {
List<Integer> relevantDocumentIndices = new ArrayList<>();
List<Double> similarities = new ArrayList<>();
// 计算查询 Embedding 与知识库中每个文档 Embedding 的相似度
for (int i = 0; i < knowledgeBaseEmbeddings.size(); i++) {
float[] documentEmbedding = knowledgeBaseEmbeddings.get(i);
double similarity = cosineSimilarity(queryEmbedding, documentEmbedding);
similarities.add(similarity);
}
// 找到最相似的 topK 个文档
for (int i = 0; i < topK; i++) {
int maxIndex = 0;
double maxSimilarity = Double.MIN_VALUE;
for (int j = 0; j < similarities.size(); j++) {
if (similarities.get(j) > maxSimilarity) {
maxSimilarity = similarities.get(j);
maxIndex = j;
}
}
relevantDocumentIndices.add(maxIndex);
similarities.set(maxIndex, Double.MIN_VALUE); // 避免重复选择
}
return relevantDocumentIndices;
}
// Cosine Similarity 计算 (使用 SentenceBertExample 中的代码)
private static double cosineSimilarity(float[] vectorA, float[] vectorB) {
double dotProduct = 0.0;
double normA = 0.0;
double normB = 0.0;
for (int i = 0; i < vectorA.length; i++) {
dotProduct += vectorA[i] * vectorB[i];
normA += Math.pow(vectorA[i], 2);
normB += Math.pow(vectorB[i], 2);
}
return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
}
}
代码解释:
- 知识库准备: 创建一个包含多个文档的
knowledgeBase列表。 - SentenceBertEmbedder创建: 创建
SentenceBertExample.SentenceTranslator对象,用于生成 Embedding 向量。 - Embedding 生成: 遍历
knowledgeBase列表,使用getSentenceEmbedding()方法生成每个文档的 Embedding 向量,并将它们存储在knowledgeBaseEmbeddings列表中。 - 用户查询: 定义用户查询
query。 - 查询 Embedding 生成: 使用
getSentenceEmbedding()方法生成查询的 Embedding 向量。 - 检索: 使用
retrieveTopK()方法检索与查询最相关的 topK 个文档。retrieveTopK()方法计算查询 Embedding 与知识库中每个文档 Embedding 的 Cosine 相似度。- 找到相似度最高的 topK 个文档的索引。
- 打印结果: 打印查询和检索到的文档。
注意事项:
- 简化示例: 为了简化代码,示例中省略了
Criteria和ZooModel的创建,直接调用translator.processInput和translator.processOutput方法。 实际应用中,需要使用 DJL 的完整流程。 - NDManager: 创建了一个假的
NDManager对象