JAVA 构建可插拔召回链路:Embedding 模型检索融合与扩展
大家好,今天我们来探讨如何在 JAVA 中构建一个可插拔的召回链路,重点在于支持不同 Embedding 模型的检索融合与扩展。召回是推荐系统和搜索系统中的关键环节,它的目标是从海量数据中快速筛选出与用户query或者用户画像相关的候选集。传统的基于规则或者关键词的召回方法已经难以满足复杂业务的需求,而基于 Embedding 的向量检索则能够更好地捕捉语义信息,提高召回的准确率。
一、召回链路的核心组件与设计原则
一个完整的召回链路通常包含以下几个核心组件:
- Embedding 模型服务: 负责将文本、图像、视频等各种类型的数据转换为向量表示。这部分通常独立部署,提供 API 接口。
- 向量索引: 用于存储 Embedding 向量,并支持高效的相似度检索。常用的向量索引包括 Faiss、Annoy、HNSW 等。
- 检索服务: 接收查询请求,调用 Embedding 模型服务获取 query 的向量表示,然后在向量索引中进行检索,返回相似的候选集。
- 融合策略: 如果使用多个 Embedding 模型,需要定义融合策略,将不同模型召回的结果进行合并和排序。
- 数据预处理: 对原始数据进行清洗、转换和增强,使其更适合 Embedding 模型的训练和使用。
- 评估与监控: 对召回链路的效果进行评估,并监控其性能指标,及时发现和解决问题。
在设计可插拔的召回链路时,我们需要遵循以下几个原则:
- 模块化: 将不同的组件解耦,使其可以独立开发、测试和部署。
- 可配置化: 允许通过配置文件或者 API 来调整组件的行为,例如选择不同的 Embedding 模型、向量索引类型、融合策略等。
- 可扩展性: 方便地添加新的 Embedding 模型、向量索引类型和融合策略,以适应业务的变化。
- 高性能: 保证召回链路的响应速度和吞吐量,满足在线服务的需求。
二、JAVA 实现可插拔的 Embedding 模型服务
首先,我们定义一个 EmbeddingModel 接口,用于抽象不同的 Embedding 模型:
public interface EmbeddingModel {
/**
* 获取文本的 Embedding 向量
* @param text 输入文本
* @return Embedding 向量
*/
float[] getEmbedding(String text);
/**
* 模型名称
* @return 模型名称
*/
String getModelName();
}
然后,我们可以实现不同的 EmbeddingModel,例如基于 Sentence Transformers 的实现:
import ai.djl.Application;
import ai.djl.MalformedModelException;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.TranslateException;
import java.io.IOException;
import java.util.Arrays;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class SentenceTransformerEmbeddingModel implements EmbeddingModel {
private static final Logger logger = LoggerFactory.getLogger(SentenceTransformerEmbeddingModel.class);
private final String modelName;
private final Criteria<String, float[]> criteria;
private ZooModel<String, float[]> model;
private Predictor<String, float[]> predictor;
public SentenceTransformerEmbeddingModel(String modelName, String modelUrl)
throws ModelException, IOException, MalformedModelException {
this.modelName = modelName;
this.criteria =
Criteria.builder()
.optApplication(Application.NLP.SENTENCE_ENCODING)
.setTypes(String.class, float[].class)
.optModelUrls(modelUrl)
.optOption("embeddingLayer", "bert_pooler")
.build();
try {
this.model = criteria.loadModel();
this.predictor = model.newPredictor();
} catch (Exception e) {
logger.error("Error loading model", e);
throw e;
}
}
@Override
public float[] getEmbedding(String text) {
try {
return predictor.predict(text);
} catch (TranslateException e) {
logger.error("Error generating embedding", e);
return null;
}
}
@Override
public String getModelName() {
return modelName;
}
public static void main(String[] args) throws ModelException, IOException, MalformedModelException {
SentenceTransformerEmbeddingModel model = new SentenceTransformerEmbeddingModel("all-MiniLM-L6-v2", "https://resources.djl.ai/test-models/sentence-transformer/all-MiniLM-L6-v2.zip");
String text = "This is an example sentence.";
float[] embedding = model.getEmbedding(text);
System.out.println("Embedding: " + Arrays.toString(embedding));
}
}
或者基于 OpenAI API 的实现:
import com.theokanning.openai.OpenAiHttpException;
import com.theokanning.openai.embedding.EmbeddingRequest;
import com.theokanning.openai.embedding.Embedding;
import com.theokanning.openai.service.OpenAiService;
import java.util.Arrays;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class OpenAIEmbeddingModel implements EmbeddingModel {
private static final Logger logger = LoggerFactory.getLogger(OpenAIEmbeddingModel.class);
private final String modelName;
private final OpenAiService service;
public OpenAIEmbeddingModel(String modelName, String apiKey) {
this.modelName = modelName;
this.service = new OpenAiService(apiKey);
}
@Override
public float[] getEmbedding(String text) {
EmbeddingRequest embeddingRequest =
EmbeddingRequest.builder().model(modelName).input(text).build();
try {
List<Embedding> embeddings = service.createEmbeddings(embeddingRequest).getData();
if (embeddings != null && !embeddings.isEmpty()) {
return embeddings.get(0).getEmbedding().stream().map(Double::floatValue).toArray(Float[]::new);
} else {
logger.warn("No embeddings returned for text: {}", text);
return null;
}
} catch (OpenAiHttpException e) {
logger.error("Error generating embedding: {}", e.getMessage());
return null;
} catch (Exception e) {
logger.error("Unexpected error during embedding generation", e);
return null;
}
}
@Override
public String getModelName() {
return modelName;
}
public static void main(String[] args) {
OpenAIEmbeddingModel model = new OpenAIEmbeddingModel("text-embedding-ada-002", "YOUR_API_KEY"); // Replace with your actual API key
String text = "This is an example sentence.";
float[] embedding = model.getEmbedding(text);
System.out.println("Embedding: " + Arrays.toString(embedding));
}
}
注意:你需要替换 YOUR_API_KEY 为你自己的 OpenAI API Key。
为了实现可插拔性,我们可以使用工厂模式来创建 EmbeddingModel 实例:
import java.util.HashMap;
import java.util.Map;
public class EmbeddingModelFactory {
private static final Map<String, EmbeddingModel> modelCache = new HashMap<>();
public static EmbeddingModel getEmbeddingModel(String modelType, Map<String, String> config) {
String cacheKey = modelType + "-" + config.toString(); // Unique key for caching
if (modelCache.containsKey(cacheKey)) {
return modelCache.get(cacheKey);
}
EmbeddingModel model = null;
try {
switch (modelType) {
case "sentence-transformer":
String modelName = config.get("modelName");
String modelUrl = config.get("modelUrl");
model = new SentenceTransformerEmbeddingModel(modelName, modelUrl);
break;
case "openai":
String openAiModelName = config.get("modelName");
String apiKey = config.get("apiKey");
model = new OpenAIEmbeddingModel(openAiModelName, apiKey);
break;
default:
throw new IllegalArgumentException("Unsupported embedding model type: " + modelType);
}
} catch (Exception e) {
throw new RuntimeException("Failed to create embedding model: " + modelType, e);
}
modelCache.put(cacheKey, model);
return model;
}
}
使用示例:
import java.util.HashMap;
import java.util.Map;
public class Main {
public static void main(String[] args) {
// 配置 Sentence Transformer 模型
Map<String, String> sentenceTransformerConfig = new HashMap<>();
sentenceTransformerConfig.put("modelName", "all-MiniLM-L6-v2");
sentenceTransformerConfig.put("modelUrl", "https://resources.djl.ai/test-models/sentence-transformer/all-MiniLM-L6-v2.zip");
// 获取 Sentence Transformer 模型实例
EmbeddingModel sentenceTransformerModel = EmbeddingModelFactory.getEmbeddingModel("sentence-transformer", sentenceTransformerConfig);
// 使用 Sentence Transformer 模型
float[] sentenceTransformerEmbedding = sentenceTransformerModel.getEmbedding("This is a test sentence.");
System.out.println("Sentence Transformer Embedding: " + java.util.Arrays.toString(sentenceTransformerEmbedding));
// 配置 OpenAI 模型
Map<String, String> openAIConfig = new HashMap<>();
openAIConfig.put("modelName", "text-embedding-ada-002");
openAIConfig.put("apiKey", "YOUR_API_KEY"); // 替换为你的 OpenAI API Key
// 获取 OpenAI 模型实例
EmbeddingModel openAIModel = EmbeddingModelFactory.getEmbeddingModel("openai", openAIConfig);
// 使用 OpenAI 模型
float[] openAIEmbedding = openAIModel.getEmbedding("This is another test sentence.");
System.out.println("OpenAI Embedding: " + java.util.Arrays.toString(openAIEmbedding));
}
}
这样,我们就可以通过配置文件或者代码来选择不同的 Embedding 模型,而无需修改召回链路的核心代码。
三、JAVA 实现可插拔的向量索引
我们定义一个 VectorIndex 接口,用于抽象不同的向量索引类型:
public interface VectorIndex {
/**
* 添加向量到索引
* @param id 向量的 ID
* @param vector 向量
*/
void add(String id, float[] vector);
/**
* 从索引中删除向量
* @param id 向量的 ID
*/
void delete(String id);
/**
* 检索与 query 向量最相似的 K 个向量
* @param queryVector query 向量
* @param topK 返回的向量个数
* @return 相似向量的 ID 列表
*/
List<String> search(float[] queryVector, int topK);
/**
* 关闭索引
*/
void close();
}
然后,我们可以实现不同的 VectorIndex,例如基于 Faiss 的实现 (需要引入 Faiss 的 JAVA 绑定):
import com.facebook.jni.HybridData;
import java.util.ArrayList;
import java.util.List;
public class FaissVectorIndex implements VectorIndex {
private HybridData hybridData;
private final int dimension;
private final String indexType;
public FaissVectorIndex(int dimension, String indexType) {
this.dimension = dimension;
this.indexType = indexType;
hybridData = initHybrid(dimension, indexType);
}
private native HybridData initHybrid(int dimension, String indexType);
private native void addNative(String id, float[] vector);
private native void deleteNative(String id);
private native String[] searchNative(float[] queryVector, int topK);
private native void closeNative();
static {
System.loadLibrary("faissjni"); // Load the Faiss JNI library
}
@Override
public void add(String id, float[] vector) {
addNative(id, vector);
}
@Override
public void delete(String id) {
deleteNative(id);
}
@Override
public List<String> search(float[] queryVector, int topK) {
String[] results = searchNative(queryVector, topK);
return new ArrayList<>(List.of(results));
}
@Override
public void close() {
closeNative();
hybridData.resetNative();
}
}
注意: Faiss 的 JAVA 绑定需要使用 JNI 技术,这里只是一个示例,你需要根据 Faiss 的官方文档来实现完整的 JNI 代码。 faissjni 是一个假设的库名称,你需要替换成你实际使用的 Faiss JNI 库的名称。 为了编译和运行这个代码,你需要安装 Faiss 并配置好 JNI 环境。
或者基于 HNSW 的实现 (使用 Hnswlib 的 JAVA 绑定,需要引入Hnswlib):
import com.github.jelmerk.knn.DistanceFunction;
import com.github.jelmerk.knn.Index;
import com.github.jelmerk.knn.SearchResult;
import com.github.jelmerk.knn.hnsw.HnswIndex;
import java.util.List;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.stream.Collectors;
public class HnswVectorIndex implements VectorIndex {
private final int dimension;
private final int maxItemCount;
private final int m;
private final int efConstruction;
private final int efSearch;
private Index<String, float[], Float> index;
public HnswVectorIndex(int dimension, int maxItemCount, int m, int efConstruction, int efSearch) {
this.dimension = dimension;
this.maxItemCount = maxItemCount;
this.m = m;
this.efConstruction = efConstruction;
this.efSearch = efSearch;
this.index =
HnswIndex.newBuilder(getDistanceFunction(), dimension)
.withMaxItemCount(maxItemCount)
.withM(m)
.withEfConstruction(efConstruction)
.withReentrantReadWriteLock(new ReentrantReadWriteLock())
.build();
}
private DistanceFunction<float[]> getDistanceFunction() {
return (u, v) -> {
double sum = 0;
for (int i = 0; i < dimension; i++) {
sum += Math.pow(u[i] - v[i], 2);
}
return (float) Math.sqrt(sum);
};
}
@Override
public void add(String id, float[] vector) {
index.add(id, vector);
}
@Override
public void delete(String id) {
index.remove(id);
}
@Override
public List<String> search(float[] queryVector, int topK) {
List<SearchResult<String, Float>> results = index.findNearest(queryVector, topK);
return results.stream().map(SearchResult::id).collect(Collectors.toList());
}
@Override
public void close() {
// Hnswlib index doesn't require explicit closing.
}
public static void main(String[] args) {
int dimension = 3;
int maxItemCount = 1000;
int m = 16;
int efConstruction = 200;
int efSearch = 50;
HnswVectorIndex index = new HnswVectorIndex(dimension, maxItemCount, m, efConstruction, efSearch);
float[] vector1 = {1.0f, 2.0f, 3.0f};
float[] vector2 = {4.0f, 5.0f, 6.0f};
float[] vector3 = {7.0f, 8.0f, 9.0f};
index.add("item1", vector1);
index.add("item2", vector2);
index.add("item3", vector3);
float[] queryVector = {1.1f, 2.1f, 3.1f};
List<String> searchResults = index.search(queryVector, 2);
System.out.println("Search results: " + searchResults);
}
}
同样,我们可以使用工厂模式来创建 VectorIndex 实例:
import java.util.HashMap;
import java.util.Map;
public class VectorIndexFactory {
private static final Map<String, VectorIndex> indexCache = new HashMap<>();
public static VectorIndex getVectorIndex(String indexType, Map<String, String> config) {
String cacheKey = indexType + "-" + config.toString();
if (indexCache.containsKey(cacheKey)) {
return indexCache.get(cacheKey);
}
VectorIndex index = null;
try {
switch (indexType) {
case "faiss":
int faissDimension = Integer.parseInt(config.get("dimension"));
String faissIndexType = config.get("indexType");
index = new FaissVectorIndex(faissDimension, faissIndexType);
break;
case "hnsw":
int hnswDimension = Integer.parseInt(config.get("dimension"));
int hnswMaxItemCount = Integer.parseInt(config.get("maxItemCount"));
int hnswM = Integer.parseInt(config.get("m"));
int hnswEfConstruction = Integer.parseInt(config.get("efConstruction"));
int hnswEfSearch = Integer.parseInt(config.get("efSearch"));
index = new HnswVectorIndex(hnswDimension, hnswMaxItemCount, hnswM, hnswEfConstruction, hnswEfSearch);
break;
default:
throw new IllegalArgumentException("Unsupported vector index type: " + indexType);
}
} catch (Exception e) {
throw new RuntimeException("Failed to create vector index: " + indexType, e);
}
indexCache.put(cacheKey, index);
return index;
}
}
使用示例:
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class Main {
public static void main(String[] args) {
// 配置 HNSW 索引
Map<String, String> hnswConfig = new HashMap<>();
hnswConfig.put("dimension", "128");
hnswConfig.put("maxItemCount", "10000");
hnswConfig.put("m", "16");
hnswConfig.put("efConstruction", "200");
hnswConfig.put("efSearch", "50");
// 获取 HNSW 索引实例
VectorIndex hnswIndex = VectorIndexFactory.getVectorIndex("hnsw", hnswConfig);
// 添加向量
float[] vector1 = new float[128];
float[] vector2 = new float[128];
for (int i = 0; i < 128; i++) {
vector1[i] = (float) Math.random();
vector2[i] = (float) Math.random();
}
hnswIndex.add("item1", vector1);
hnswIndex.add("item2", vector2);
// 搜索
float[] queryVector = new float[128];
for (int i = 0; i < 128; i++) {
queryVector[i] = (float) Math.random();
}
List<String> results = hnswIndex.search(queryVector, 10);
System.out.println("Search Results: " + results);
// 关闭索引
hnswIndex.close();
}
}
四、JAVA 实现可插拔的融合策略
当使用多个 Embedding 模型进行召回时,我们需要一种机制来融合不同模型的结果。 我们可以定义一个 FusionStrategy 接口:
import java.util.List;
import java.util.Map;
public interface FusionStrategy {
/**
* 融合多个模型的召回结果
* @param results 每个模型的召回结果,Map<模型名称, List<召回的 ID>>
* @param weights 每个模型的权重,Map<模型名称, 权重>
* @param topK 返回的最终结果个数
* @return 融合后的结果列表
*/
List<String> fuse(Map<String, List<String>> results, Map<String, Double> weights, int topK);
}
然后,我们可以实现不同的 FusionStrategy,例如基于加权平均的融合策略:
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
public class WeightedAverageFusionStrategy implements FusionStrategy {
@Override
public List<String> fuse(Map<String, List<String>> results, Map<String, Double> weights, int topK) {
// 使用优先队列来维护 topK 个结果
PriorityQueue<ResultItem> queue = new PriorityQueue<>();
// 使用 Map 来记录每个 item 的总得分
Map<String, Double> scores = new HashMap<>();
// 遍历每个模型的召回结果
for (Map.Entry<String, List<String>> entry : results.entrySet()) {
String modelName = entry.getKey();
List<String> items = entry.getValue();
double weight = weights.getOrDefault(modelName, 1.0); // 默认权重为 1.0
// 为每个 item 增加得分
for (String item : items) {
scores.put(item, scores.getOrDefault(item, 0.0) + weight);
}
}
// 将所有 item 加入优先队列
for (Map.Entry<String, Double> entry : scores.entrySet()) {
String item = entry.getKey();
double score = entry.getValue();
queue.offer(new ResultItem(item, score));
// 如果队列大小超过 topK,则移除队尾元素
if (queue.size() > topK) {
queue.poll();
}
}
// 将优先队列中的结果提取出来
List<String> fusedResults = new ArrayList<>();
while (!queue.isEmpty()) {
fusedResults.add(0, queue.poll().item); // 逆序添加到列表
}
return fusedResults;
}
// 辅助类,用于存储结果和得分
private static class ResultItem implements Comparable<ResultItem> {
String item;
double score;
public ResultItem(String item, double score) {
this.item = item;
this.score = score;
}
@Override
public int compareTo(ResultItem other) {
// 优先队列是最小堆,所以需要反向比较
return Double.compare(other.score, this.score);
}
}
}
或者基于 Rank Fusion 的融合策略 (例如 Reciprocal Rank Fusion):
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
public class ReciprocalRankFusionStrategy implements FusionStrategy {
private final double k; // RRF 的参数 k
public ReciprocalRankFusionStrategy(double k) {
this.k = k;
}
@Override
public List<String> fuse(Map<String, List<String>> results, Map<String, Double> weights, int topK) {
// 使用优先队列来维护 topK 个结果
PriorityQueue<ResultItem> queue = new PriorityQueue<>();
// 使用 Map 来记录每个 item 的总得分
Map<String, Double> scores = new HashMap<>();
// 遍历每个模型的召回结果
for (Map.Entry<String, List<String>> entry : results.entrySet()) {
String modelName = entry.getKey();
List<String> items = entry.getValue();
// double weight = weights.getOrDefault(modelName, 1.0); RRF 不直接使用权重
// 为每个 item 增加得分,rank 从 1 开始
for (int i = 0; i < items.size(); i++) {
String item = items.get(i);
double rank = i + 1;
double rrfScore = 1.0 / (k + rank);
scores.put(item, scores.getOrDefault(item, 0.0) + rrfScore);
}
}
// 将所有 item 加入优先队列
for (Map.Entry<String, Double> entry : scores.entrySet()) {
String item = entry.getKey();
double score = entry.getValue();
queue.offer(new ResultItem(item, score));
// 如果队列大小超过 topK,则移除队尾元素
if (queue.size() > topK) {
queue.poll();
}
}
// 将优先队列中的结果提取出来
List<String> fusedResults = new ArrayList<>();
while (!queue.isEmpty()) {
fusedResults.add(0, queue.poll().item); // 逆序添加到列表
}
return fusedResults;
}
// 辅助类,用于存储结果和得分
private static class ResultItem implements Comparable<ResultItem> {
String item;
double score;
public ResultItem(String item, double score) {
this.item = item;
this.score = score;
}
@Override
public int compareTo(ResultItem other) {
// 优先队列是最小堆,所以需要反向比较
return Double.compare(other.score, this.score);
}
}
public static void main(String[] args) {
// 示例
Map<String, List<String>> results = new HashMap<>();
List<String> model1Results = List.of("A", "B", "C", "D");
List<String> model2Results = List.of("B", "E", "A", "F");
results.put("model1", model1Results);
results.put("model2", model2Results);
Map<String, Double> weights = new HashMap<>();
weights.put("model1", 0.6);
weights.put("model2", 0.4);
ReciprocalRankFusionStrategy rrf = new ReciprocalRankFusionStrategy(60);
List<String> fusedResults = rrf.fuse(results, weights, 5);
System.out.println("Fused Results: " + fusedResults);
}
}
同样,我们可以使用工厂模式来创建 FusionStrategy 实例:
import java.util.HashMap;
import java.util.Map;
public class FusionStrategyFactory {
private static final Map<String, FusionStrategy> strategyCache = new HashMap<>();
public static FusionStrategy getFusionStrategy(String strategyType, Map<String, String> config) {
String cacheKey = strategyType + "-" + config.toString();
if (strategyCache.containsKey(cacheKey)) {
return strategyCache.get(cacheKey);
}
FusionStrategy strategy = null;
try {
switch (strategyType) {
case "weighted-average":
strategy = new WeightedAverageFusionStrategy();
break;
case "reciprocal-rank-fusion":
double k = Double.parseDouble(config.get("k"));
strategy = new ReciprocalRankFusionStrategy(k);
break;
default:
throw new IllegalArgumentException("Unsupported fusion strategy type: " + strategyType);
}
} catch (Exception e) {
throw new RuntimeException("Failed to create fusion strategy: " + strategyType, e);
}
strategyCache.put(cacheKey, strategy);
return strategy;
}
}
使用示例:
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class Main {
public static void main(String[] args) {
// 准备召回结果
Map<String, List<String>> results = new HashMap<>();
results.put("model1", Arrays.asList("item1", "item2", "item3"));
results.put("model2", Arrays.asList("item2", "item4", "item5"));
// 准备模型权重
Map<String, Double> weights = new HashMap<>();
weights.put("model1", 0.6);
weights.put("model2", 0.4);
// 配置 RRF 融合策略
Map<String, String> rrfConfig = new HashMap<>();
rrfConfig.put("k", "60");
// 获取 RRF 融合策略实例
FusionStrategy rrfStrategy = FusionStrategyFactory.getFusionStrategy("reciprocal-rank-fusion", rrfConfig);
// 融合结果
List<String> fusedResults = rrfStrategy.fuse(results, weights, 5);
System.out.println("Fused Results: " + fusedResults);
// 配置 加权平均 融合策略
Map<String, String> weightedAverageConfig = new HashMap<>();
// 获取 加权平均 融合策略实例
FusionStrategy weightedAverageStrategy = FusionStrategyFactory.getFusionStrategy("weighted-average", weightedAverageConfig);
// 融合结果
List<String> weightedAverageResults = weightedAverageStrategy.fuse(results, weights, 5);
System.out.println("Weighted Average Results: " + weightedAverageResults);
}
}
五、构建完整的召回链路
现在,我们可以将上述组件组合起来,构建一个完整的可插拔召回链路:
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class RecallPipeline {
private final List<EmbeddingModel> embeddingModels;
private final VectorIndex vectorIndex;
private final FusionStrategy fusionStrategy;
private final Map<String, Double> modelWeights; // 模型权重
public RecallPipeline(List<EmbeddingModel> embeddingModels, VectorIndex vectorIndex, FusionStrategy fusionStrategy, Map<String, Double> modelWeights) {
this.embeddingModels = embeddingModels;
this.vectorIndex = vectorIndex;
this.fusionStrategy = fusionStrategy;
this.modelWeights = modelWeights;
}
public List<String> recall(String query, int topK) {
Map<String, List<String>> modelResults = new HashMap<>();
// 使用每个 Embedding 模型进行召回
for (EmbeddingModel model : embeddingModels) {
float[] queryVector = model.getEmbedding(query);
if (queryVector != null) {
List<String> results = vectorIndex.search(queryVector, topK);
modelResults.put(model.getModelName(), results);
} else {
System.err.println("Failed to get embedding from model: " + model.getModelName());
modelResults.put(model.getModelName(), new ArrayList<>()); // 保证每个模型都有结果
}
}
// 融合多个模型的结果
return fusionStrategy.fuse(modelResults, modelWeights, topK);
}
}
使用示例:
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class Main {
public static void main(String[] args) {
// 1. 配置 Embedding 模型
Map<String, String> sentenceTransformerConfig = new HashMap<>();
sentenceTransformerConfig.put("modelName", "all-MiniLM-L6-v2");
sentenceTransformerConfig.put("modelUrl", "https://resources.djl.ai/test-models/sentence-transformer/all-MiniLM-L6-v2.zip");
EmbeddingModel sentenceTransformerModel = EmbeddingModelFactory.getEmbeddingModel("sentence-transformer", sentenceTransformerConfig);
Map<String, String> openAIConfig = new HashMap<>();
openAIConfig.put("modelName", "text-embedding-ada-002");
openAIConfig.put("apiKey", "YOUR_API_KEY"); // 替换为你的 OpenAI API Key
EmbeddingModel openAIModel = EmbeddingModelFactory.getEmbeddingModel("openai", openAIConfig);
List<EmbeddingModel> embeddingModels = new ArrayList<>();
embeddingModels.add(sentenceTransformerModel);
embeddingModels.add(openAIModel);
// 2. 配置向量索引
Map<String, String> hnswConfig = new HashMap<>();
hnswConfig.put("dimension", "128"); // 假设 Embedding 维度是 128
hnswConfig.put("maxItemCount", "10000");
hnswConfig.put("m", "16");
hnswConfig.put("efConstruction", "200");
hnswConfig.put("efSearch", "50");
VectorIndex vectorIndex = VectorIndexFactory.getVectorIndex("hnsw", hnswConfig);
// 预先添加一些 item 到索引中
float[] vector1 = new float[128];
float[] vector2 = new float[128];
float[] vector3 = new float[128];
for (int i = 0; i < 128; i++) {
vector1[i] = (float) Math.random();
vector2[i] = (float) Math.random();
vector3[i] = (float) Math.random();
}
vectorIndex.add("item1", vector1);
vectorIndex.add("item2", vector2);
vectorIndex.add("item3", vector3);
// 3. 配置融合策略
Map<String, String> rrfConfig = new HashMap<>();
rrfConfig.put("k", "60");
FusionStrategy fusionStrategy = FusionStrategyFactory.getFusionStrategy("reciprocal-rank