U-NET模型——pytorch实现

论文传送门:https://arxiv.org/pdf/1505.04597.pdf

U-NET分割任务中的典型网络。

U-NET模型结构:

模型整体呈“U”形,主要分为三个部分:

左侧contraction,提取特征,整体结构类似VGG(没有BN层);

右侧expansion,将特征层上采样至原图片大小,最后通过1x1卷积,输出segmentation map;

中间的copy and crop操作,多尺度特征图融合。

Overlap-tile策略:

U-NET使用Overlap-tile策略,即在图片输入模型前进行镜像padding,使得模型对于图片边缘的预测也有较高的准确率。输入图片(1x388x388)采用镜像padding后,得到input image tile(1x572x572),模型中所有的卷积操作均不进行padding处理,卷积后特征层尺寸减小,所以模型输出output segmentation map(2x388x388)与输入原图片尺寸相同,且多尺度特征融合时需要进行crop操作。

import torch
import torch.nn as nn
import torch.nn.functional as F


class UNET(nn.Module):  # 定义UNET模型结构
    def __init__(self, img_shape=(1, 388, 388)):  # 初始化方法
        super(UNET, self).__init__()  # 继承初始化方法
        self.img_shape = img_shape  # 输入图片形状,默认为(1,388,388)

        self.reflectionpad = nn.ReflectionPad2d(92)  # 对输入图片进行镜像padding,对应原文的Overlap-tile策略,经计算,padding长度为92,此值与输入图片尺寸无关
        self.contraction1 = self.contraction(self.img_shape[0], 64,
                                             maxpool=False)  # contraction结构块,(maxpool)+conv+relu+conv+relu,第一个结构块不含maxpool
        self.contraction2 = self.contraction(64, 128)  # contraction结构块
        self.contraction3 = self.contraction(128, 256)  # contraction结构块
        self.contraction4 = self.contraction(256, 512)  # contraction结构块
        self.contraction5 = self.contraction(512, 1024)  # contraction结构块

        self.deconv1 = nn.ConvTranspose2d(1024, 512, 2, 2, 0)  # transconv,实现upconv,上采样
        self.expansion1 = self.expansion(1024, 512)  # expansion结构块,conv+relu+conv+relu+(transconv),第一个结构块前单独进行一次upconv
        self.expansion2 = self.expansion(512, 256)  # expansion结构块
        self.expansion3 = self.expansion(256, 128)  # expansion结构块
        self.expansion4 = self.expansion(128, 64, upconv=False)  # expansion块,最后一个结构块不含upconv
        self.conv1 = nn.Conv2d(64, 2, 1, 1, 0)  # 1x1卷积,输出segmentation map

    def contraction(self, in_channel, out_channel, maxpool=True):  # 定义contraction结构块,对应模型左侧
        layers = []  # 列表,用于存放模型结构
        if maxpool:  # 如果进行maxpool
            layers += [nn.MaxPool2d(2, 2)]  # 添加maxpool
        layers += [nn.Conv2d(in_channel, out_channel, 3, 1, 0),  # 添加conv
                   nn.ReLU(),  # 添加relu
                   nn.Conv2d(out_channel, out_channel, 3, 1, 0),  # 添加conv
                   nn.ReLU()]  # 添加relu
        return nn.Sequential(*layers)  # 返回contraction结构块,(maxpool)+conv+relu+conv+relu,方便进行copy and crop操作

    def expansion(self, in_channel, out_channel, upconv=True):  # 定义expansion结构块,对应模型右侧
        layers = []  # 列表,用于存放模型结构
        layers += [nn.Conv2d(in_channel, out_channel, 3, 1, 0),  # 添加conv
                   nn.ReLU(),  # 添加relu
                   nn.Conv2d(out_channel, out_channel, 3, 1, 0),  # 添加conv
                   nn.ReLU()]  # 添加relu
        if upconv:  # 如果进行upconv
            layers += [nn.ConvTranspose2d(out_channel, out_channel // 2, 2, 2, 0)]  # 添加transconv
        return nn.Sequential(*layers)  # 返回expansion结构块,conv+relu+conv+relu+(transconv),方便进行copy and crop操作

    def crop(self, x, target_x):
        '''
        crop操作,将左侧特征层(n,c,h,w)裁剪至右侧特征层(n,c,h',w')
        :param x: 输入特征
        :param target_x:目标特征
        :return: 经过裁剪后,与目标特征尺寸相同的输入特征
        '''
        pad_h = -(x.shape[2] - target_x.shape[2]) // 2  # H维度上裁剪尺寸,为负值
        pad_w = -(x.shape[3] - target_x.shape[3]) // 2  # W维度上裁剪尺寸,为负值
        return F.pad(x, (pad_h, pad_h, pad_w, pad_w))  # 使用pad操作,输入pad为负值,即实现裁剪操作

    def forward(self, x):  # 前传函数
        x = self.reflectionpad(x)  # 镜像padding,(n,1,388,388)-->(n,1,572,572)
        x1 = self.contraction1(x)  # contraction,(n,1,572,572)-->(n,64,570,570)-->(n,64,568,568)
        x2 = self.contraction2(x1)  # contraction,(n,64,568,568)-->(n,128,284,284)-->(n,128,282,282)-->(n,128,280,280)
        x3 = self.contraction3(x2)  # contraction,(n,128,280,280)-->(n,256,140,140)-->(n,256,138,138)-->(n,256,136,136)
        x4 = self.contraction4(x3)  # contraction,(n,256,136,136)-->(n,512,68,68)-->(n,512,66,66)-->(n,512,64,64)
        x = self.contraction5(x4)  # contraction,(n,512,64,64)-->(n,1024,32,32)-->(n,1024,30,30)-->(n,1024,28,28)
        x = self.deconv1(x)  # upconv,(n,1024,28,28)-->(n,512,56,56)

        x4 = self.crop(x4, x)  # crop,(n,512,64,64)-->(n,512,56,56)
        x = torch.cat((x4, x), dim=1)  # cat,在C维度进行拼接,(n,512,56,56)+(n,512,56,56)-->(n,1024,56,56)
        x = self.expansion1(x)  # expasion,(n,1024,56,56)-->(n,512,54,54)-->(n,512,52,52)-->(n,256,104,104)
        x3 = self.crop(x3, x)  # crop,(n,256,128,128)-->(n,256,104,104)
        x = torch.cat((x3, x), dim=1)  # cat,在C维度进行拼接,(n,256,104,104)+(n,256,104,104)-->(n,512,104,104)
        x = self.expansion2(x)  # expasion,(n,512,104,104)-->(n,256,102,102)-->(n,256,100,100)-->(n,128,200,200)
        x2 = self.crop(x2, x)  # crop,(n,128,280,280)-->(n,128,200,200)
        x = torch.cat((x2, x), dim=1)  # cat,在C维度进行拼接,(n,128,200,200)+(n,128,200,200)-->(n,256,200,200)
        x = self.expansion3(x)  # expasion,(n,256,200,200)-->(n,128,198,198)-->(n,128,196,196)-->(n,64,392,392)
        x1 = self.crop(x1, x)  # crop,(n,64,568,568)-->(n,64,392,392)
        x = torch.cat((x1, x), dim=1)  # cat,在C维度进行拼接,(n,64,392,392)+(n,64,392,392)-->(n,128,392,392)
        x = self.expansion4(x)  # expasion,(n,128,392,392)-->(n,64,390,390)-->(n,64,338,338)
        x = self.conv1(x)  # 1x1conv,(n,64,338,338)-->(n,2,338,338)

        return x  # 返回与输入图片尺寸相同的segmentation map

本图文内容来源于网友网络收集整理提供,作为学习参考使用,版权属于原作者。
THE END
分享
二维码
< <上一篇
下一篇>>