基于重要性采样的期望估计——sampled softmax推导
一、背景
在推荐召回系统中,通常会采用tow-tower模型并利用log softmax作为损失进行优化,设
[
B
]
[B]
[B]为mini-batch,
[
C
]
[C]
[C]为全局语料库,
s
(
x
,
y
)
s(x, y)
s(x,y)为query x和item y的相似度分数,则有如下的损失函数:
L
=
−
1
B
∑
i
∈
[
B
]
l
o
g
(
e
s
(
x
i
,
y
i
)
∑
j
∈
[
C
]
e
s
(
x
i
,
y
j
)
)
=
−
1
B
∑
i
∈
[
B
]
{
s
(
x
i
,
y
i
)
−
l
o
g
∑
j
∈
[
C
]
e
s
(
x
i
,
y
j
)
}
begin{align} mathcal{L} &= - frac {1}{B} sum_{i in [B]}log(frac {e^{s(x_i,y_i)}}{sum_{jin [C]} e^{s(x_i,y_j)}}) \ &= - frac{1}{B} sum_{i in [B]}{s(x_i,y_i) - logsum_{jin [C]} e^{s(x_i,y_j)}} end{align}
L=−B1i∈[B]∑log(∑j∈[C]es(xi,yj)es(xi,yi))=−B1i∈[B]∑{s(xi,yi)−logj∈[C]∑es(xi,yj)}
对损失求导
∇
θ
L
=
−
1
B
∑
i
∈
[
B
]
{
∇
θ
s
(
x
i
,
y
i
)
−
∑
j
∈
[
C
]
e
s
(
x
i
,
y
j
)
∑
k
∈
[
C
]
e
s
(
x
i
,
y
k
)
∇
θ
s
(
x
i
,
y
j
)
}
=
−
1
B
∑
i
∈
[
B
]
{
∇
θ
s
(
x
i
,
y
i
)
−
∑
j
∈
[
C
]
P
(
y
j
∣
x
i
)
∇
θ
s
(
x
i
,
y
j
)
}
=
−
1
B
∑
i
∈
[
B
]
{
∇
θ
s
(
x
i
,
y
i
)
⏟
p
a
r
t
o
n
e
−
E
P
[
∇
θ
s
(
x
i
,
y
j
)
]
⏟
p
a
r
t
t
w
o
}
begin{align} mathcal{nabla_theta L} &=- frac{1}{B} sum_{i in [B]} { nabla_{theta} s(x_i, y_i) - sum_{j in [C]} frac{e^{s(x_i, y_j)}}{sum_{kin [C]} e^{s(x_i, y_k)}} nabla_ theta s(x_i, y_j)} \ &= - frac{1}{B} sum_{i in [B]} { nabla_{theta} s(x_i, y_i) - sum_{j in [C]} P(y_j|x_i) nabla_ theta s(x_i, y_j)} \ &= - frac{1}{B} sum_{i in [B]} { underbrace{nabla_{theta} s(x_i, y_i)}_{part one} - underbrace{E_{P}[nabla_theta s(x_i, y_j)]}_{part two}} end{align}
∇θL=−B1i∈[B]∑{∇θs(xi,yi)−j∈[C]∑∑k∈[C]es(xi,yk)es(xi,yj)∇θs(xi,yj)}=−B1i∈[B]∑{∇θs(xi,yi)−j∈[C]∑P(yj∣xi)∇θs(xi,yj)}=−B1i∈[B]∑{part one
∇θs(xi,yi)−part two
EP[∇θs(xi,yj)]}
可以发现梯度的第二部分是
∇
θ
s
(
x
i
,
y
j
)
nabla_theta s(x_i, y_j)
∇θs(xi,yj)关于target distribution P的期望,由于语料库的规模十分庞大,导致在计算配分函数时产生巨大的计算开销,因此需要对期望(梯度)进行近似计算,比较常见的做法是利用importance sampling采样较小规模的item来近似期望(sampled softmax),本文将对sampled softmax的计算公式进行推导,供学习参考,如有错误还请指出
二、公式推导
设P为target distribution,Q为proposal distribution,重要性采样的基本思想是利用更容易采样的Q分布进行采样
E
P
[
∇
θ
s
(
x
i
,
y
j
)
]
=
∑
j
∈
C
P
(
y
j
∣
x
i
)
∇
θ
s
(
x
i
,
y
j
)
=
∑
j
∈
C
P
(
y
j
∣
x
i
)
Q
(
y
j
∣
x
i
)
Q
(
y
j
∣
x
i
)
∇
θ
s
(
x
i
,
y
j
)
=
E
Q
[
P
(
y
j
∣
x
i
)
Q
(
y
j
∣
x
i
)
∇
θ
s
(
x
i
,
y
j
)
]
≈
1
B
∑
j
∈
[
B
]
P
(
y
j
∣
x
i
)
Q
(
y
j
∣
x
i
)
∇
θ
s
(
x
i
,
y
j
)
begin{align} E_{P}[nabla_theta s(x_i, y_j)] &= sum_{j in C} P(y_j|x_i) nabla_ theta s(x_i, y_j) \ &= sum_{j in C} frac{P(y_j|x_i)}{Q(y_j|x_i)} Q(y_j|x_i) nabla_ theta s(x_i, y_j) \ &= E_{Q}[frac{P(y_j|x_i)}{Q(y_j|x_i)}nabla_theta s(x_i, y_j)] \ &approx frac{1}{B}sum_{j in [B]} frac{P(y_j|x_i)}{Q(y_j|x_i)}nabla_theta s(x_i, y_j) end{align}
EP[∇θs(xi,yj)]=j∈C∑P(yj∣xi)∇θs(xi,yj)=j∈C∑Q(yj∣xi)P(yj∣xi)Q(yj∣xi)∇θs(xi,yj)=EQ[Q(yj∣xi)P(yj∣xi)∇θs(xi,yj)]≈B1j∈[B]∑Q(yj∣xi)P(yj∣xi)∇θs(xi,yj)
其中
P
(
y
j
∣
x
i
)
Q
(
y
j
∣
x
i
)
frac{P(y_j|x_i)}{Q(y_j|x_i)}
Q(yj∣xi)P(yj∣xi)就是importacne sampling中的重要性权重,分布Q与分布P越接近,则权重越大,在公式(9)中,我们从分布Q中采样B个样本,计算近似期望
在得到期望的近似计算公式后,我们再将
P
(
y
j
∣
x
i
)
P(y_j|x_i)
P(yj∣xi)的计算公式代入
E
P
[
∇
θ
s
(
x
i
,
y
j
)
]
≈
1
B
∑
j
∈
[
B
]
P
(
y
j
∣
x
i
)
Q
(
y
j
∣
x
i
)
∇
θ
s
(
x
i
,
y
j
)
=
1
B
∑
j
∈
[
B
]
e
s
(
x
i
,
y
j
)
Q
(
y
j
∣
x
i
)
∑
k
∈
C
e
s
(
x
i
,
y
k
)
∇
θ
s
(
x
i
,
y
j
)
=
1
B
∑
j
∈
[
B
]
e
s
(
x
i
,
y
j
)
−
l
n
Q
(
y
j
∣
x
i
)
∑
k
∈
C
e
s
(
x
i
,
y
k
)
∇
θ
s
(
x
i
,
y
j
)
begin{align} E_{P}[nabla_theta s(x_i, y_j)] &approx frac{1}{B}sum_{j in [B]} frac{P(y_j|x_i)}{Q(y_j|x_i)}nabla_theta s(x_i, y_j) \ &= frac{1}{B}sum_{j in [B]} frac{e^{s(x_i, y_j)}}{Q(y_j|x_i)sum_{kin C} e^{s(x_i, y_k)}} nabla_theta s(x_i, y_j) \ &= frac{1}{B}sum_{j in [B]} frac{e^{s(x_i, y_j)-lnQ(y_j|x_i)}}{sum_{kin C} e^{s(x_i, y_k)}} nabla_theta s(x_i, y_j) end{align}
EP[∇θs(xi,yj)]≈B1j∈[B]∑Q(yj∣xi)P(yj∣xi)∇θs(xi,yj)=B1j∈[B]∑Q(yj∣xi)∑k∈Ces(xi,yk)es(xi,yj)∇θs(xi,yj)=B1j∈[B]∑∑k∈Ces(xi,yk)es(xi,yj)−lnQ(yj∣xi)∇θs(xi,yj)
可以发现由于
P
(
y
j
∣
x
i
)
P(y_j|x_i)
P(yj∣xi)的计算引入了配分函数,导致计算量仍然十分庞大,因此需要对配分函数的计算进行简化,思路是构造一个期望的形式,然后同样采样B个样本近似计算期望
∑
k
∈
C
e
s
(
x
i
,
y
k
)
=
∑
k
∈
C
Q
(
y
k
∣
x
i
)
⋅
1
Q
(
y
k
∣
x
i
)
e
s
(
x
i
,
y
k
)
=
E
Q
[
Q
(
y
k
∣
x
i
)
e
s
(
x
i
,
y
k
)
−
l
n
Q
(
y
k
∣
x
i
)
]
=
E
Q
[
e
s
(
x
i
,
y
k
)
−
l
n
Q
(
y
k
∣
x
i
)
]
≈
1
B
∑
k
∈
[
B
]
e
s
(
x
i
,
y
k
)
−
l
n
Q
(
y
k
∣
x
i
)
begin{align} sum_{kin C} e^{s(x_i, y_k)} &= sum_{kin C} Q(y_k|x_i) cdot frac{1}{Q(y_k|x_i)} e^{s(x_i, y_k)} \ &= E_{Q}[Q(y_k|x_i) e^{s(x_i, y_k)-lnQ(y_k|x_i)}] \ &= E_{Q}[e^{s(x_i, y_k)-lnQ(y_k|x_i)}] \ &approx frac{1}{B}sum_{k in [B]} e^{s(x_i, y_k)-lnQ(y_k|x_i)} end{align}
k∈C∑es(xi,yk)=k∈C∑Q(yk∣xi)⋅Q(yk∣xi)1es(xi,yk)=EQ[Q(yk∣xi)es(xi,yk)−lnQ(yk∣xi)]=EQ[es(xi,yk)−lnQ(yk∣xi)]≈B1k∈[B]∑es(xi,yk)−lnQ(yk∣xi)
令
s
c
(
x
i
,
y
i
)
=
s
(
x
i
,
y
i
)
−
l
n
Q
(
y
i
∣
x
i
)
s^c(x_i, y_i) = s(x_i, y_i) - lnQ(y_i|x_i)
sc(xi,yi)=s(xi,yi)−lnQ(yi∣xi),即可得到最终的计算公式:
E
P
[
∇
θ
s
(
x
i
,
y
j
)
]
≈
1
B
∑
j
∈
[
B
]
s
c
(
x
i
,
y
j
)
1
B
∑
k
∈
[
B
]
s
c
(
x
i
,
y
k
)
∇
θ
s
(
x
i
,
y
j
)
=
∑
j
∈
[
B
]
s
c
(
x
i
,
y
j
)
∑
k
∈
[
B
]
s
c
(
x
i
,
y
k
)
∇
θ
s
(
x
i
,
y
j
)
begin{align} E_{P}[nabla_theta s(x_i, y_j)] &approx frac{1}{B}sum_{j in [B]} frac{s^c(x_i, y_j)}{frac{1}{B}sum_{k in [B]} s^c(x_i, y_k)} nabla_theta s(x_i, y_j) \ &= sum_{j in [B]} frac{s^c(x_i, y_j)}{sum_{k in [B]} s^c(x_i, y_k)} nabla_theta s(x_i, y_j) end{align}
EP[∇θs(xi,yj)]≈B1j∈[B]∑B1∑k∈[B]sc(xi,yk)sc(xi,yj)∇θs(xi,yj)=j∈[B]∑∑k∈[B]sc(xi,yk)sc(xi,yj)∇θs(xi,yj)
至此公式推导完毕,sampled softmax在实际使用中只需利用负采样得到数量较少的负样本,将修正后的分数代入log-softmax即可,大大减小了计算量,但同时也引入了bias,因此许多研究关注于提高采样分布的质量和偏差的修正
Reference
[1] Yang J, Yi X, Zhiyuan Cheng D, et al. Mixed negative sampling for learning two-tower neural networks in recommendations[C]. Companion Proceedings of the Web Conference 2020, 2020: 441-447.
[2] Bengio Y, Senécal J S. Adaptive importance sampling to accelerate training of a neural probabilistic language model[J]. IEEE Transactions on Neural Networks, 2008, 19(4): 713-722.
[3] Jean S, Cho K, Memisevic R, et al. On using very large target vocabulary for neural machine translation[J]. arXiv preprint arXiv:1412.2007, 2014.