U-Net 论文笔记

U-Net 论文笔记

Abstract

论文链接: https://arxiv.org/pdf/1505.04597.pdf

这个文章是做医疗方面的,挺感兴趣的。 网络结构非常像“U”,见下图

avator

属于一个segmentation的任务,估计医疗的许多任务都和segmentation相关,所以很有必要知道了解一下相关的技术,做一下技术储备。

###

文章中说这个network是在FCN的基础上建立的。医疗领域不同于其它领域的一个重要的地方是数据相对比较少,所以如何在非常有限的数据集上训练出比较好的模型就显得非常的关键,所以数据增强在这个里面很重要。FCN主要的想法是用一些successive layers 来实现一个压缩的网络,其中用upsampling 来取代了pooling, 因此这些层增加了输出的resolution。然后为了位置的精确性,从压缩network来的高resolution的feature要和上采样的feature组合在一起,然后依赖于这些信息,后面再接上一系列的cnn。

这里在结构上的一个不同是在上采样的时候feature channels变得更多了,这里的解释是这将会给从前面来的高resolution的layers更多的context的信息。然后expansive path基本上和前面的是对称的。整个网络没有用到fc层。

为了预测图像边界区域的像素,用了”tiling strategy”,具体的细节需要对着代码来更清晰的理解。

###

  • 输入进来时的结构是这样的
class double_conv(nn.Module):
    '''(conv => BN => ReLU) * 2'''
    def __init__(self, in_ch, out_ch):
        super(double_conv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class inconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(inconv, self).__init__()
        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x):
        x = self.conv(x)
        return x


这里的padding给的默认值是1,paper里面显示的是cnn之后shape变了,是没有padding的,不过这在初始化的时候可以传入padding=0. paper里指的tile操作应该也是这个意思吧,即连续作了两次cnn,个人理解这是。

  • 然后是4个down的操作,其中相比上面的结构多了maxpooling
class down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(down, self).__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool2d(2),
            double_conv(in_ch, out_ch)
        )

    def forward(self, x):
        x = self.mpconv(x)
    return x

从上面的网络结构中可以看出,这4个down的操作,channels的数目一直在增加。

self.down1 = down(64, 128)
self.down2 = down(128, 256)
self.down3 = down(256, 512)
self.down4 = down(512, 1024)

  • 然后是up的操作
class up(nn.Module):
    def __init__(self, in_ch, out_ch, bilinear=True):
        super(up, self).__init__()

        #  would be a nice idea if the upsampling could be learned too,
        #  but my machine do not have enough memory to handle all those weights
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)

        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,
                        diffY // 2, diffY - diffY//2))
        
        # for padding issues, see 
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd

        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x

从网络结构图中可以看出,先做一个up的采样操作,然后要和之前的特征有一个concat的操作,然后再一起做一个二次cnn的操作,up的操作也有4层 因为和左边是对称的。

self.up1 = up(1024, 256)
self.up2 = up(512, 128)
self.up3 = up(256, 64)
self.up4 = up(128, 64)

  • 最后是输出层了,即接一个1×1的卷积操作。
class outconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(outconv, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 1)

    def forward(self, x):
        x = self.conv(x)
        return x

打赏,谢谢~~

取消

感谢您的支持,我会继续努力的!

扫码支持
扫码打赏,多谢支持~

打开微信扫一扫,即可进行扫码打赏哦