对抗生成网络GAN系列——GANomaly原理及源码解析
🍊作者简介:秃头小苏,致力于用最通俗的语言描述问题
🍊往期回顾:对抗生成网络GAN系列——GAN原理及手写数字生成小案例 对抗生成网络GAN系列——DCGAN简介及人脸图像生成案例 对抗生成网络GAN系列——AnoGAN原理及缺陷检测实战 对抗生成网络GAN系列——EGBAD原理及缺陷检测实战
🍊近期目标:写好专栏的每一篇文章
🍊支持小苏:点赞👍🏼、收藏⭐、留言📩
文章目录
对抗生成网络GAN系列——GANomaly原理及源码解析
写在前面
在前面,我已经介绍过好几篇有关GAN的文章,链接如下:
- [1]对抗生成网络GAN系列——GAN原理及手写数字生成小案例 🍁🍁🍁
- [2]对抗生成网络GAN系列——DCGAN简介及人脸图像生成案例🍁🍁🍁
- [3]对抗生成网络GAN系列——CycleGAN原理🍁🍁🍁
- [4] 对抗生成网络GAN系列——AnoGAN原理及缺陷检测实战 🍁🍁🍁
- [5]对抗生成网络GAN系列——EGBAD原理及缺陷检测实战🍁🍁🍁
- [6]对抗生成网络GAN系列——WGAN原理及实战演练🍁🍁🍁
这篇文章我将来为大家介绍GANomaly,论文名为:Semi-Supervised Anomaly Detection via Adversarial Training。这篇文章同样是实现缺陷检测的,因此在阅读本文之前建议你对使用GAN网络实现缺陷检测有一定的了解,可以参考上文链接中的[4]和[5]。
准备好了吗,嘟嘟嘟,开始发车。🚖🚖🚖
GANomaly原理解析
【阅读此部分前建议对GAN的原理及GAN在缺陷检测上的应用有所了解,详情点击写在前面中的链接查看,本篇文章我不会再介绍GAN的一些先验知识。】
GANomaly结构
这部分为大家介绍GANomaly的原理,其实我们一起来看下图就足够了:
图1 GANomaly结构图
我们还是先来对上图中的结构做一些解释。从直观的颜色上来看,我们可以分成两类,一类是红色的Encoder结构,一类是蓝色的Decoder结构。Encoder主要就是降维的作用啦,如将一张张图片数据压缩成一个个潜在向量;相反,Decoder就是升维的作用,如将一个个潜在向量重建成一张张图片。按照论文描述的结构来分,可以分成三个子结构,分别为生成器网络G,编码器网络E和判别器网络D。下面分别来介绍介绍这三个子结构:
-
生成器网络G
生成器网络G由两个部分组成,分别为编码器G
E
(
x
)
)
G_E(x))
GE(x))和解码器
G
D
(
z
)
G_D(z)
GD(z),其实这就是一个自动编码器结构,主要用来学习输入x的数据分布并重建图像
x
^
{hat x}
x^。我们一个个来看,先看
G
E
(
x
)
G_E(x)
GE(x)结构,假设我们的输入x维度为
R
C
×
H
×
W
mathbb{R}^{C×H×W}
RC×H×W,经过
G
E
(
x
)
G_E(x)
GE(x)结构后,变成一个向量
z
z
z,其维度为
R
d
mathbb{R}^d
Rd。【
G
E
(
x
)
G_E(x)
GE(x)具体结构很简单啦,这里就不详细介绍了。我会在源码解析部分给出,大家肯定一看就会。】接着我们来看
G
D
(
z
)
G_D(z)
GD(z)结构,它会将刚刚得到的向量z上采样成
x
^
hat x
x^,
x
^
hat x
x^的维度和
x
x
x一致,都为
R
C
×
H
×
W
mathbb{R}^{C×H×W}
RC×H×W。关于
G
D
(
Z
)
G_D(Z)
GD(Z)结构也很简单,其主要用到了转置卷积,对于转置卷积不了解的可以看博客[2]了解详情。生成器网络G就为大家介绍完了,是不是发现很简单呢。总结下来就两步,第一步让输入x通过
G
E
(
x
)
G_E(x)
GE(x)得到z,第二步让z通过
G
D
(
Z
)
G_D(Z)
GD(Z)变成
x
^
hat x
x^。这两步也可以用一步表示,即
x
^
=
G
(
x
)
hat x=G(x)
x^=G(x)。
思来想去我还是想在这里给大家抛出一个问题,我们传统的GAN是怎么通过生成器来构建假图像的呢?和GANomaly有区别吗?其实这个问题的答案很简单,大家都稍稍思考一下,我就不给答案了,不明白的评论区见吧!!!🥂🥂🥂
-
编码器网络E
编码器网络E的作用是将生成器得到的
x
^
hat x
x^压缩成一个向量
z
^
hat z
z^,是不是发现和生成器网络中的
G
E
(
x
)
G_E(x)
GE(x)很像呢,其实呀,它俩的结构就是完全一样的,生成的
z
^
hat z
z^ 和
x
^
hat x
x^ 的维度一致,这是方便后面的损失比较。
-
判别器网络D
判别器网络D和我们之前介绍DCGAN时的结构是一样的,都是将真实数据
x
x
x和生成数据
x
^
hat x
x^输入网络,然后得出一个分数。
GANomaly损失函数
GANomaly的损失函数分为两部分,第一部分是生成器损失,第二部分为判别器损失,下面我们分别来进行介绍:
-
生成器损失函数
生成器损失函数又由三个部分组成,分别如下:
-
Adversari Loss
我还是直接上公式吧,如下:
L
a
d
v
=
E
x
∼
p
x
∣
∣
f
(
x
)
−
E
x
∼
p
x
f
(
G
(
x
)
)
∣
∣
2
L_{adv}=E_{x sim px}||f(x)-E_{x sim px}f(G(x))||_2
Ladv=Ex∼px∣∣f(x)−Ex∼pxf(G(x))∣∣2
这个公式对应图一中的
L
a
d
v
=
∣
∣
f
(
x
)
−
f
(
x
^
)
∣
∣
2
L_{adv}=||f(x)-f(hat x)||_2
Ladv=∣∣f(x)−f(x^)∣∣2🍵🍵🍵这个损失函数应该很好理解,在前面介绍的GAN网络都有提及,
f
(
∗
)
f(*)
f(∗)表示判别器网络某个中间层的输出。这个损失函数的作用就是让两张图像
x
和
x
^
x和hat x
x和x^尽可能接近,也就是让生成器生成的图片更加逼真。
-
Contextual Loss
同样的,直接来上公式,如下:
L
c
o
n
=
E
x
∼
p
x
∣
∣
x
−
G
(
x
)
∣
∣
1
L_{con}=E_{x sim px}||x-G(x)||_1
Lcon=Ex∼px∣∣x−G(x)∣∣1
这个公式对应图一中的
L
c
o
n
=
∣
∣
x
−
x
^
∣
∣
1
L_{con}=||x-hat x||_1
Lcon=∣∣x−x^∣∣1🍵🍵🍵这个函数其实也是要让两张图像
x
和
x
^
x和hat x
x和x^尽可能接近。至于这里为什么用的是L1范数而不是L2范数,作者在论文中说这里使用L1范数的效果要比使用L2范数的效果好,这属于实验得到的结论,大家也不用过于纠结。
-
Encoder Loss
话不多说,上公式,如下:
L
e
n
c
=
E
x
∼
p
x
∣
∣
G
E
(
x
)
−
E
(
G
(
x
)
)
∣
∣
2
L_{enc}=E_{x sim px}||G_E(x)-E(G(x))||_2
Lenc=Ex∼px∣∣GE(x)−E(G(x))∣∣2
这个公式对应图一中的
L
e
n
c
=
∣
∣
z
−
z
^
∣
∣
2
L_{enc}=||z-hat z||_2
Lenc=∣∣z−z^∣∣2🍵🍵🍵这里的损失函数在我看来主要作用就是让我们在推理过程中的效果更好,这里就像AnoGAN中不断搜索最优的那个z的作用。
如果大家这里读过cycleGAN的论文的话,可能会觉得这个损失函数有点类似cycleGAN中的循环一致性损失。我觉得这篇文章的思想可能借鉴了cycleGAN中的思想,感兴趣的可以去阅读一下,非常有意思的一篇文章!!!🥃🥃🥃
生成器总的损失是上述三种损失的加权和,如下:
L
=
w
a
d
v
L
a
d
v
+
w
c
o
n
L
c
o
n
+
w
e
n
c
L
e
n
c
L=w_{adv}L_{adv}+w_{con}L_{con}+w_{enc}L_{enc}
L=wadvLadv+wconLcon+wencLenc
在论文提供的源码中,默认
w
c
o
n
=
50
,
w
a
d
v
=
w
e
n
c
=
1
w_{con}=50,w_{adv}=w_{enc}=1
wcon=50,wadv=wenc=1。
-
-
判别器损失函数
判别器的损失函数就和原始GAN一样,如下:【不清楚的点击☞☞☞了解详情】
这部分我直接先放上代码吧,不多,也很容易理解,如下:
self.l_bce = nn.BCELoss() # Real - Fake Loss self.err_d_real = self.l_bce(self.pred_real, self.real_label) self.err_d_fake = self.l_bce(self.pred_fake, self.fake_label) # NetD Loss & Backward-Pass self.err_d = (self.err_d_real + self.err_d_fake) * 0.5
GANomaly测试阶段
在上一小节,为大家介绍了GANomaly的损失函数,这是在测试阶段使用的。GANomaly针对的是异常检测任务,在测试阶段我们会对输入的数据进行评分,根据评分的结果来判定输入是否异常。在GANomaly中使用的评分函数就是我们上一小节介绍的Encoder Loss,对于一个测试数据x,用
A
(
x
)
A(x)
A(x)表示其异常得分,则:
A
(
x
)
=
∣
∣
G
E
(
x
)
−
E
(
G
(
x
)
)
∣
∣
2
A(x)=||G_E(x)-E(G(x))||_2
A(x)=∣∣GE(x)−E(G(x))∣∣2
这里大家需要注意以下,论文中
A
(
x
)
A(x)
A(x)的表达式使用的是L1范数,但是从我阅读论文提供的源码来看,代码中使用的是L2范数。这里保持和源码一致,使用L2范数。代码中关于此部分的描述如下:
# latent_i表示G_E(x),latent_o表示E(G(x))。torch.pow(m,2)=m^2
error = torch.mean(torch.pow((latent_i-latent_o), 2), dim=1)
GANomaly源码解析
这里直接使用论文中提供的源码地址:GANomaly源码🌱🌱🌱
GANomaly模型搭建
其实通过我前文的讲解,不知道大家能否感受到GANomaly模型其实是不复杂的。需要注意的是在介绍GANomaly结构时我们将模型分为了三个子结构,分别为生成器网络G、编码器网络E、判别器网络D。但是在代码中我们将生成器网络G和编码器网络E合并在一块儿了,也称为生成器网络G。
下面我给出这部分的代码,大家注意一下这里面的超参数比较多,为了方便大家阅读,我把这里用到超参数的整理出来,如下图所示:
""" Network architectures.
"""
# pylint: disable=W0221,W0622,C0103,R0913
##
import torch
import torch.nn as nn
import torch.nn.parallel
from options import Options
##
def weights_init(mod):
"""
Custom weights initialization called on netG, netD and netE
:param m:
:return:
"""
classname = mod.__class__.__name__
if classname.find('Conv') != -1:
mod.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
mod.weight.data.normal_(1.0, 0.02)
mod.bias.data.fill_(0)
###
class Encoder(nn.Module):
"""
DCGAN ENCODER NETWORK
"""
def __init__(self, isize, nz, nc, ndf, ngpu, n_extra_layers=0, add_final_conv=True):
super(Encoder, self).__init__()
self.ngpu = ngpu
assert isize % 16 == 0, "isize has to be a multiple of 16"
main = nn.Sequential()
# input is nc x isize x isize
main.add_module('initial-conv-{0}-{1}'.format(nc, ndf),
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False))
main.add_module('initial-relu-{0}'.format(ndf),
nn.LeakyReLU(0.2, inplace=True))
csize, cndf = isize / 2, ndf # csize=16,cndf=64
# Extra layers
for t in range(n_extra_layers):
main.add_module('extra-layers-{0}-{1}-conv'.format(t, cndf),
nn.Conv2d(cndf, cndf, 3, 1, 1, bias=False))
main.add_module('extra-layers-{0}-{1}-batchnorm'.format(t, cndf),
nn.BatchNorm2d(cndf))
main.add_module('extra-layers-{0}-{1}-relu'.format(t, cndf),
nn.LeakyReLU(0.2, inplace=True))
while csize > 4:
in_feat = cndf
out_feat = cndf * 2
main.add_module('pyramid-{0}-{1}-conv'.format(in_feat, out_feat),
nn.Conv2d(in_feat, out_feat, 4, 2, 1, bias=False))
main.add_module('pyramid-{0}-batchnorm'.format(out_feat),
nn.BatchNorm2d(out_feat))
main.add_module('pyramid-{0}-relu'.format(out_feat),
nn.LeakyReLU(0.2, inplace=True))
cndf = cndf * 2
csize = csize / 2
# state size. K x 4 x 4
if add_final_conv:
main.add_module('final-{0}-{1}-conv'.format(cndf, 1),
nn.Conv2d(cndf, nz, 4, 1, 0, bias=False))
self.main = main
def forward(self, input):
if self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
else:
output = self.main(input)
return output
##
class Decoder(nn.Module):
"""
DCGAN DECODER NETWORK
"""
def __init__(self, isize, nz, nc, ngf, ngpu, n_extra_layers=0):
super(Decoder, self).__init__()
self.ngpu = ngpu
assert isize % 16 == 0, "isize has to be a multiple of 16"
cngf, tisize = ngf // 2, 4 #cngf=32 ,tisize=4
while tisize != isize:
cngf = cngf * 2
tisize = tisize * 2
main = nn.Sequential()
# input is Z, going into a convolution
main.add_module('initial-{0}-{1}-convt'.format(nz, cngf),
nn.ConvTranspose2d(nz, cngf, 4, 1, 0, bias=False))
main.add_module('initial-{0}-batchnorm'.format(cngf),
nn.BatchNorm2d(cngf))
main.add_module('initial-{0}-relu'.format(cngf),
nn.ReLU(True))
csize, _ = 4, cngf
while csize < isize // 2:
main.add_module('pyramid-{0}-{1}-convt'.format(cngf, cngf // 2),
nn.ConvTranspose2d(cngf, cngf // 2, 4, 2, 1, bias=False))
main.add_module('pyramid-{0}-batchnorm'.format(cngf // 2),
nn.BatchNorm2d(cngf // 2))
main.add_module('pyramid-{0}-relu'.format(cngf // 2),
nn.ReLU(True))
cngf = cngf // 2
csize = csize * 2
# Extra layers
for t in range(n_extra_layers):
main.add_module('extra-layers-{0}-{1}-conv'.format(t, cngf),
nn.Conv2d(cngf, cngf, 3, 1, 1, bias=False))
main.add_module('extra-layers-{0}-{1}-batchnorm'.format(t, cngf),
nn.BatchNorm2d(cngf))
main.add_module('extra-layers-{0}-{1}-relu'.format(t, cngf),
nn.ReLU(True))
main.add_module('final-{0}-{1}-convt'.format(cngf, nc),
nn.ConvTranspose2d(cngf, nc, 4, 2, 1, bias=False))
main.add_module('final-{0}-tanh'.format(nc),
nn.Tanh())
self.main = main
def forward(self, input):
if self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
else:
output = self.main(input)
return output
## 判别器网络结构
class NetD(nn.Module):
"""
DISCRIMINATOR NETWORK
"""
def __init__(self, opt):
super(NetD, self).__init__()
model = Encoder(opt.isize, 1, opt.nc, opt.ngf, opt.ngpu, opt.extralayers)
layers = list(model.main.children())
self.features = nn.Sequential(*layers[:-1])
self.classifier = nn.Sequential(layers[-1])
self.classifier.add_module('Sigmoid', nn.Sigmoid())
def forward(self, x):
features = self.features(x)
features = features
classifier = self.classifier(features)
classifier = classifier.view(-1, 1).squeeze(1)
return classifier, features
## 生成器网络结构
class NetG(nn.Module):
"""
GENERATOR NETWORK
"""
def __init__(self, opt):
super(NetG, self).__init__()
self.encoder1 = Encoder(opt.isize, opt.nz, opt.nc, opt.ngf, opt.ngpu, opt.extralayers)
self.decoder = Decoder(opt.isize, opt.nz, opt.nc, opt.ngf, opt.ngpu, opt.extralayers)
self.encoder2 = Encoder(opt.isize, opt.nz, opt.nc, opt.ngf, opt.ngpu, opt.extralayers)
def forward(self, x):
latent_i = self.encoder1(x)
gen_imag = self.decoder(latent_i)
latent_o = self.encoder2(gen_imag)
return gen_imag, latent_i, latent_o
GANomaly损失函数
我们在理论部分已经介绍了GANomaly的损失函数,那么在代码上它们都是一一对应的,实现起来也很简单,如下:
## 定义L1 Loss
def l1_loss(input, target):
return torch.mean(torch.abs(input - target))
## 定义L2 Loss
def l2_loss(input, target, size_average=True):
if size_average:
return torch.mean(torch.pow((input-target), 2))
else:
return torch.pow((input-target), 2)
self.l_adv = l2_loss
self.l_con = nn.L1Loss()
self.l_enc = l2_loss
self.err_g_adv = self.l_adv(self.netd(self.input)[1], self.netd(self.fake)[1])
self.err_g_con = self.l_con(self.fake, self.input)
self.err_g_enc = self.l_enc(self.latent_o, self.latent_i)
self.err_g = self.err_g_adv * self.opt.w_adv +
self.err_g_con * self.opt.w_con +
self.err_g_enc * self.opt.w_enc
上述代码为GANomaly生成器损失函数代码,判别器的损失函数代码已经在理论部分为大家介绍了,这里就不在赘述了。🍄🍄🍄
小结
这里我并没有很详细的为大家解读代码,但是把一些关键的部分都给大家介绍了。会了这些其实你完全可以自己实现一个GANomaly网络,或者对我之前在Anogan中的代码稍加改造也可以达到一样的效果。论文中提供的源码感兴趣的大家可以自己去调试一下,代码量也不算多,但有的地方理解起来也有一定的困难,总之大家加油吧!!!🌼🌼🌼
参考链接
GANomaly: Semi-Supervised Anomaly Detection via Adversarial Training 🍁🍁🍁
如若文章对你有所帮助,那就🛴🛴🛴