构建自动问答评估系统:BLEU + 语义相似度
大家好,今天我们来聊聊如何构建一个自动问答(QA)系统的评估系统。评估QA系统的好坏,不能只靠人工判断,我们需要一套自动化的、可量化的评估指标来衡量。今天我们将重点讨论两种评估指标:BLEU (Bilingual Evaluation Understudy) 和语义相似度,并将它们结合起来,构建一个更完善的评估体系。
1. 为什么需要自动评估系统?
在QA系统开发过程中,我们需要不断地改进模型。每次改进后,都需要评估新模型的效果,判断改进是否有效。如果完全依赖人工评估,效率低下,且容易受到主观因素影响。自动评估系统可以:
- 提高效率: 快速评估大量问题和答案。
- 客观性: 减少主观偏差,提供更一致的评估结果。
- 可重复性: 方便比较不同模型的性能,进行实验验证。
- 自动化流程: 可以集成到持续集成/持续部署 (CI/CD) 流程中。
2. BLEU 指标
BLEU 是一种广泛应用于机器翻译领域的评估指标,它通过比较模型生成的答案(candidate)与参考答案(reference)之间的n-gram overlap来评估答案的质量。简单来说,BLEU衡量的是candidate答案与reference答案的相似程度。
2.1 BLEU 的核心思想
BLEU 的核心思想是:如果一个模型生成的答案与人工提供的参考答案越接近,那么这个答案的质量就越高。
2.2 BLEU 的计算公式
BLEU 的计算公式如下:
BLEU = BP * exp(Σ (wn * log(pn)))
其中:
- BP (Brevity Penalty): 长度惩罚因子,用于惩罚生成答案过短的情况。
- pn (Precision): n-gram 精确率,表示 candidate 答案中出现在 reference 答案中的 n-gram 的比例。
- wn (Weight): 每个 n-gram 精确率的权重。通常,我们会为不同的 n-gram 设置相同的权重,例如,对于 1-gram 到 4-gram,每个权重为 1/4。
2.3 Brevity Penalty (BP) 的计算
BP 的计算公式如下:
BP =
1, if c > r
exp(1 - r/c), if c <= r
其中:
- c: candidate 答案的长度。
- r: reference 答案的长度。
如果 candidate 答案的长度大于 reference 答案的长度,则 BP 为 1,不进行惩罚。如果 candidate 答案的长度小于 reference 答案的长度,则进行惩罚,candidate 答案越短,惩罚越大。
2.4 n-gram 精确率 (pn) 的计算
n-gram 精确率的计算公式如下:
pn = (number of n-gram in candidate that are also in reference) / (total number of n-gram in candidate)
例如,如果 candidate 答案是 "the cat sat",reference 答案是 "the cat is on the mat",那么:
- 1-gram 精确率 (p1) = 3/3 = 1 ("the", "cat", "sat" 都出现在 reference 答案中)
- 2-gram 精确率 (p2) = 1/2 = 0.5 ("the cat" 出现在 reference 答案中,"cat sat" 没有)
2.5 Java 代码实现 BLEU
下面是一个简单的 Java 代码实现 BLEU 的例子:
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class BLEU {
public static double calculateBLEU(List<String> candidate, List<String> reference, int maxNgram) {
double brevityPenalty = calculateBrevityPenalty(candidate.size(), reference.size());
double sumOfLogs = 0.0;
double weight = 1.0 / maxNgram;
for (int i = 1; i <= maxNgram; i++) {
double precision = calculateNgramPrecision(candidate, reference, i);
sumOfLogs += weight * Math.log(precision);
}
return brevityPenalty * Math.exp(sumOfLogs);
}
private static double calculateBrevityPenalty(int candidateLength, int referenceLength) {
if (candidateLength > referenceLength) {
return 1.0;
} else if (candidateLength == 0) {
return 0.0; // Avoid division by zero
} else {
return Math.exp(1.0 - (double) referenceLength / candidateLength);
}
}
private static double calculateNgramPrecision(List<String> candidate, List<String> reference, int n) {
Map<String, Integer> candidateNgrams = getNgrams(candidate, n);
Map<String, Integer> referenceNgrams = getNgrams(reference, n);
int commonNgrams = 0;
for (String ngram : candidateNgrams.keySet()) {
if (referenceNgrams.containsKey(ngram)) {
commonNgrams += Math.min(candidateNgrams.get(ngram), referenceNgrams.get(ngram));
}
}
int totalCandidateNgrams = 0;
for (int count : candidateNgrams.values()) {
totalCandidateNgrams += count;
}
if (totalCandidateNgrams == 0) {
return 0.0; // Avoid division by zero
}
return (double) commonNgrams / totalCandidateNgrams;
}
private static Map<String, Integer> getNgrams(List<String> tokens, int n) {
Map<String, Integer> ngrams = new HashMap<>();
for (int i = 0; i <= tokens.size() - n; i++) {
StringBuilder ngramBuilder = new StringBuilder();
for (int j = 0; j < n; j++) {
ngramBuilder.append(tokens.get(i + j));
if (j < n - 1) {
ngramBuilder.append(" ");
}
}
String ngram = ngramBuilder.toString();
ngrams.put(ngram, ngrams.getOrDefault(ngram, 0) + 1);
}
return ngrams;
}
public static void main(String[] args) {
List<String> candidate = Arrays.asList("the", "cat", "sat", "on", "the", "mat");
List<String> reference = Arrays.asList("the", "cat", "is", "on", "the", "mat");
double bleuScore = calculateBLEU(candidate, reference, 4);
System.out.println("BLEU score: " + bleuScore);
}
}
2.6 BLEU 的优点和缺点
优点:
- 计算速度快,易于实现。
- 广泛应用于机器翻译和 QA 领域。
缺点:
- 只考虑了 n-gram 的精确率,忽略了召回率。
- 对短答案的惩罚过于严厉。
- 无法捕捉语义信息,可能出现以下情况:生成的答案与参考答案的 n-gram overlap 很低,但语义上非常相似。
3. 语义相似度指标
由于 BLEU 无法捕捉语义信息,我们需要引入语义相似度指标来弥补这一缺陷。语义相似度指标旨在衡量两个句子在语义上的相似程度。
3.1 常见的语义相似度指标
- 余弦相似度 (Cosine Similarity): 将句子表示成向量,然后计算两个向量之间的余弦值。
- 编辑距离 (Edit Distance): 计算将一个句子转换为另一个句子所需的最小编辑操作(插入、删除、替换)次数。
- Word Mover’s Distance (WMD): 基于词嵌入,计算一个句子中的词移动到另一个句子中的词所需的最小距离。
- Sentence-BERT (SBERT): 使用预训练的 BERT 模型生成句子嵌入,然后计算嵌入向量之间的余弦相似度。
3.2 Sentence-BERT (SBERT)
Sentence-BERT (SBERT) 是一种常用的语义相似度计算方法。它使用预训练的 BERT 模型生成句子嵌入,然后计算嵌入向量之间的余弦相似度。SBERT 的优点是可以捕捉句子中的语义信息,并且计算速度较快。
3.3 Java 代码实现 SBERT 语义相似度
要实现 SBERT,我们需要使用一个 Java NLP 库,比如 Deeplearning4j (DL4J) 或者 Hugging Face 的 Transformers 的 Java 版本 (DJL)。这里以 DJL 为例,展示如何计算两个句子的 SBERT 相似度。
首先,你需要添加 DJL 的依赖到你的项目中。如果你使用 Maven,可以在 pom.xml 文件中添加以下依赖:
<dependency>
<groupId>ai.djl.sentencepiece</groupId>
<artifactId>sentencepiece</artifactId>
<version>0.24.0</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
<version>0.24.0</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-auto</artifactId>
<version>2.1.1</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>0.24.0</version>
</dependency>
<dependency>
<groupId>ai.djl.basicdataset</groupId>
<artifactId>basicdataset</artifactId>
<version>0.24.0</version>
</dependency>
<dependency>
<groupId>ai.djl.modelzoo</groupId>
<artifactId>modelzoo</artifactId>
<version>0.24.0</version>
</dependency>
<dependency>
<groupId>ai.djl.huggingface</groupId>
<artifactId>huggingface</artifactId>
<version>0.24.0</version>
</dependency>
然后,可以使用以下代码计算两个句子的 SBERT 相似度:
import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.Tokenizer;
import ai.djl.inference.InferenceModel;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslateException;
import ai.djl.translate.TranslatorContext;
import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
public class SBERT {
private static final String MODEL_NAME = "sentence-transformers/all-mpnet-base-v2";
public static double calculateSBERTSimilarity(String sentence1, String sentence2) throws IOException, ModelNotFoundException, TranslateException {
// Load the model
Criteria<String[], float[][]> criteria = Criteria.builder()
.setTypes(String[].class, float[][].class)
.optModelName(MODEL_NAME)
.optTranslator(new SentenceTranslator())
.optEngine("PyTorch")
.build();
try (ZooModel<String[], float[][]> model = criteria.loadModel()) {
// Prepare input
String[] input = new String[]{sentence1, sentence2};
// Inference
float[][] embeddings = model.newPredictor().predict(input);
// Calculate cosine similarity
return cosineSimilarity(embeddings[0], embeddings[1]);
}
}
private static double cosineSimilarity(float[] vectorA, float[] vectorB) {
double dotProduct = 0.0;
double normA = 0.0;
double normB = 0.0;
for (int i = 0; i < vectorA.length; i++) {
dotProduct += vectorA[i] * vectorB[i];
normA += Math.pow(vectorA[i], 2);
normB += Math.pow(vectorB[i], 2);
}
return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
}
private static class SentenceTranslator implements Translator<String[], float[][]> {
private Tokenizer tokenizer;
@Override
public void prepare(TranslatorContext ctx) throws IOException {
Path modelDir = Paths.get(ctx.getModelDir().toString(), "sentence-transformers", "all-mpnet-base-v2");
tokenizer = Tokenizer.newInstance(modelDir.toString());
}
@Override
public NDList processInput(TranslatorContext ctx, String[] input) {
List<Encoding> encodings = Arrays.stream(input)
.map(tokenizer::encode)
.collect(Collectors.toList());
long maxLen = encodings.stream()
.mapToLong(Encoding::getLength)
.max()
.orElse(0);
NDArray inputIds = ctx.getNDManager().create(new Shape(input.length, maxLen), ai.djl.dtype.DataType.INT64);
NDArray attentionMask = ctx.getNDManager().create(new Shape(input.length, maxLen), ai.djl.dtype.DataType.INT64);
for (int i = 0; i < encodings.size(); i++) {
Encoding encoding = encodings.get(i);
long length = encoding.getLength();
long[] ids = encoding.getIds();
byte[] mask = encoding.getAttentionMask();
inputIds.set(i, ctx.getNDManager().create(ids));
attentionMask.set(i, ctx.getNDManager().create(mask));
}
return new NDList(inputIds, attentionMask);
}
@Override
public float[][] processOutput(TranslatorContext ctx, NDList list) {
NDArray embeddings = list.get(0).get("mean");
return embeddings.toFloatArray();
}
@Override
public Batchifier getBatchifier() {
return Batchifier.STACK;
}
}
public static void main(String[] args) throws IOException, ModelNotFoundException, TranslateException {
String sentence1 = "The cat sat on the mat.";
String sentence2 = "A cat is sitting on a rug.";
double similarity = calculateSBERTSimilarity(sentence1, sentence2);
System.out.println("SBERT Similarity: " + similarity);
}
}
注意:
- 这段代码需要下载预训练的 SBERT 模型,所以第一次运行可能会比较慢。
- 你可能需要根据你的硬件环境调整 DJL 的配置。
- 这个例子使用了
sentence-transformers/all-mpnet-base-v2模型,你可以根据需要选择其他的 SBERT 模型。
3.4 语义相似度指标的优点和缺点
优点:
- 可以捕捉语义信息,更准确地衡量答案的质量。
- 对答案的长度不敏感。
缺点:
- 计算速度相对较慢,特别是对于复杂的模型。
- 需要大量的训练数据来训练模型。
4. 结合 BLEU 和语义相似度
为了构建一个更完善的 QA 系统评估系统,我们可以将 BLEU 和语义相似度指标结合起来。一种简单的方法是:
综合得分 = α * BLEU + (1 - α) * 语义相似度
其中:
- α: 是一个权重因子,用于控制 BLEU 和语义相似度在综合得分中的比例。可以根据实际情况调整 α 的值。例如,如果更看重答案的准确性,可以设置 α 较高;如果更看重答案的语义相似度,可以设置 α 较低。
4.1 代码示例
public class QAEvaluation {
public static double calculateCombinedScore(List<String> candidate, List<String> reference, String candidateText, String referenceText, int maxNgram, double alpha) throws IOException, ModelNotFoundException, TranslateException {
double bleuScore = BLEU.calculateBLEU(candidate, reference, maxNgram);
double semanticSimilarity = SBERT.calculateSBERTSimilarity(candidateText, referenceText);
return alpha * bleuScore + (1 - alpha) * semanticSimilarity;
}
public static void main(String[] args) throws IOException, ModelNotFoundException, TranslateException {
List<String> candidate = Arrays.asList("the", "cat", "sat", "on", "the", "mat");
List<String> reference = Arrays.asList("the", "cat", "is", "on", "the", "mat");
String candidateText = "the cat sat on the mat";
String referenceText = "the cat is on the mat";
double alpha = 0.5; // Adjust the weight factor as needed
double combinedScore = calculateCombinedScore(candidate, reference, candidateText, referenceText, 4, alpha);
System.out.println("Combined Score: " + combinedScore);
}
}
4.2 如何选择合适的 α 值?
选择合适的 α 值需要根据实际情况进行调整。可以尝试不同的 α 值,然后观察评估结果,选择一个能够反映 QA 系统性能的最佳值。
一种常用的方法是:
- 人工评估: 随机抽取一部分问题和答案,进行人工评估,得到一个人工评估结果。
- 自动评估: 使用不同的 α 值,运行自动评估系统,得到多个自动评估结果。
- 比较: 将自动评估结果与人工评估结果进行比较,选择一个与人工评估结果最接近的 α 值。
5. 其他需要考虑的因素
除了 BLEU 和语义相似度,还有一些其他的因素需要考虑:
- 答案的完整性: 答案是否完整地回答了问题。
- 答案的相关性: 答案是否与问题相关。
- 答案的流畅性: 答案是否流畅自然。
- 答案的正确性: 答案是否正确。
可以将这些因素也纳入评估体系中,构建一个更全面的 QA 系统评估系统。
6. 总结
我们讨论了如何使用 BLEU 和语义相似度指标来构建一个自动问答评估系统。BLEU 衡量 n-gram overlap,速度快但忽略语义;语义相似度指标(如 SBERT)捕捉语义信息,但计算较慢。将两者结合可以提升评估的准确性,并提供了 Java 代码示例方便实践。