U-NET模型——pytorch实现

U-NET分割任务中的典型网络。

U-NET模型结构：

Overlap-tile策略：

``````import torch
import torch.nn as nn
import torch.nn.functional as F

class UNET(nn.Module):  # 定义UNET模型结构
def __init__(self, img_shape=(1, 388, 388)):  # 初始化方法
super(UNET, self).__init__()  # 继承初始化方法
self.img_shape = img_shape  # 输入图片形状，默认为(1,388,388)

self.contraction1 = self.contraction(self.img_shape[0], 64,
maxpool=False)  # contraction结构块，(maxpool)+conv+relu+conv+relu，第一个结构块不含maxpool
self.contraction2 = self.contraction(64, 128)  # contraction结构块
self.contraction3 = self.contraction(128, 256)  # contraction结构块
self.contraction4 = self.contraction(256, 512)  # contraction结构块
self.contraction5 = self.contraction(512, 1024)  # contraction结构块

self.deconv1 = nn.ConvTranspose2d(1024, 512, 2, 2, 0)  # transconv，实现upconv，上采样
self.expansion1 = self.expansion(1024, 512)  # expansion结构块，conv+relu+conv+relu+(transconv)，第一个结构块前单独进行一次upconv
self.expansion2 = self.expansion(512, 256)  # expansion结构块
self.expansion3 = self.expansion(256, 128)  # expansion结构块
self.expansion4 = self.expansion(128, 64, upconv=False)  # expansion块，最后一个结构块不含upconv
self.conv1 = nn.Conv2d(64, 2, 1, 1, 0)  # 1x1卷积，输出segmentation map

def contraction(self, in_channel, out_channel, maxpool=True):  # 定义contraction结构块，对应模型左侧
layers = []  # 列表，用于存放模型结构
if maxpool:  # 如果进行maxpool
layers += [nn.MaxPool2d(2, 2)]  # 添加maxpool
layers += [nn.Conv2d(in_channel, out_channel, 3, 1, 0),  # 添加conv
nn.ReLU(),  # 添加relu
nn.Conv2d(out_channel, out_channel, 3, 1, 0),  # 添加conv
nn.ReLU()]  # 添加relu
return nn.Sequential(*layers)  # 返回contraction结构块，(maxpool)+conv+relu+conv+relu，方便进行copy and crop操作

def expansion(self, in_channel, out_channel, upconv=True):  # 定义expansion结构块，对应模型右侧
layers = []  # 列表，用于存放模型结构
layers += [nn.Conv2d(in_channel, out_channel, 3, 1, 0),  # 添加conv
nn.ReLU(),  # 添加relu
nn.Conv2d(out_channel, out_channel, 3, 1, 0),  # 添加conv
nn.ReLU()]  # 添加relu
if upconv:  # 如果进行upconv
layers += [nn.ConvTranspose2d(out_channel, out_channel // 2, 2, 2, 0)]  # 添加transconv
return nn.Sequential(*layers)  # 返回expansion结构块，conv+relu+conv+relu+(transconv)，方便进行copy and crop操作

def crop(self, x, target_x):
'''
crop操作，将左侧特征层(n,c,h,w)裁剪至右侧特征层(n,c,h',w')
:param x: 输入特征
:param target_x:目标特征
:return: 经过裁剪后，与目标特征尺寸相同的输入特征
'''
pad_h = -(x.shape[2] - target_x.shape[2]) // 2  # H维度上裁剪尺寸，为负值
pad_w = -(x.shape[3] - target_x.shape[3]) // 2  # W维度上裁剪尺寸，为负值

def forward(self, x):  # 前传函数
x1 = self.contraction1(x)  # contraction,(n,1,572,572)-->(n,64,570,570)-->(n,64,568,568)
x2 = self.contraction2(x1)  # contraction,(n,64,568,568)-->(n,128,284,284)-->(n,128,282,282)-->(n,128,280,280)
x3 = self.contraction3(x2)  # contraction,(n,128,280,280)-->(n,256,140,140)-->(n,256,138,138)-->(n,256,136,136)
x4 = self.contraction4(x3)  # contraction,(n,256,136,136)-->(n,512,68,68)-->(n,512,66,66)-->(n,512,64,64)
x = self.contraction5(x4)  # contraction,(n,512,64,64)-->(n,1024,32,32)-->(n,1024,30,30)-->(n,1024,28,28)
x = self.deconv1(x)  # upconv,(n,1024,28,28)-->(n,512,56,56)

x4 = self.crop(x4, x)  # crop,(n,512,64,64)-->(n,512,56,56)
x = torch.cat((x4, x), dim=1)  # cat,在C维度进行拼接,(n,512,56,56)+(n,512,56,56)-->(n,1024,56,56)
x = self.expansion1(x)  # expasion,(n,1024,56,56)-->(n,512,54,54)-->(n,512,52,52)-->(n,256,104,104)
x3 = self.crop(x3, x)  # crop,(n,256,128,128)-->(n,256,104,104)
x = torch.cat((x3, x), dim=1)  # cat,在C维度进行拼接,(n,256,104,104)+(n,256,104,104)-->(n,512,104,104)
x = self.expansion2(x)  # expasion,(n,512,104,104)-->(n,256,102,102)-->(n,256,100,100)-->(n,128,200,200)
x2 = self.crop(x2, x)  # crop,(n,128,280,280)-->(n,128,200,200)
x = torch.cat((x2, x), dim=1)  # cat,在C维度进行拼接,(n,128,200,200)+(n,128,200,200)-->(n,256,200,200)
x = self.expansion3(x)  # expasion,(n,256,200,200)-->(n,128,198,198)-->(n,128,196,196)-->(n,64,392,392)
x1 = self.crop(x1, x)  # crop,(n,64,568,568)-->(n,64,392,392)
x = torch.cat((x1, x), dim=1)  # cat,在C维度进行拼接,(n,64,392,392)+(n,64,392,392)-->(n,128,392,392)
x = self.expansion4(x)  # expasion,(n,128,392,392)-->(n,64,390,390)-->(n,64,338,338)
x = self.conv1(x)  # 1x1conv,(n,64,338,338)-->(n,2,338,338)

return x  # 返回与输入图片尺寸相同的segmentation map``````

THE END