JAVA RAG 召回结果不相关?Embedding 模型选择对比与优化

Java RAG 召回结果不相关?Embedding 模型选择对比与优化

大家好,今天我们来深入探讨一个在Java RAG(Retrieval-Augmented Generation,检索增强生成)应用中经常遇到的问题:召回结果不相关。RAG的核心在于从知识库中检索相关文档,并将其作为上下文提供给生成模型,以提高生成结果的准确性和相关性。如果召回阶段出了问题,后续的生成质量自然会受到影响。

本次讲座将围绕以下几个方面展开:

  1. RAG 流程回顾与问题诊断: 简要回顾RAG流程,并详细分析召回结果不相关的常见原因。
  2. Embedding 模型选择: 对比几种常用的Embedding模型,包括其原理、优缺点以及适用场景,并通过代码示例展示如何在Java RAG应用中使用它们。
  3. Embedding 模型优化: 探讨优化Embedding模型效果的各种策略,包括数据预处理、微调技术以及向量索引的选择。
  4. 代码实战:Java RAG 示例: 提供一个基于Java的RAG示例,并演示如何通过调整Embedding模型来改善召回结果。
  5. 评估指标与监控: 介绍评估召回效果的常用指标,并讨论如何在生产环境中监控RAG系统的性能。

1. RAG 流程回顾与问题诊断

RAG流程通常包含以下几个步骤:

  1. 数据准备: 将知识库中的文档分割成块(chunks)。Chunk的大小会影响召回效果,过小可能缺乏上下文信息,过大则可能包含无关信息。
  2. Embedding 生成: 使用Embedding模型将每个Chunk转换为向量表示。这些向量将被存储在向量数据库中。
  3. 检索: 接收用户查询,将其转换为向量,并在向量数据库中搜索与查询向量最相似的Chunk向量。
  4. 生成: 将检索到的Chunk作为上下文提供给生成模型(例如,LLM),生成最终的答案。

召回结果不相关的原因有很多,常见的包括:

  • Chunk质量差: Chunk分割方式不合理,导致Chunk包含过多噪声或缺乏关键信息。
  • Embedding模型不适合: 选择的Embedding模型无法准确捕捉文本的语义信息,导致相似的文本向量距离较远,不相似的文本向量距离较近。
  • 向量索引不佳: 向量索引的选择会影响检索效率和准确性。不合适的索引可能导致无法召回最相关的Chunk。
  • 查询理解偏差: 查询本身存在歧义或表达不清,导致Embedding模型无法准确理解查询意图。
  • 数据偏差: 知识库中存在偏差,导致模型学习到的向量表示无法泛化到新的查询。

2. Embedding 模型选择

选择合适的Embedding模型是提高RAG召回效果的关键。以下是一些常用的Embedding模型:

模型名称 优点 缺点 适用场景
Word2Vec/GloVe 训练速度快,计算资源消耗低,易于部署。 无法捕捉上下文信息,对一词多义处理效果差,无法处理OOV(Out-of-Vocabulary)问题。 适用于对性能要求较高,数据量较小,对语义理解要求不高的场景。例如,简单的文本分类任务。
FastText 速度快,能够处理OOV问题,对 morphologically rich 的语言有较好的效果。 无法捕捉上下文信息,对一词多义处理效果差。 适用于需要处理OOV问题,并且对速度有较高要求的场景。例如,文本分类,词性标注。
Sentence-BERT 能够生成句子的embedding,考虑了上下文信息,在语义相似度任务上表现优秀。 训练和推理速度相对较慢,资源消耗较高。 适用于对语义理解要求较高,需要生成句子级别embedding的场景。例如,语义搜索,文本聚类,问答系统。
OpenAI Embeddings (e.g., text-embedding-ada-002) 易于使用,性能良好,支持多种语言,可以处理长文本。 需要通过API调用,存在一定的成本,可能存在隐私问题。 适用于对性能要求较高,需要处理长文本,并且能够接受API调用的场景。例如,知识库问答,语义搜索。
BGE (BAAI General Embedding) 在多个benchmark上取得了领先的性能,支持对比学习,对中文文本有较好的效果。 模型较大,训练和推理速度相对较慢。 适用于对性能要求较高,需要处理中文文本,并且希望获得最佳性能的场景。例如,知识库问答,语义搜索。

2.1 Word2Vec 示例 (使用 Deeplearning4j)

import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

import java.io.File;
import java.io.IOException;

public class Word2VecExample {

    public static void main(String[] args) throws Exception {
        // 1. 加载文本数据
        String filePath = "your_text_data.txt"; // 替换成你的文本文件路径
        SentenceIterator iter = new BasicLineIterator(filePath);

        // 2. 配置 Tokenizer
        TokenizerFactory t = new DefaultTokenizerFactory();
        t.setTokenPreProcessor(new CommonPreprocessor());

        // 3. 构建 Word2Vec 模型
        Word2Vec vec = new Word2Vec.Builder()
                .minWordFrequency(5) // 设置最小词频
                .iterations(1) // 设置迭代次数
                .layerSize(100) // 设置词向量维度
                .seed(42) // 设置随机种子
                .windowSize(5) // 设置窗口大小
                .iterate(iter) // 设置 SentenceIterator
                .tokenizerFactory(t) // 设置 TokenizerFactory
                .build();

        // 4. 训练模型
        System.out.println("Building model....");
        vec.fit();

        // 5. 保存模型
        WordVectorSerializer.writeWord2VecModel(vec, "word2vec.model");

        // 6. 使用模型
        INDArray wordVector = vec.getWordVectorMatrix("example"); // 获取 "example" 的词向量
        System.out.println("Word vector for 'example': " + wordVector);

        double similarity = vec.similarity("example", "sample"); // 计算 "example" 和 "sample" 的相似度
        System.out.println("Similarity between 'example' and 'sample': " + similarity);

        // 7. 计算句子向量(简单平均)
        String sentence = "This is an example sentence.";
        INDArray sentenceVector = getSentenceVector(sentence, vec, t);
        System.out.println("Sentence vector: " + sentenceVector);
    }

    // 计算句子向量(简单平均)
    public static INDArray getSentenceVector(String sentence, Word2Vec vec, TokenizerFactory t) {
        t.create(sentence).getTokens().stream().filter(vec::hasWord).toList();
        List<String> tokens = t.create(sentence).getTokens().stream().filter(vec::hasWord).toList();
        if (tokens.isEmpty()) {
            return Nd4j.zeros(vec.getLayerSize());
        }

        INDArray sentenceVector = Nd4j.zeros(vec.getLayerSize());
        for (String token : tokens) {
            sentenceVector.addi(vec.getWordVectorMatrix(token));
        }
        return sentenceVector.divi(tokens.size());
    }
}

代码解释:

  1. 数据加载: 使用 BasicLineIterator 从文本文件中逐行读取句子。
  2. Tokenizer配置: 使用 DefaultTokenizerFactory 进行分词,并使用 CommonPreprocessor 进行预处理(例如,转换为小写,去除标点符号)。
  3. Word2Vec模型构建: 使用 Word2Vec.Builder 配置 Word2Vec 模型,包括最小词频、迭代次数、词向量维度、随机种子和窗口大小。
  4. 模型训练: 使用 vec.fit() 训练 Word2Vec 模型。
  5. 模型保存: 使用 WordVectorSerializer.writeWord2VecModel() 将训练好的模型保存到文件中。
  6. 模型使用:
    • vec.getWordVectorMatrix("example"): 获取单词 "example" 的词向量。
    • vec.similarity("example", "sample"): 计算单词 "example" 和 "sample" 之间的相似度。
    • getSentenceVector(): 计算句子的向量表示(简单平均所有单词的向量)。

2.2 Sentence-BERT 示例 (使用 Hugging Face Transformers Java API)

import ai.djl.ModelException;
import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslateException;
import ai.djl.translate.TranslatorContext;

import java.io.IOException;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;

public class SentenceBertExample {

    public static void main(String[] args) throws IOException, ModelException, TranslateException {
        String modelName = "sentence-transformers/all-mpnet-base-v2"; // 选择 Sentence-BERT 模型

        // 1. 构建 Criteria
        Criteria<String, float[]> criteria = Criteria.builder()
                .setTypes(String.class, float[].class)
                .optModelPath(Paths.get("models")) // 可选:指定模型下载路径
                .optModelName(modelName)
                .optTranslator(new SentenceTranslator(modelName))
                .optEngine("PyTorch") // 确保安装了 PyTorch 引擎
                .build();

        // 2. 加载模型
        try (ZooModel<String, float[]> model = criteria.loadModel()) {
            // 3. 创建 Predictor
            try (Predictor<String, float[]> predictor = model.newPredictor()) {
                // 4. 生成句子 Embedding
                String sentence1 = "This is a sentence.";
                String sentence2 = "This is another sentence.";

                float[] embedding1 = predictor.predict(sentence1);
                float[] embedding2 = predictor.predict(sentence2);

                System.out.println("Embedding for sentence 1: " + Arrays.toString(embedding1));
                System.out.println("Embedding for sentence 2: " + Arrays.toString(embedding2));

                // 5. 计算相似度 (Cosine Similarity)
                double similarity = cosineSimilarity(embedding1, embedding2);
                System.out.println("Similarity between sentence 1 and sentence 2: " + similarity);
            }
        }
    }

    // Cosine Similarity 计算
    private static double cosineSimilarity(float[] vectorA, float[] vectorB) {
        double dotProduct = 0.0;
        double normA = 0.0;
        double normB = 0.0;
        for (int i = 0; i < vectorA.length; i++) {
            dotProduct += vectorA[i] * vectorB[i];
            normA += Math.pow(vectorA[i], 2);
            normB += Math.pow(vectorB[i], 2);
        }
        return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
    }

    // 自定义 Translator
    private static class SentenceTranslator implements Translator<String, float[]> {

        private final String modelName;
        private HuggingFaceTokenizer tokenizer;

        public SentenceTranslator(String modelName) {
            this.modelName = modelName;
        }

        @Override
        public void prepare(TranslatorContext ctx) throws IOException {
            tokenizer = HuggingFaceTokenizer.newInstance(modelName);
        }

        @Override
        public NDArray processInput(TranslatorContext ctx, String input) {
            NDManager manager = ctx.getNDManager();
            Encoding encoding = tokenizer.encode(input);
            long[] indices = encoding.getIds();
            NDArray inputIds = manager.create(indices);
            inputIds.setName("input_ids");
            return inputIds.expandDims(0);
        }

        @Override
        public float[] processOutput(TranslatorContext ctx, NDArray output) {
            // Pooling layer (Mean Pooling)
            NDArray inputMask = ctx.getNDManager().ones(new long[]{1, output.getShape()[1]});
            NDArray expandedMask = inputMask.expandDims(-1);
            NDArray maskedOutput = output.mul(expandedMask);
            NDArray sumVectors = maskedOutput.sum(new int[]{1});
            NDArray sumMask = inputMask.sum(new int[]{1}).expandDims(1);
            NDArray pooled = sumVectors.div(sumMask);

            return pooled.toFloatArray();
        }

        @Override
        public Batchifier getBatchifier() {
            return Batchifier.STACK;
        }
    }
}

代码解释:

  1. 依赖: 确保项目中包含 ai.djl 相关的依赖,包括 ai.djl-api, ai.djl-model-zoo, ai.djl-huggingface, ai.djl-pytorch (或其他你选择的引擎).
  2. Criteria构建: 使用 Criteria.builder() 定义模型的加载标准,包括模型类型、模型路径、模型名称和自定义的 Translator
  3. 模型加载: 使用 criteria.loadModel() 加载 Sentence-BERT 模型。
  4. Predictor创建: 使用 model.newPredictor() 创建 Predictor 对象,用于执行推理。
  5. 句子 Embedding 生成: 使用 predictor.predict(sentence) 生成句子的 Embedding 向量。
  6. 相似度计算: 使用 cosineSimilarity() 函数计算两个 Embedding 向量之间的 Cosine 相似度。
  7. 自定义 Translator:
    • SentenceTranslator 类实现了 Translator 接口,用于处理输入和输出。
    • prepare() 方法初始化 HuggingFaceTokenizer
    • processInput() 方法将输入句子转换为模型所需的 NDArray 格式(token IDs)。
    • processOutput() 方法将模型的输出转换为 Embedding 向量(使用 Mean Pooling)。
    • getBatchifier() 方法指定 Batchifier 类型。

注意事项:

  • Hugging Face 模型: 代码中使用的是 sentence-transformers/all-mpnet-base-v2 模型,你也可以选择其他 Sentence-BERT 模型。
  • 引擎选择: optEngine("PyTorch") 指定使用 PyTorch 引擎。你需要确保已经安装了相应的引擎。
  • 依赖管理: 确保你的项目中包含了所有必要的 ai.djl 依赖。
  • 模型下载: 模型会自动下载到 optModelPath 指定的路径下。 首次运行可能需要一些时间下载模型。
  • Pooling策略: processOutput 方法中使用了 Mean Pooling。 其他的Pooling策略(例如,CLS token embedding)也可以尝试。

2.3 OpenAI Embeddings 示例

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.util.ArrayList;
import java.util.List;

public class OpenAIEmbeddingsExample {

    private static final String API_KEY = "YOUR_OPENAI_API_KEY"; // 替换成你的 OpenAI API Key
    private static final String MODEL_NAME = "text-embedding-ada-002"; // 选择 Embedding 模型

    public static void main(String[] args) throws IOException, InterruptedException {
        String text1 = "This is the first text.";
        String text2 = "This is the second text.";

        List<Double> embedding1 = getEmbedding(text1);
        List<Double> embedding2 = getEmbedding(text2);

        System.out.println("Embedding for text 1: " + embedding1);
        System.out.println("Embedding for text 2: " + embedding2);

        // 计算相似度 (Cosine Similarity)
        double similarity = cosineSimilarity(embedding1, embedding2);
        System.out.println("Similarity between text 1 and text 2: " + similarity);
    }

    // 获取 OpenAI Embedding
    public static List<Double> getEmbedding(String text) throws IOException, InterruptedException {
        String apiUrl = "https://api.openai.com/v1/embeddings";

        // 构建请求体
        String requestBody = String.format("{"input": "%s", "model": "%s"}", text, MODEL_NAME);

        // 创建 HttpClient
        HttpClient client = HttpClient.newHttpClient();

        // 创建 HttpRequest
        HttpRequest request = HttpRequest.newBuilder()
                .uri(URI.create(apiUrl))
                .header("Content-Type", "application/json")
                .header("Authorization", "Bearer " + API_KEY)
                .POST(HttpRequest.BodyPublishers.ofString(requestBody))
                .build();

        // 发送请求并获取响应
        HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());

        // 解析 JSON 响应
        ObjectMapper mapper = new ObjectMapper();
        JsonNode root = mapper.readTree(response.body());

        // 提取 Embedding 向量
        List<Double> embedding = new ArrayList<>();
        JsonNode data = root.get("data").get(0);
        JsonNode embeddingNode = data.get("embedding");
        if (embeddingNode.isArray()) {
            for (JsonNode element : embeddingNode) {
                embedding.add(element.asDouble());
            }
        }

        return embedding;
    }

    // Cosine Similarity 计算
    private static double cosineSimilarity(List<Double> vectorA, List<Double> vectorB) {
        double dotProduct = 0.0;
        double normA = 0.0;
        double normB = 0.0;
        for (int i = 0; i < vectorA.size(); i++) {
            dotProduct += vectorA.get(i) * vectorB.get(i);
            normA += Math.pow(vectorA.get(i), 2);
            normB += Math.pow(vectorB.get(i), 2);
        }
        return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
    }
}

代码解释:

  1. API Key: 替换 YOUR_OPENAI_API_KEY 为你的 OpenAI API Key。
  2. 依赖: 确保项目中包含 com.fasterxml.jackson.databind 依赖 (Jackson JSON 库)。
  3. 构建请求: 使用 HttpClient 构建 HTTP POST 请求,包含 OpenAI API 的 URL、请求头 (Content-Type, Authorization) 和请求体 (JSON 格式的输入文本和模型名称)。
  4. 发送请求: 使用 client.send() 发送请求并获取响应。
  5. 解析响应: 使用 ObjectMapper 解析 JSON 响应,提取 Embedding 向量。
  6. Cosine 相似度计算: 使用 cosineSimilarity() 函数计算两个 Embedding 向量之间的 Cosine 相似度。
  7. 错误处理: 实际应用中,需要添加更完善的错误处理机制,例如处理 API 调用失败、JSON 解析错误等。

注意事项:

  • OpenAI API Key: 你需要拥有一个有效的 OpenAI API Key 才能使用 OpenAI Embeddings API。
  • 模型选择: MODEL_NAME 变量指定了使用的 OpenAI Embedding 模型。 text-embedding-ada-002 是一个常用的模型,你也可以选择其他模型。
  • API 调用限制: OpenAI API 存在调用频率限制,需要注意控制调用频率,避免超出限制。
  • 费用: 使用 OpenAI API 会产生费用,需要了解 OpenAI 的定价策略。
  • 依赖管理: 确保你的项目中包含了 Jackson JSON 库的依赖。

3. Embedding 模型优化

仅仅选择一个合适的Embedding模型是不够的,还需要对其进行优化,以进一步提高召回效果。

  • 数据预处理: 对原始文本进行清洗和预处理,例如去除HTML标签、特殊字符、停用词等。这可以减少噪声,提高Embedding模型的准确性。
  • Chunk优化: 调整Chunk的大小和分割方式。可以尝试固定大小的Chunk,也可以使用基于语义的Chunk分割方法,例如将句子或段落作为Chunk。
  • 微调 (Fine-tuning): 如果知识库的领域比较特殊,可以考虑使用知识库中的数据对Embedding模型进行微调。这可以使模型更好地适应特定领域的语义信息。
    • 对比学习 (Contrastive Learning): 一种常用的微调方法,通过构造正负样本对,训练模型区分相似和不相似的文本。
  • 向量索引选择: 选择合适的向量索引可以提高检索效率和准确性。常用的向量索引包括:
    • Annoy: 适用于高维向量的近似最近邻搜索。
    • HNSW: 一种基于图的索引,能够提供较高的检索精度和效率。
    • Faiss: Facebook AI Similarity Search,提供了多种向量索引算法,适用于大规模向量搜索。

3.1 使用 BGE 模型进行对比学习微调示例 (伪代码)

由于篇幅限制,这里只提供伪代码,展示对比学习微调的思路。实际实现需要使用深度学习框架(例如,PyTorch)。

# 伪代码 (Python)
import torch
from transformers import AutoModel, AutoTokenizer
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader

# 1. 定义数据集
class ContrastiveDataset(Dataset):
    def __init__(self, data):
        self.data = data # data: list of tuples (anchor, positive, negative)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# 2. 定义模型
class BGEModel(torch.nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.model = AutoModel.from_pretrained(model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

    def forward(self, text):
        encoded_input = self.tokenizer(text, padding=True, truncation=True, return_tensors='pt')
        output = self.model(**encoded_input).pooler_output # 或者使用 mean pooling
        return output

# 3. 定义损失函数 (InfoNCE Loss)
def info_nce_loss(anchor, positive, negative, temperature=0.05):
    # 计算相似度
    similarity_positive = torch.cosine_similarity(anchor, positive, dim=1)
    similarity_negative = torch.cosine_similarity(anchor, negative, dim=1)

    # 计算 logits
    logits = torch.cat([similarity_positive, similarity_negative], dim=0)
    logits = logits / temperature

    # 创建 labels
    labels = torch.tensor([0]).to(anchor.device)

    # 计算损失
    loss = torch.nn.CrossEntropyLoss()(logits.unsqueeze(0), labels)
    return loss

# 4. 训练循环
def train(model, dataloader, optimizer, epochs=10):
    model.train()
    for epoch in range(epochs):
        for anchor, positive, negative in dataloader:
            # 将数据移动到 GPU (如果可用)
            anchor = anchor.to(device)
            positive = positive.to(device)
            negative = negative.to(device)

            # 计算 embeddings
            anchor_embedding = model(anchor)
            positive_embedding = model(positive)
            negative_embedding = model(negative)

            # 计算损失
            loss = info_nce_loss(anchor_embedding, positive_embedding, negative_embedding)

            # 反向传播和优化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            print(f"Epoch: {epoch}, Loss: {loss.item()}")

# 5. 准备数据
# data = [("anchor text", "positive text", "negative text"), ...] # 构建正负样本对

# 6. 初始化模型、优化器和数据加载器
model_name = "BAAI/bge-small-en-v1.5" # 或者其他 BGE 模型
model = BGEModel(model_name).to(device)
optimizer = AdamW(model.parameters(), lr=1e-5)
dataset = ContrastiveDataset(data)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# 7. 开始训练
train(model, dataloader, optimizer)

# 8. 保存微调后的模型
model.model.save_pretrained("fine_tuned_bge_model")
model.tokenizer.save_pretrained("fine_tuned_bge_model")

代码解释:

  1. ContrastiveDataset: 自定义数据集,用于加载正负样本对。每个样本包含一个 anchor 文本、一个 positive 文本(与 anchor 相似)和一个 negative 文本(与 anchor 不相似)。
  2. BGEModel: 加载 BGE 模型和 tokenizer,并定义 forward 方法,用于计算文本的 embedding。
  3. info_nce_loss: 计算 InfoNCE Loss,用于训练模型区分相似和不相似的文本。
  4. train: 训练循环,包括前向传播、损失计算、反向传播和优化。
  5. 数据准备: 需要根据你的知识库构建正负样本对。
  6. 模型初始化: 初始化 BGE 模型、AdamW 优化器和 DataLoader。
  7. 训练: 调用 train 函数开始训练。
  8. 模型保存: 保存微调后的模型和 tokenizer。

注意事项:

  • 正负样本构建: 正负样本的质量对微调效果至关重要。可以使用各种方法构建正负样本,例如:
    • 同义词替换: 将 anchor 文本中的一些词替换成同义词,作为 positive 样本。
    • 随机替换: 将 anchor 文本中的一些词随机替换成其他词,作为 negative 样本。
    • BM25 检索: 使用 BM25 检索与 anchor 文本相关的文档,将检索到的文档作为 positive 样本,将其他文档作为 negative 样本。
  • 超参数调整: 需要根据实际情况调整超参数,例如学习率、batch size、temperature 等。
  • 评估: 使用评估指标(例如,Recall@K)评估微调后的模型效果。
  • 硬件资源: 微调需要大量的计算资源,建议使用 GPU 进行训练。

4. 代码实战:Java RAG 示例

下面提供一个简化的Java RAG示例,演示如何使用Sentence-BERT模型进行召回。

import ai.djl.ModelException;
import ai.djl.TranslateException;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public class JavaRAGExample {

    public static void main(String[] args) throws IOException, ModelException, TranslateException {
        // 1. 准备知识库
        List<String> knowledgeBase = new ArrayList<>();
        knowledgeBase.add("Java is a high-level, class-based, object-oriented programming language.");
        knowledgeBase.add("Java is designed to have as few implementation dependencies as possible.");
        knowledgeBase.add("Java is intended to let application developers write once, run anywhere (WORA).");
        knowledgeBase.add("Java is currently one of the most popular programming languages in use.");

        // 2. 创建 SentenceBertEmbedder (使用 SentenceBertExample 中的代码)
        SentenceBertExample.SentenceTranslator translator = new SentenceBertExample.SentenceTranslator("sentence-transformers/all-mpnet-base-v2");

        // 3. 生成知识库的 Embedding
        List<float[]> knowledgeBaseEmbeddings = new ArrayList<>();
        for (String text : knowledgeBase) {
            knowledgeBaseEmbeddings.add(getSentenceEmbedding(text, translator));
        }

        // 4. 用户查询
        String query = "What are the key features of Java?";

        // 5. 生成查询的 Embedding
        float[] queryEmbedding = getSentenceEmbedding(query, translator);

        // 6. 检索最相关的文档
        int topK = 2;
        List<Integer> relevantDocumentIndices = retrieveTopK(queryEmbedding, knowledgeBaseEmbeddings, topK);

        // 7. 打印召回结果
        System.out.println("Query: " + query);
        System.out.println("Retrieved documents:");
        for (int index : relevantDocumentIndices) {
            System.out.println("- " + knowledgeBase.get(index));
        }
    }

    // 获取句子 Embedding
    private static float[] getSentenceEmbedding(String text, SentenceBertExample.SentenceTranslator translator) throws IOException, ModelException, TranslateException {
        // 使用 SentenceBertExample 中的代码,简化起见,这里省略了 Criteria 和 ZooModel 的创建
        // 实际应用中,需要创建 Criteria 和 ZooModel,并使用 Predictor 进行推理
        // 这里直接调用 translator.processInput 和 translator.processOutput
        // 注意: 这只是一个简化示例,实际需要使用 DJL 的完整流程

        // 创建一个假的 NDManager,用于 translator
        ai.djl.ndarray.NDManager manager = ai.djl.ndarray.NDManager.newBaseManager();

        // 创建 TranslatorContext
        ai.djl.translate.TranslatorContext ctx = new ai.djl.translate.NoBatchifyTranslatorContext(manager);

        // 准备 Translator
        translator.prepare(ctx);

        // 处理输入
        ai.djl.ndarray.NDArray input = translator.processInput(ctx, text);

        // 模拟模型输出 (这里假设模型输出已经加载到 NDArray 中)
        ai.djl.ndarray.NDArray output = ai.djl.ndarray.NDArrays.randn(manager, new long[]{1, 768}); // 768 是 all-mpnet-base-v2 的 embedding 维度

        // 处理输出
        float[] embedding = translator.processOutput(ctx, output);

        return embedding;
    }

    // 检索最相关的文档
    private static List<Integer> retrieveTopK(float[] queryEmbedding, List<float[]> knowledgeBaseEmbeddings, int topK) {
        List<Integer> relevantDocumentIndices = new ArrayList<>();
        List<Double> similarities = new ArrayList<>();

        // 计算查询 Embedding 与知识库中每个文档 Embedding 的相似度
        for (int i = 0; i < knowledgeBaseEmbeddings.size(); i++) {
            float[] documentEmbedding = knowledgeBaseEmbeddings.get(i);
            double similarity = cosineSimilarity(queryEmbedding, documentEmbedding);
            similarities.add(similarity);
        }

        // 找到最相似的 topK 个文档
        for (int i = 0; i < topK; i++) {
            int maxIndex = 0;
            double maxSimilarity = Double.MIN_VALUE;
            for (int j = 0; j < similarities.size(); j++) {
                if (similarities.get(j) > maxSimilarity) {
                    maxSimilarity = similarities.get(j);
                    maxIndex = j;
                }
            }
            relevantDocumentIndices.add(maxIndex);
            similarities.set(maxIndex, Double.MIN_VALUE); // 避免重复选择
        }

        return relevantDocumentIndices;
    }

    // Cosine Similarity 计算 (使用 SentenceBertExample 中的代码)
    private static double cosineSimilarity(float[] vectorA, float[] vectorB) {
        double dotProduct = 0.0;
        double normA = 0.0;
        double normB = 0.0;
        for (int i = 0; i < vectorA.length; i++) {
            dotProduct += vectorA[i] * vectorB[i];
            normA += Math.pow(vectorA[i], 2);
            normB += Math.pow(vectorB[i], 2);
        }
        return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
    }
}

代码解释:

  1. 知识库准备: 创建一个包含多个文档的 knowledgeBase 列表。
  2. SentenceBertEmbedder创建: 创建 SentenceBertExample.SentenceTranslator 对象,用于生成 Embedding 向量。
  3. Embedding 生成: 遍历 knowledgeBase 列表,使用 getSentenceEmbedding() 方法生成每个文档的 Embedding 向量,并将它们存储在 knowledgeBaseEmbeddings 列表中。
  4. 用户查询: 定义用户查询 query
  5. 查询 Embedding 生成: 使用 getSentenceEmbedding() 方法生成查询的 Embedding 向量。
  6. 检索: 使用 retrieveTopK() 方法检索与查询最相关的 topK 个文档。
    • retrieveTopK() 方法计算查询 Embedding 与知识库中每个文档 Embedding 的 Cosine 相似度。
    • 找到相似度最高的 topK 个文档的索引。
  7. 打印结果: 打印查询和检索到的文档。

注意事项:

  • 简化示例: 为了简化代码,示例中省略了 CriteriaZooModel 的创建,直接调用 translator.processInputtranslator.processOutput 方法。 实际应用中,需要使用 DJL 的完整流程。
  • NDManager: 创建了一个假的 NDManager 对象

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注