活体检测 SSDG 代码学习记录

代码链接GitHub - taylover-pei/SSDG-CVPR2020: Single-Side Domain Generalization for Face Anti-Spoofing, CVPR2020

SSDG 模型整体结构图如下:

1)数据读取

读取三个不同源域的数据

import os
import random
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from utils.dataset import YunpeiDataset
from utils.utils import sample_frames

def get_dataset(src1_data, src1_train_num_frames, src2_data, src2_train_num_frames, src3_data, src3_train_num_frames,
                tgt_data, tgt_test_num_frames, batch_size):
    print('Load Source Data')
    print('Source Data: ', src1_data)
    src1_train_data_fake = sample_frames(flag=0, num_frames=src1_train_num_frames, dataset_name=src1_data)
    src1_train_data_real = sample_frames(flag=1, num_frames=src1_train_num_frames, dataset_name=src1_data)
    print('Source Data: ', src2_data)
    src2_train_data_fake = sample_frames(flag=0, num_frames=src2_train_num_frames, dataset_name=src2_data)
    src2_train_data_real = sample_frames(flag=1, num_frames=src2_train_num_frames, dataset_name=src2_data)
    print('Source Data: ', src3_data)
    src3_train_data_fake = sample_frames(flag=0, num_frames=src3_train_num_frames, dataset_name=src3_data)
    src3_train_data_real = sample_frames(flag=1, num_frames=src3_train_num_frames, dataset_name=src3_data)

    print('Load Target Data')
    print('Target Data: ', tgt_data)
    tgt_test_data = sample_frames(flag=2, num_frames=tgt_test_num_frames, dataset_name=tgt_data)

    src1_train_dataloader_fake = DataLoader(YunpeiDataset(src1_train_data_fake, train=True),
                                            batch_size=batch_size, shuffle=True)
    src1_train_dataloader_real = DataLoader(YunpeiDataset(src1_train_data_real, train=True),
                                            batch_size=batch_size, shuffle=True)
    src2_train_dataloader_fake = DataLoader(YunpeiDataset(src2_train_data_fake, train=True),
                                            batch_size=batch_size, shuffle=True)
    src2_train_dataloader_real = DataLoader(YunpeiDataset(src2_train_data_real, train=True),
                                            batch_size=batch_size, shuffle=True)
    src3_train_dataloader_fake = DataLoader(YunpeiDataset(src3_train_data_fake, train=True),
                                            batch_size=batch_size, shuffle=True)
    src3_train_dataloader_real = DataLoader(YunpeiDataset(src3_train_data_real, train=True),
                                            batch_size=batch_size, shuffle=True)
    tgt_dataloader = DataLoader(YunpeiDataset(tgt_test_data, train=False), batch_size=batch_size, shuffle=False)
    return src1_train_dataloader_fake, src1_train_dataloader_real, 
           src2_train_dataloader_fake, src2_train_dataloader_real, 
           src3_train_dataloader_fake, src3_train_dataloader_real, 
           tgt_dataloader

2)数据处理

######### data prepare #########
src1_img_real, src1_label_real = src1_train_iter_real.next()
src1_img_real = src1_img_real.cuda()
src1_label_real = src1_label_real.cuda()
input1_real_shape = src1_img_real.shape[0]

src2_img_real, src2_label_real = src2_train_iter_real.next()
src2_img_real = src2_img_real.cuda()
src2_label_real = src2_label_real.cuda()
input2_real_shape = src2_img_real.shape[0]

src3_img_real, src3_label_real = src3_train_iter_real.next()
src3_img_real = src3_img_real.cuda()
src3_label_real = src3_label_real.cuda()
input3_real_shape = src3_img_real.shape[0]

src1_img_fake, src1_label_fake = src1_train_iter_fake.next()
src1_img_fake = src1_img_fake.cuda()
src1_label_fake = src1_label_fake.cuda()
input1_fake_shape = src1_img_fake.shape[0]

src2_img_fake, src2_label_fake = src2_train_iter_fake.next()
src2_img_fake = src2_img_fake.cuda()
src2_label_fake = src2_label_fake.cuda()
input2_fake_shape = src2_img_fake.shape[0]

src3_img_fake, src3_label_fake = src3_train_iter_fake.next()
src3_img_fake = src3_img_fake.cuda()
src3_label_fake = src3_label_fake.cuda()
input3_fake_shape = src3_img_fake.shape[0]

input_data = torch.cat([src1_img_real, src1_img_fake, src2_img_real, src2_img_fake, src3_img_real, src3_img_fake], dim=0)

        source_label = torch.cat([src1_label_real, src1_label_fake,
                                  src2_label_real, src2_label_fake,
                                  src3_label_real, src3_label_fake], dim=0)

3)DG 网络搭建

调用 DG_model

import torch
import torch.nn as nn
from torchvision.models.resnet import ResNet, BasicBlock
import sys
import numpy as np
from torch.autograd import Variable
import random
import os

def l2_norm(input, axis=1):
    norm = torch.norm(input, 2, axis, True)
    output = torch.div(input, norm)
    return output

def resnet18(pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    # change your path
    model_path = r'D:Projectsface_anti_spoofingSSDG-CVPR2020-masterpretrained_modelresnet18-5c106cde.pth'
    if pretrained:
        model.load_state_dict(torch.load(model_path))
        print("loading model: ", model_path)
    # print(model)
    return model

class Feature_Generator_ResNet18(nn.Module):
    def __init__(self):
        super(Feature_Generator_ResNet18, self).__init__()
        model_resnet = resnet18(pretrained=True)
        self.conv1 = model_resnet.conv1
        self.bn1 = model_resnet.bn1
        self.relu = model_resnet.relu
        self.maxpool = model_resnet.maxpool
        self.layer1 = model_resnet.layer1
        self.layer2 = model_resnet.layer2
        self.layer3 = model_resnet.layer3
    def forward(self, input):
        feature = self.conv1(input)
        feature = self.bn1(feature)
        feature = self.relu(feature)
        feature = self.maxpool(feature)
        feature = self.layer1(feature)
        feature = self.layer2(feature)
        feature = self.layer3(feature)
        return feature

class Feature_Embedder_ResNet18(nn.Module):
    def __init__(self):
        super(Feature_Embedder_ResNet18, self).__init__()
        model_resnet = resnet18(pretrained=False)
        self.layer4 = model_resnet.layer4
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=1)
        self.bottleneck_layer_fc = nn.Linear(512, 512)
        self.bottleneck_layer_fc.weight.data.normal_(0, 0.005)
        self.bottleneck_layer_fc.bias.data.fill_(0.1)
        self.bottleneck_layer = nn.Sequential(
            self.bottleneck_layer_fc,
            nn.ReLU(),
            nn.Dropout(0.5)
        )

    def forward(self, input, norm_flag):
        feature = self.layer4(input)
        feature = self.avgpool(feature)
        feature = feature.view(feature.size(0), -1)
        feature = self.bottleneck_layer(feature)
        if (norm_flag):
            feature_norm = torch.norm(feature, p=2, dim=1, keepdim=True).clamp(min=1e-12) ** 0.5 * (2) ** 0.5
            feature = torch.div(feature, feature_norm)
        return feature

class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        self.classifier_layer = nn.Linear(512, 2)
        self.classifier_layer.weight.data.normal_(0, 0.01)
        self.classifier_layer.bias.data.fill_(0.0)

    def forward(self, input, norm_flag):
        if(norm_flag):
            self.classifier_layer.weight.data = l2_norm(self.classifier_layer.weight, axis=0)
            classifier_out = self.classifier_layer(input)
        else:
            classifier_out = self.classifier_layer(input)
        return classifier_out


class DG_model(nn.Module):
    def __init__(self, model):
        super(DG_model, self).__init__()
        if(model == 'resnet18'):
            self.backbone = Feature_Generator_ResNet18()
            self.embedder = Feature_Embedder_ResNet18()
        elif(model == 'maddg'):
            self.backbone = Feature_Generator_MADDG()
            self.embedder = Feature_Embedder_MADDG()
        else:
            print('Wrong Name!')
        self.classifier = Classifier()

    def forward(self, input, norm_flag):
        feature = self.backbone(input)
        feature = self.embedder(feature, norm_flag)
        classifier_out = self.classifier(feature, norm_flag)
        print(feature.shape)
        return classifier_out, feature

实例化网络:

model = DG_model('resnet18')

4)DG_model 前向传播

classifier_label_out, feature = net(input_data, config.norm_flag)

5)判别器网络

在判别器的反向传播中引入 GRL,作用是在训练早期阶段抑制噪声信号的影响。在训练初期,GRL 的系数很小,随着迭代次数的增加,系数逐渐增大

import torch
import torch.nn as nn
from torchvision.models.resnet import ResNet, BasicBlock
import sys
import numpy as np
from torch.autograd import Variable
import random
import os


class GRL(torch.autograd.Function):
    def __init__(self):
        self.iter_num = 0
        self.alpha = 10
        self.low = 0.0
        self.high = 1.0
        self.max_iter = 4000  # be same to the max_iter of config.py

    def forward(self, input):
        self.iter_num += 1
        return input * 1.0

    def backward(self, gradOutput):
        coeff = np.float(2.0 * (self.high - self.low) / (1.0 + np.exp(-self.alpha * self.iter_num / self.max_iter))
                         - (self.high - self.low) + self.low)
        return -coeff * gradOutput

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(512, 512)
        self.fc1.weight.data.normal_(0, 0.01)
        self.fc1.bias.data.fill_(0.0)
        self.fc2 = nn.Linear(512, 3)
        self.fc2.weight.data.normal_(0, 0.3)
        self.fc2.bias.data.fill_(0.0)
        self.ad_net = nn.Sequential(
            self.fc1,
            nn.ReLU(),
            nn.Dropout(0.5),
            self.fc2
        )
        self.grl_layer = GRL()

    def forward(self, feature):
        adversarial_out = self.ad_net(self.grl_layer(feature))
        return adversarial_out

实例化判别器网络

ad_net_real = Discriminator().to(device)

6)单边对抗学习

feature 为 DG_model 的第二个输出 (DG_model 有两个输出,一个是分类结果,一个是经backbone 和 embedder 网络提取的特征)

######### single side adversarial learning #########
input1_shape = input1_real_shape + input1_fake_shape
input2_shape = input2_real_shape + input2_fake_shape
# torch.narrow(input,dim,start,length),从input张量中返回一个范围限制后的张量,范围限制条件为:沿维度dim从start到start+length的范围区间,类似于数组切片用法
# 从feature中选出不同域的真人脸特征并拼接起来
feature_real_1 = feature.narrow(0, 0, input1_real_shape)  
feature_real_2 = feature.narrow(0, input1_shape, input2_real_shape)
feature_real_3 = feature.narrow(0, input1_shape+input2_shape, input3_real_shape)
feature_real = torch.cat([feature_real_1, feature_real_2, feature_real_3], dim=0)
discriminator_out_real = ad_net_real(feature_real)

7)非对称三元组损失

######### unbalanced triplet loss #########
real_domain_label_1 = torch.LongTensor(input1_real_shape, 1).fill_(0).cuda()
real_domain_label_2 = torch.LongTensor(input2_real_shape, 1).fill_(0).cuda()
real_domain_label_3 = torch.LongTensor(input3_real_shape, 1).fill_(0).cuda()
fake_domain_label_1 = torch.LongTensor(input1_fake_shape, 1).fill_(1).cuda()
fake_domain_label_2 = torch.LongTensor(input2_fake_shape, 1).fill_(2).cuda()
fake_domain_label_3 = torch.LongTensor(input3_fake_shape, 1).fill_(3).cuda()
source_domain_label = torch.cat([real_domain_label_1, fake_domain_label_1,
                                         real_domain_label_2, fake_domain_label_2,
                                         real_domain_label_3, fake_domain_label_3], dim=0).view(-1)
triplet = criterion["triplet"](feature, source_domain_label)

8)分类损失及单边对抗学习损失 (均为交叉熵损失)

######### cross-entropy loss #########
real_shape_list = []
real_shape_list.append(input1_real_shape)
real_shape_list.append(input2_real_shape)
real_shape_list.append(input3_real_shape)
real_adloss = Real_AdLoss(discriminator_out_real, criterion["softmax"], real_shape_list)
cls_loss = criterion["softmax"](classifier_label_out.narrow(0, 0, input_data.size(0)), source_label)

9)总体损失

total_loss = cls_loss + config.lambda_triplet * triplet + config.lambda_adreal * real_adloss

10)配置文件

I_C_M_to_O 协议下的配置文件如下:

class DefaultConfigs(object):
    seed = 666
    # SGD
    weight_decay = 5e-4
    momentum = 0.9
    # learning rate
    init_lr = 0.01
    lr_epoch_1 = 0
    lr_epoch_2 = 150
    # model
    pretrained = True
    model = 'resnet18'     # resnet18 or maddg
    # training parameters
    gpus = "3"
    batch_size = 10
    norm_flag = True
    max_iter = 4000
    lambda_triplet = 2
    lambda_adreal = 0.1
    # test model name
    tgt_best_model_name = 'model_best_0.08_29.pth.tar' 
    # source data information
    src1_data = 'casia'
    src1_train_num_frames = 1
    src2_data = 'replay'
    src2_train_num_frames = 1
    src3_data = 'msu'
    src3_train_num_frames = 1
    # target data information
    tgt_data = 'oulu'
    tgt_test_num_frames = 2
    # paths information
    checkpoint_path = './' + tgt_data + '_checkpoint/' + model + '/DGFANet/'
    best_model_path = './' + tgt_data + '_checkpoint/' + model + '/best_model/'
    logs = './logs/'

config = DefaultConfigs()

本图文内容来源于网友网络收集整理提供,作为学习参考使用,版权属于原作者。
THE END
分享
二维码
< <上一篇
下一篇>>