详解机器翻译任务中的BLEU
目录
一、
n
n
n 元语法(N-Gram)
n
n
n 元语法(n-gram)是指文本中连续出现的
n
n
n 个词元。当
n
n
n 分别为
1
,
2
,
3
1,2,3
1,2,3 时,n-gram 又叫作 unigram(一元语法)、bigram(二元语法)和 trigram(三元语法)。
n
n
n 元语法模型是基于
n
−
1
n-1
n−1 阶马尔可夫链的一种概率语言模型(即只考虑前
n
−
1
n-1
n−1 个词出现的情况下,后一个词出现的概率):
unigram:
P
(
w
1
,
w
2
,
⋯
,
w
T
)
=
∏
i
=
1
T
P
(
w
i
)
bigram:
P
(
w
1
,
w
2
,
⋯
,
w
T
)
=
P
(
x
1
)
∏
i
=
1
T
−
1
P
(
w
i
+
1
∣
w
i
)
trigram:
P
(
w
1
,
w
2
,
⋯
,
w
T
)
=
P
(
x
1
)
P
(
x
2
∣
x
1
)
∏
i
=
1
T
−
2
P
(
w
i
+
2
∣
w
i
,
w
i
+
1
)
begin{aligned} text{unigram:}quad&P(w_1,w_2,cdots,w_T)=prod_{i=1}^T P(w_i) \ text{bigram:}quad&P(w_1,w_2,cdots,w_T)=P(x_1)prod_{i=1}^{T-1} P(w_{i+1}|w_i) \ text{trigram:}quad&P(w_1,w_2,cdots,w_T)=P(x_1)P(x_2|x_1)prod_{i=1}^{T-2} P(w_{i+2}|w_{i},w_{i+1}) \ end{aligned}
unigram:bigram:trigram:P(w1,w2,⋯,wT)=i=1∏TP(wi)P(w1,w2,⋯,wT)=P(x1)i=1∏T−1P(wi+1∣wi)P(w1,w2,⋯,wT)=P(x1)P(x2∣x1)i=1∏T−2P(wi+2∣wi,wi+1)
二、BLEU(Bilingual Evaluation Understudy)
2.1 BLEU 定义
BLEU(发音与单词 blue 相同) 最早是用于评估机器翻译的结果, 但现在它已经被广泛用于评估许多应用的输出序列的质量。对于预测序列 pred
中的任意
n
n
n 元语法, BLEU 的评估都是这个
n
n
n 元语法是否出现在标签序列 label
中。
BLEU 定义如下:
BLEU
=
exp
(
min
(
0
,
1
−
len(label)
len(pred)
)
)
∏
n
=
1
k
p
n
1
/
2
n
text{BLEU}=expleft(minleft(0,1-frac{text{len(label)}}{text{len(pred)}}right)right)prod_{n=1}^kp_n^{1/2^n}
BLEU=exp(min(0,1−len(pred)len(label)))n=1∏kpn1/2n
其中
len(*)
text{len(*)}
len(*) 代表序列
∗
*
∗ 中的词元个数,
k
k
k 用于匹配最长的
n
n
n 元语法(常取
4
4
4),
p
n
p_n
pn 表示
n
n
n 元语法的精确度。
具体而言,给定 label
:
A
,
B
,
C
,
D
,
E
,
F
A,B,C,D,E,F
A,B,C,D,E,F 和 pred
:
A
,
B
,
B
,
C
,
D
A,B,B,C,D
A,B,B,C,D,取
k
=
3
k=3
k=3。
首先看
p
1
p_1
p1 如何计算。我们先将 pred
中的每个 unigram 都统计出来:
(
A
)
,
(
B
)
,
(
B
)
,
(
C
)
,
(
D
)
(A),(B),(B),(C),(D)
(A),(B),(B),(C),(D),再将 label
中的每个 unigram 都统计出来:
(
A
)
,
(
B
)
,
(
C
)
,
(
D
)
,
(
E
)
,
(
F
)
(A),(B),(C),(D),(E),(F)
(A),(B),(C),(D),(E),(F),然后看它们之间有多少匹配的(不可以重复匹配,即必须保持一一对应的关系)。可以看出一共有
4
4
4 个匹配的,而 pred
中一共有
5
5
5 个 unigram,于是
p
1
=
4
/
5
p_1=4/5
p1=4/5。
再来看
p
2
p_2
p2 如何计算。我们先将 pred
中的每个 bigram 都统计出来:
(
A
,
B
)
,
(
B
,
B
)
,
(
B
,
C
)
,
(
C
,
D
)
(A,B),(B,B),(B,C),(C,D)
(A,B),(B,B),(B,C),(C,D),再将 label
中的每个 bigram 都统计出来:
(
A
,
B
)
,
(
B
,
C
)
,
(
C
,
D
)
,
(
D
,
E
)
,
(
E
,
F
)
(A,B),(B,C),(C,D),(D,E),(E,F)
(A,B),(B,C),(C,D),(D,E),(E,F),然后看它们之间有多少匹配的。可以看出一共有
3
3
3 个匹配的,而 pred
中一共有
4
4
4 个 bigram,于是
p
2
=
3
/
4
p_2=3/4
p2=3/4。
最后看
p
3
p_3
p3 如何计算。我们先将 pred
中的每个 trigram 都统计出来:
(
A
,
B
,
B
)
,
(
B
,
B
,
C
)
,
(
B
,
C
,
D
)
(A,B,B),(B,B,C),(B,C,D)
(A,B,B),(B,B,C),(B,C,D),再将 label
中的每个 trigram 都统计出来:
(
A
,
B
,
C
)
,
(
B
,
C
,
D
)
,
(
C
,
D
,
E
)
,
(
D
,
E
,
F
)
(A,B,C),(B,C,D),(C,D,E),(D,E,F)
(A,B,C),(B,C,D),(C,D,E),(D,E,F),然后看它们之间有多少匹配的。可以看出只有
1
1
1 个匹配,而 pred
中一共有
3
3
3 个 trigram,于是
p
3
=
1
/
3
p_3=1/3
p3=1/3。
因此此例的 BLEU 分数为
BLEU
=
exp
(
min
(
0
,
1
−
6
/
5
)
)
⋅
p
1
1
/
2
⋅
p
2
1
/
4
⋅
p
3
1
/
8
=
e
−
0.2
⋅
(
4
5
)
1
/
2
⋅
(
3
4
)
1
/
4
⋅
(
1
3
)
1
/
8
≈
0.5940
begin{aligned} text{BLEU}&=exp(min(0,1-6/5))cdot p_1^{1/2}cdot p_2^{1/4}cdot p_3^{1/8} \ &=e^{-0.2}cdot left(frac45right)^{1/2}cdot left(frac34right)^{1/4}cdotleft(frac13right)^{1/8} \ &approx0.5940 end{aligned}
BLEU=exp(min(0,1−6/5))⋅p11/2⋅p21/4⋅p31/8=e−0.2⋅(54)1/2⋅(43)1/4⋅(31)1/8≈0.5940
2.2 BLEU 的探讨
根据 BLEU 的定义,当预测序列与标签序列完全相同时,BLEU 的值为
1
1
1。另一方面,由于
e
x
>
0
e^x>0
ex>0 且
p
n
≥
0
p_ngeq0
pn≥0,因此有
BLEU
∈
[
0
,
1
]
text{BLEU}in[0,1]
BLEU∈[0,1]
BLEU 的值越接近
1
1
1,则代表预测效果越好;BLEU 的值越接近
0
0
0,则代表预测效果越差。
此外,由于
n
n
n 元语法越长匹配难度越大, 所以 BLEU 为更长的
n
n
n 元语法的精确度分配更大的权重(固定
a
∈
(
0
,
1
)
ain(0,1)
a∈(0,1),则
a
1
/
2
n
a^{1/2^n}
a1/2n 会随着
n
n
n 的增加而增加)。而且,由于预测序列越短获得的
p
n
p_n
pn 值越高,所以系数
exp
(
⋅
)
exp(cdot)
exp(⋅) 这一项用于惩罚较短的预测序列。
2.3 BLEU 的简单实现
import math
from collections import Counter
def bleu(label, pred, k=4):
# 我们假设输入的label和pred都已经进行了分词
score = math.exp(min(0, 1 - len(label) / len(pred)))
for n in range(1, k + 1):
# 使用哈希表用来存放label中所有的n-gram
hashtable = Counter([' '.join(label[i:i + n]) for i in range(len(label) - n + 1)])
# 匹配成功的个数
num_matches = 0
for i in range(len(pred) - n + 1):
ngram = ' '.join(pred[i:i + n])
if ngram in hashtable and hashtable[ngram] > 0:
num_matches += 1
hashtable[ngram] -= 1
score *= math.pow(num_matches / (len(pred) - n + 1), math.pow(0.5, n))
return score
例如:
label = 'A B C D E F'
pred = 'A B B C D'
for i in range(4):
print(bleu(label.split(), pred.split(), k=i + 1))
# 0.7322950476607851
# 0.6814773296495302
# 0.5940339360503315
# 0.0