活体检测 SSDG 代码学习记录
SSDG 模型整体结构图如下:
1)数据读取
读取三个不同源域的数据
import os
import random
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from utils.dataset import YunpeiDataset
from utils.utils import sample_frames
def get_dataset(src1_data, src1_train_num_frames, src2_data, src2_train_num_frames, src3_data, src3_train_num_frames,
tgt_data, tgt_test_num_frames, batch_size):
print('Load Source Data')
print('Source Data: ', src1_data)
src1_train_data_fake = sample_frames(flag=0, num_frames=src1_train_num_frames, dataset_name=src1_data)
src1_train_data_real = sample_frames(flag=1, num_frames=src1_train_num_frames, dataset_name=src1_data)
print('Source Data: ', src2_data)
src2_train_data_fake = sample_frames(flag=0, num_frames=src2_train_num_frames, dataset_name=src2_data)
src2_train_data_real = sample_frames(flag=1, num_frames=src2_train_num_frames, dataset_name=src2_data)
print('Source Data: ', src3_data)
src3_train_data_fake = sample_frames(flag=0, num_frames=src3_train_num_frames, dataset_name=src3_data)
src3_train_data_real = sample_frames(flag=1, num_frames=src3_train_num_frames, dataset_name=src3_data)
print('Load Target Data')
print('Target Data: ', tgt_data)
tgt_test_data = sample_frames(flag=2, num_frames=tgt_test_num_frames, dataset_name=tgt_data)
src1_train_dataloader_fake = DataLoader(YunpeiDataset(src1_train_data_fake, train=True),
batch_size=batch_size, shuffle=True)
src1_train_dataloader_real = DataLoader(YunpeiDataset(src1_train_data_real, train=True),
batch_size=batch_size, shuffle=True)
src2_train_dataloader_fake = DataLoader(YunpeiDataset(src2_train_data_fake, train=True),
batch_size=batch_size, shuffle=True)
src2_train_dataloader_real = DataLoader(YunpeiDataset(src2_train_data_real, train=True),
batch_size=batch_size, shuffle=True)
src3_train_dataloader_fake = DataLoader(YunpeiDataset(src3_train_data_fake, train=True),
batch_size=batch_size, shuffle=True)
src3_train_dataloader_real = DataLoader(YunpeiDataset(src3_train_data_real, train=True),
batch_size=batch_size, shuffle=True)
tgt_dataloader = DataLoader(YunpeiDataset(tgt_test_data, train=False), batch_size=batch_size, shuffle=False)
return src1_train_dataloader_fake, src1_train_dataloader_real,
src2_train_dataloader_fake, src2_train_dataloader_real,
src3_train_dataloader_fake, src3_train_dataloader_real,
tgt_dataloader
2)数据处理
######### data prepare #########
src1_img_real, src1_label_real = src1_train_iter_real.next()
src1_img_real = src1_img_real.cuda()
src1_label_real = src1_label_real.cuda()
input1_real_shape = src1_img_real.shape[0]
src2_img_real, src2_label_real = src2_train_iter_real.next()
src2_img_real = src2_img_real.cuda()
src2_label_real = src2_label_real.cuda()
input2_real_shape = src2_img_real.shape[0]
src3_img_real, src3_label_real = src3_train_iter_real.next()
src3_img_real = src3_img_real.cuda()
src3_label_real = src3_label_real.cuda()
input3_real_shape = src3_img_real.shape[0]
src1_img_fake, src1_label_fake = src1_train_iter_fake.next()
src1_img_fake = src1_img_fake.cuda()
src1_label_fake = src1_label_fake.cuda()
input1_fake_shape = src1_img_fake.shape[0]
src2_img_fake, src2_label_fake = src2_train_iter_fake.next()
src2_img_fake = src2_img_fake.cuda()
src2_label_fake = src2_label_fake.cuda()
input2_fake_shape = src2_img_fake.shape[0]
src3_img_fake, src3_label_fake = src3_train_iter_fake.next()
src3_img_fake = src3_img_fake.cuda()
src3_label_fake = src3_label_fake.cuda()
input3_fake_shape = src3_img_fake.shape[0]
input_data = torch.cat([src1_img_real, src1_img_fake, src2_img_real, src2_img_fake, src3_img_real, src3_img_fake], dim=0)
source_label = torch.cat([src1_label_real, src1_label_fake,
src2_label_real, src2_label_fake,
src3_label_real, src3_label_fake], dim=0)
3)DG 网络搭建
调用 DG_model
import torch
import torch.nn as nn
from torchvision.models.resnet import ResNet, BasicBlock
import sys
import numpy as np
from torch.autograd import Variable
import random
import os
def l2_norm(input, axis=1):
norm = torch.norm(input, 2, axis, True)
output = torch.div(input, norm)
return output
def resnet18(pretrained=False, **kwargs):
"""Constructs a ResNet-18 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
# change your path
model_path = r'D:Projectsface_anti_spoofingSSDG-CVPR2020-masterpretrained_modelresnet18-5c106cde.pth'
if pretrained:
model.load_state_dict(torch.load(model_path))
print("loading model: ", model_path)
# print(model)
return model
class Feature_Generator_ResNet18(nn.Module):
def __init__(self):
super(Feature_Generator_ResNet18, self).__init__()
model_resnet = resnet18(pretrained=True)
self.conv1 = model_resnet.conv1
self.bn1 = model_resnet.bn1
self.relu = model_resnet.relu
self.maxpool = model_resnet.maxpool
self.layer1 = model_resnet.layer1
self.layer2 = model_resnet.layer2
self.layer3 = model_resnet.layer3
def forward(self, input):
feature = self.conv1(input)
feature = self.bn1(feature)
feature = self.relu(feature)
feature = self.maxpool(feature)
feature = self.layer1(feature)
feature = self.layer2(feature)
feature = self.layer3(feature)
return feature
class Feature_Embedder_ResNet18(nn.Module):
def __init__(self):
super(Feature_Embedder_ResNet18, self).__init__()
model_resnet = resnet18(pretrained=False)
self.layer4 = model_resnet.layer4
self.avgpool = nn.AdaptiveAvgPool2d(output_size=1)
self.bottleneck_layer_fc = nn.Linear(512, 512)
self.bottleneck_layer_fc.weight.data.normal_(0, 0.005)
self.bottleneck_layer_fc.bias.data.fill_(0.1)
self.bottleneck_layer = nn.Sequential(
self.bottleneck_layer_fc,
nn.ReLU(),
nn.Dropout(0.5)
)
def forward(self, input, norm_flag):
feature = self.layer4(input)
feature = self.avgpool(feature)
feature = feature.view(feature.size(0), -1)
feature = self.bottleneck_layer(feature)
if (norm_flag):
feature_norm = torch.norm(feature, p=2, dim=1, keepdim=True).clamp(min=1e-12) ** 0.5 * (2) ** 0.5
feature = torch.div(feature, feature_norm)
return feature
class Classifier(nn.Module):
def __init__(self):
super(Classifier, self).__init__()
self.classifier_layer = nn.Linear(512, 2)
self.classifier_layer.weight.data.normal_(0, 0.01)
self.classifier_layer.bias.data.fill_(0.0)
def forward(self, input, norm_flag):
if(norm_flag):
self.classifier_layer.weight.data = l2_norm(self.classifier_layer.weight, axis=0)
classifier_out = self.classifier_layer(input)
else:
classifier_out = self.classifier_layer(input)
return classifier_out
class DG_model(nn.Module):
def __init__(self, model):
super(DG_model, self).__init__()
if(model == 'resnet18'):
self.backbone = Feature_Generator_ResNet18()
self.embedder = Feature_Embedder_ResNet18()
elif(model == 'maddg'):
self.backbone = Feature_Generator_MADDG()
self.embedder = Feature_Embedder_MADDG()
else:
print('Wrong Name!')
self.classifier = Classifier()
def forward(self, input, norm_flag):
feature = self.backbone(input)
feature = self.embedder(feature, norm_flag)
classifier_out = self.classifier(feature, norm_flag)
print(feature.shape)
return classifier_out, feature
实例化网络:
model = DG_model('resnet18')
4)DG_model 前向传播
classifier_label_out, feature = net(input_data, config.norm_flag)
5)判别器网络
在判别器的反向传播中引入 GRL,作用是在训练早期阶段抑制噪声信号的影响。在训练初期,GRL 的系数很小,随着迭代次数的增加,系数逐渐增大
import torch
import torch.nn as nn
from torchvision.models.resnet import ResNet, BasicBlock
import sys
import numpy as np
from torch.autograd import Variable
import random
import os
class GRL(torch.autograd.Function):
def __init__(self):
self.iter_num = 0
self.alpha = 10
self.low = 0.0
self.high = 1.0
self.max_iter = 4000 # be same to the max_iter of config.py
def forward(self, input):
self.iter_num += 1
return input * 1.0
def backward(self, gradOutput):
coeff = np.float(2.0 * (self.high - self.low) / (1.0 + np.exp(-self.alpha * self.iter_num / self.max_iter))
- (self.high - self.low) + self.low)
return -coeff * gradOutput
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.fc1 = nn.Linear(512, 512)
self.fc1.weight.data.normal_(0, 0.01)
self.fc1.bias.data.fill_(0.0)
self.fc2 = nn.Linear(512, 3)
self.fc2.weight.data.normal_(0, 0.3)
self.fc2.bias.data.fill_(0.0)
self.ad_net = nn.Sequential(
self.fc1,
nn.ReLU(),
nn.Dropout(0.5),
self.fc2
)
self.grl_layer = GRL()
def forward(self, feature):
adversarial_out = self.ad_net(self.grl_layer(feature))
return adversarial_out
实例化判别器网络
ad_net_real = Discriminator().to(device)
6)单边对抗学习
feature 为 DG_model 的第二个输出 (DG_model 有两个输出,一个是分类结果,一个是经backbone 和 embedder 网络提取的特征)
######### single side adversarial learning #########
input1_shape = input1_real_shape + input1_fake_shape
input2_shape = input2_real_shape + input2_fake_shape
# torch.narrow(input,dim,start,length),从input张量中返回一个范围限制后的张量,范围限制条件为:沿维度dim从start到start+length的范围区间,类似于数组切片用法
# 从feature中选出不同域的真人脸特征并拼接起来
feature_real_1 = feature.narrow(0, 0, input1_real_shape)
feature_real_2 = feature.narrow(0, input1_shape, input2_real_shape)
feature_real_3 = feature.narrow(0, input1_shape+input2_shape, input3_real_shape)
feature_real = torch.cat([feature_real_1, feature_real_2, feature_real_3], dim=0)
discriminator_out_real = ad_net_real(feature_real)
7)非对称三元组损失
######### unbalanced triplet loss #########
real_domain_label_1 = torch.LongTensor(input1_real_shape, 1).fill_(0).cuda()
real_domain_label_2 = torch.LongTensor(input2_real_shape, 1).fill_(0).cuda()
real_domain_label_3 = torch.LongTensor(input3_real_shape, 1).fill_(0).cuda()
fake_domain_label_1 = torch.LongTensor(input1_fake_shape, 1).fill_(1).cuda()
fake_domain_label_2 = torch.LongTensor(input2_fake_shape, 1).fill_(2).cuda()
fake_domain_label_3 = torch.LongTensor(input3_fake_shape, 1).fill_(3).cuda()
source_domain_label = torch.cat([real_domain_label_1, fake_domain_label_1,
real_domain_label_2, fake_domain_label_2,
real_domain_label_3, fake_domain_label_3], dim=0).view(-1)
triplet = criterion["triplet"](feature, source_domain_label)
8)分类损失及单边对抗学习损失 (均为交叉熵损失)
######### cross-entropy loss #########
real_shape_list = []
real_shape_list.append(input1_real_shape)
real_shape_list.append(input2_real_shape)
real_shape_list.append(input3_real_shape)
real_adloss = Real_AdLoss(discriminator_out_real, criterion["softmax"], real_shape_list)
cls_loss = criterion["softmax"](classifier_label_out.narrow(0, 0, input_data.size(0)), source_label)
9)总体损失
total_loss = cls_loss + config.lambda_triplet * triplet + config.lambda_adreal * real_adloss
10)配置文件
I_C_M_to_O 协议下的配置文件如下:
class DefaultConfigs(object):
seed = 666
# SGD
weight_decay = 5e-4
momentum = 0.9
# learning rate
init_lr = 0.01
lr_epoch_1 = 0
lr_epoch_2 = 150
# model
pretrained = True
model = 'resnet18' # resnet18 or maddg
# training parameters
gpus = "3"
batch_size = 10
norm_flag = True
max_iter = 4000
lambda_triplet = 2
lambda_adreal = 0.1
# test model name
tgt_best_model_name = 'model_best_0.08_29.pth.tar'
# source data information
src1_data = 'casia'
src1_train_num_frames = 1
src2_data = 'replay'
src2_train_num_frames = 1
src3_data = 'msu'
src3_train_num_frames = 1
# target data information
tgt_data = 'oulu'
tgt_test_num_frames = 2
# paths information
checkpoint_path = './' + tgt_data + '_checkpoint/' + model + '/DGFANet/'
best_model_path = './' + tgt_data + '_checkpoint/' + model + '/best_model/'
logs = './logs/'
config = DefaultConfigs()