Pytorch中使用torchvision实现deform_conv2d

``````class net(nn.Module):
def __init__(self):
super(net, self).__init__()
self.conv = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)

def forward(self, x):
out = self.relu(self.conv(x))
return out``````

``````class net(nn.Module):
def __init__(self):
super(dcn, self).__init__()
self.conv = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) #原卷积

self.conv_offset = nn.Conv2d(3, 18, kernel_size=3, stride=1, padding=1)
init_offset = torch.Tensor(np.zeros([18, 3, 3, 3]))
self.conv_offset.weight = torch.nn.Parameter(init_offset) #初始化为0

self.conv_mask = nn.Conv2d(3, 9, kernel_size=3, stride=1, padding=1)
init_mask = torch.Tensor(np.zeros([9, 1, 3, 3])+np.array([0.5]))

def forward(self, x):
offset = self.conv_offset(x)
out = torchvision.ops.deform_conv2d(input=x, offset=offset,
weight=self.conv.weight,
return out``````

需要注意的点有deform_conv2d的stride默认为（1, 1），padding默认为（0, 0），dilation默认为（1, 1）。

ok！这样就可以完美的将normal_conv替换成deform_conv了！不需要再去github上去看别人巨长的代码了！感谢pytorch！

normal_conv

deform_conv

THE END