[半监督学习] Deep Co-Training for Semi-Supervised Image Recognition
在监督学习领域, 深度神经网络在许多任务中已经取得了先进水平, 因此将其引入半监督学习, 并结合 Co-Training 思想, 用于处理半监督图像分类问题.
论文地址: Deep Co-Training for Semi-SupervisedImage Recognition
代码地址: https://github.com/AlanChou/Deep-Co-Training-for-Semi-Supervised-Image-Recognition
会议: ECCV 2018
任务: 分类
Co-Training 假设
D
=
S
∪
U
mathcal{D}=mathcal{S} cup mathcal{U}
D=S∪U 中的每个数据
x
x
x 有两个视图, 即
x
=
(
v
1
,
v
2
)
x = (v_1, v_2)
x=(v1,v2), 每个视图
v
i
v_i
vi 都足以学习一个有效的模型. 其中
S
mathcal{S}
S,
U
mathcal{U}
U 分别表示标记数据集和未标记数据集. 给定
D
mathcal{D}
D 的分布
X
mathcal{X}
X, Co-Training 假设表示如下:
f
(
x
)
=
f
1
(
v
1
)
=
f
2
(
v
2
)
,
∀
x
=
(
v
1
,
v
2
)
∼
X
f(x)=f_1(v_1)=f_2(v_2),forall x=(v_1,v_2) simmathcal{X}
f(x)=f1(v1)=f2(v2),∀x=(v1,v2)∼X
即对于在每个视图
v
i
v_i
vi 上训练的模型
f
i
f_i
fi, 都有一致的输出, 每个模型都能做出正确的预测. 在给定类标签的情况下, 两个视图条件独立. 基于这个假设, Co-Training 训练简述如下: 首先为
S
mathcal{S}
S 上的每个视图学习一个单独的分类器, 然后将两个分类器对
U
mathcal{U}
U 的预测逐渐加到
S
mathcal{S}
S 上继续进行训练.
将 Co-Training 扩展到深度神经网络中, 一个简单的办法是在
D
mathcal{D}
D 上训练两个神经网络, 但是这种方法有两个严重的缺点:
- 不能保证两个网络的视图是不同和互补的.
- 协同训练会使得两个网络在训练过程中趋于一致, 即 collapsed neural networks 现象.
基于此, 提出 Deep Co-Training(DCT), 通过最小化两个网络在
U
mathcal{U}
U 上的预测之间的 JS 散度来模拟 Co-Training 假设. 为了避免 collapsed neural networks, 通过训练对抗样本来施加视图差异约束(View Difference Constraint).
Deep Co-Training 算法
Co-Training Assumption in DCT
在 DCT 中,
v
1
(
x
)
v_1(x)
v1(x) 和
v
2
(
x
)
v_2(x)
v2(x) 是
x
x
x 在最终全连接层
f
i
(
⋅
)
f_i(·)
fi(⋅) 之前的卷积表示. 在标记数据集
S
mathcal{S}
S 上的标准交叉熵损失函数定义为:
L
s
u
p
(
x
,
y
)
=
H
(
y
,
f
1
(
v
1
(
x
)
)
)
+
H
(
y
,
f
2
(
v
2
(
x
)
)
)
mathcal{L}_{mathrm{sup}}(x,y)=H(y,f_1(v_1(x)))+H(y,f_2(v_2(x)))
Lsup(x,y)=H(y,f1(v1(x)))+H(y,f2(v2(x)))
其中
H
(
p
,
q
)
H(p,q)
H(p,q) 表示交叉熵. 而对于未标记数据集
U
mathcal{U}
U, 基于 Co-Training 假设, 期望
f
1
(
v
1
(
x
)
)
f_1(v_1(x))
f1(v1(x)) 和
f
2
(
v
2
(
x
)
)
f_2(v_2(x))
f2(v2(x)) 有相似的预测, 使用 JS 散度来进行
f
1
(
v
1
(
x
)
)
f_1(v_1(x))
f1(v1(x)) 和
f
2
(
v
2
(
x
)
)
f_2(v_2(x))
f2(v2(x)) 之间的相似性度量, 损失函数定义如下:
L
c
o
t
(
x
)
=
H
(
1
2
(
f
1
(
v
1
(
x
)
)
+
f
2
(
v
2
(
x
)
)
)
)
−
1
2
(
H
(
f
1
(
v
1
(
x
)
)
)
+
H
(
f
2
(
v
2
(
x
)
)
)
)
mathcal{L}_{mathrm{cot}}(x)=H(frac{1}{2}(f_1(v_1(x))+f_2(v_2(x))))-frac{1}{2}(H(f_1(v_1(x)))+H(f_2(v_2(x))))
Lcot(x)=H(21(f1(v1(x))+f2(v2(x))))−21(H(f1(v1(x)))+H(f2(v2(x))))
其中
H
(
p
)
H(p)
H(p) 表示
p
p
p 的熵.
View Difference Constraint in DCT
利用
g
(
x
)
g(x)
g(x) 从
D
mathcal{D}
D 中生成对抗样本数据集
D
′
mathcal{D}'
D′, 在
D
′
mathcal{D}'
D′ 中
f
1
(
v
1
(
g
(
x
)
)
)
≠
f
2
(
v
2
(
g
(
x
)
)
)
f_1(v_1(g(x))) neq f_2(v_2(g(x)))
f1(v1(g(x)))=f2(v2(g(x))). 希望
g
(
x
)
g(x)
g(x) 与
x
x
x 之间足够小, 以便于对抗样本还能保持自然的图像特征. 不过当
g
(
x
)
−
x
g(x)-x
g(x)−x 很小时, 有很大概率会出现
f
1
(
v
1
(
g
(
x
)
)
=
f
1
(
v
1
(
x
)
)
f_1(v_1(g(x))=f_1(v_1(x))
f1(v1(g(x))=f1(v1(x)) 和
f
2
(
v
2
(
g
(
x
)
)
=
f
2
(
v
2
(
x
)
)
f_2(v_2(g(x))=f_2(v_2(x))
f2(v2(g(x))=f2(v2(x)), 这就与我们的想法违背. 即希望当
f
1
(
v
1
(
g
(
x
)
)
=
f
1
(
v
1
(
x
)
)
f_1(v_1(g(x))=f_1(v_1(x))
f1(v1(g(x))=f1(v1(x)) 出现时, 需满足
f
2
(
v
2
(
g
(
x
)
)
≠
f
2
(
v
2
(
x
)
)
f_2(v_2(g(x))neq f_2(v_2(x))
f2(v2(g(x))=f2(v2(x)).
通过交叉熵来训练网络
f
1
f_1
f1,
f
2
f_2
f2, 使得可以抵抗相互的对抗示例:
L
d
i
f
(
x
)
=
H
(
f
1
(
v
1
(
x
)
)
,
f
2
(
v
2
(
g
1
(
x
)
)
)
)
+
H
(
f
1
(
v
1
(
g
2
(
x
)
)
)
,
f
2
(
v
2
(
x
)
)
)
mathcal{L}_{mathrm{dif}}(x)=H(f_1(v_1(x)), f_2(v_2(g_1(x))))+H(f_1(v_1(g_2(x))), f_2(v_2(x)))
Ldif(x)=H(f1(v1(x)),f2(v2(g1(x))))+H(f1(v1(g2(x))),f2(v2(x)))
其他文献中, 使用对抗技术可以作为正则化技术来平滑输出, 如 VAT. 或者创建负示例来收紧决策边界.
最终的损失函数定义为:
L
=
E
(
x
,
y
)
∈
S
L
s
u
p
(
x
,
y
)
+
λ
c
o
t
E
x
∈
U
L
c
o
t
(
x
)
+
λ
d
i
f
E
x
∈
D
L
d
i
f
(
x
)
mathcal{L}=mathbb{E}_{(x,y)inmathcal{S}}mathcal{L}_{mathrm{sup}}(x,y)+lambda_{mathrm{cot}}mathbb{E}_{xinmathcal{U}}mathcal{L}_{mathrm{cot}}(x)+lambda_{mathrm{dif}}mathbb{E}_{xinmathcal{D}}mathcal{L}_{mathrm{dif}}(x)
L=E(x,y)∈SLsup(x,y)+λcotEx∈ULcot(x)+λdifEx∈DLdif(x)
DCT 训练迭代过程
在 DCT 训练循环的每次迭代中, 两个神经网络
p
1
p_1
p1,
p
2
p_2
p2 接收不同的标记数据
(
x
b
1
,
y
b
1
)
(x_{b_1},y_{b_1})
(xb1,yb1),
(
x
b
2
,
y
b
2
)
(x_{b_2},y_{b_2})
(xb2,yb2). 通过 FGSM 分别生成对抗样本
g
1
(
x
b
1
∪
x
u
)
g_1(x_{b_1} cup x_u)
g1(xb1∪xu),
g
2
(
x
b
2
∪
x
u
)
g_2(x_{b_2} cup x_u)
g2(xb2∪xu). 使用梯度下降计算
L
mathcal{L}
L, 并更新
p
1
p_1
p1,
p
2
p_2
p2 的参数.