MatchNet论文复现过程记录

MatchNet论文复现过程记录

原文为《Matchnet: Unifying feature and metric learning for patch-based matching》1:本文复现基于PyTorch深度学习框架,版本(1.7.1+cu110)。

I.Network architecture

matchnet结构

根据论文中描述,MatchNet包括:

A. Feature network

该特征提取网络类似AlexNet2,具体结构如下:
FeatureNet其中,PS: patch size for convolution and pooling layers; S: stride. Layer types: C: convolution, MP: max-pooling, FC: fully-connected.

B. Metric network

包括三个全连接层,FC3后接Softmax作为输出。

C. MatchNet in training

基于patch的匹配任务通常假设patch在计算相似度之前,先经过相同的特征编码。因此,论文中采用Two-tower structure with tied parameters结构,即,仅采用一个特征提取网络,在训练过程中,可以理解为同时使用了两个参数共享的特征提取网络去连接度量网络,更新任何一个特征提取网络,将会使得两个网络的参数都发生变化。(这里直接讲比较难理解,具体可以看代码实现。)


具体代码实现如下:

import torch
import torch.nn as nn

class FeatureNet(nn.Module):
    """特征提取网络
    """
    def __init__(self):
        super(FeatureNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=24, kernel_size=7, padding=3, stride=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=0),
            nn.Conv2d(in_channels=24, out_channels=64, kernel_size=5, padding=2, stride=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=0),   
            nn.Conv2d(in_channels=64, out_channels=96, kernel_size=3, padding=1, stride=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, padding=1, stride=1),  
            nn.ReLU(),
            nn.Conv2d(in_channels=96, out_channels=64, kernel_size=3, padding=1, stride=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
        )
    
    def forward(self, x):
        return self.features(x)

class MetricNet(nn.Module):
    """度量网络
    """
    def __init__(self):
        super(MetricNet, self).__init__()
        self.features = nn.Sequential(
            nn.Linear(in_features=6272, out_features=1024),
            nn.ReLU(),
            nn.Linear(in_features=1024, out_features=1024),
            nn.ReLU(),
            nn.Linear(in_features=1024, out_features=2),
            # nn.Softmax(dim=1) 
            ''' 这里原本应该接Softmax,但损失函数采用的是交叉熵损失,
            而Pytorch中的torch.nn.CrossEntropyLoss()方法包括Softmax,
            具体可参考文档https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html?highlight=nn%20crossentropyloss#torch.nn.CrossEntropyLoss
            '''
        )

    def forward(self, x):
        return self.features(x)

class MatchNet(nn.Module):
    def __init__(self):
        super(MatchNet, self).__init__()
        
        # 只添加一个特征提取网络
        self.input_ = FeatureNet()
        self.input_.apply(weights_init)
        
        self.matric_network = MetricNet()
        self.matric_network.apply(weights_init)
    
    def forward(self, x):
        """x.shape = (2, C, H, W),即两个patch
        """
        # 两个patch进入同一个FeatureNet,相当于two-tower sharing same parameters
        feature1 = self.input_(x[0]).reshape((x[0].shape[0], -1)) #[256, 3136]
        feature2 = self.input_(x[1]).reshape((x[1].shape[0], -1))
        
        features = torch.cat((feature1, feature2), 1) #[256, 6272]
        
        return self.matric_network(features)
        
def weights_init(m):
		'''
		自定义权重初始化
		'''
    if isinstance(m, nn.Conv2d):
        nn.init.orthogonal_(m.weight.data, gain=0.6)
        try:
            nn.init.constant_(m.bias.data, 0.01)
        except Exception:
            pass
    return

参考文献


  1. Han等, 《Matchnet: Unifying feature and metric learning for patch-based matching》. ↩︎

  2. A. Krizhevsky, I. Sutskever, and G. E. Hinton. ImageNet classification with deep convolutional neural networks. In NIPS, 2012. ↩︎

本图文内容来源于网友网络收集整理提供,作为学习参考使用,版权属于原作者。
THE END
分享
二维码
< <上一篇

)">
下一篇>>