原始论文
U-Net: Convolutional Networks for Biomedical Image Segmentation:点击查看
网络结构
- 论文中进行卷积操作的时候没有用
padding
,导致卷积后图片尺寸变小。推荐可能是当年padding操作并不流行。我们这里复现的时候用了padding
,保持卷积后图片尺寸不变。输入给网络是什么尺寸的图像,那么输出也将是一样的尺寸。比如输入64 x 64
的图像,那么输出也将是64 x 64
。 - 论文中并没有用到
Batch Normalization
。推测是当时需要作医学图像分割的数据集很小,不需要用。我们这里复现的时候加上。 - 论文中提到的跳层连接,推测应该是使用
torch.cat()
进行通道数合并。
首先将网络结构中出现次数较多的两个蓝色箭头(Conv+Relu)进行封装。
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, 3, 1, 1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
def forward(self, x):
return self.double_conv(x)
然后将整个网络结构分为左、中、右三部分,具体划分方式如下: 左边由4个下采样(Pooling)和4个双卷积组成;中间一个双卷积;右边是4个上采样(反卷积)和4个双卷积,最后接一个
1 x 1
的卷积输出。
import torch
import torch.nn as nn
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, 3, 1, 1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
def forward(self, x):
return self.double_conv(x)
class UNet(nn.Module):
def __init__(self):
super().__init__()
# left
self.left_conv_1 = DoubleConv(3, 64)
self.down_1 = nn.MaxPool2d(2, 2)
self.left_conv_2 = DoubleConv(64, 128)
self.down_2 = nn.MaxPool2d(2, 2)
self.left_conv_3 = DoubleConv(128, 256)
self.down_3 = nn.MaxPool2d(2, 2)
self.left_conv_4 = DoubleConv(256, 512)
self.down_4 = nn.MaxPool2d(2, 2)
# center
self.center_conv = DoubleConv(512, 1024)
# right
self.up_1 = nn.ConvTranspose2d(1024, 512, 2, 2)
self.right_conv_1 = DoubleConv(1024, 512)
self.up_2 = nn.ConvTranspose2d(512, 256, 2, 2)
self.right_conv_2 = DoubleConv(512, 256)
self.up_3 = nn.ConvTranspose2d(256, 128, 2, 2)
self.right_conv_3 = DoubleConv(256, 128)
self.up_4 = nn.ConvTranspose2d(128, 64, 2, 2)
self.right_conv_4 = DoubleConv(128, 64)
# output
self.output = nn.Conv2d(64, 3, 1, 1, 0)
def forward(self, x):
# left
x1 = self.left_conv_1(x)
x1_down = self.down_1(x1)
x2 = self.left_conv_2(x1_down)
x2_down = self.down_2(x2)
x3 = self.left_conv_3(x2_down)
x3_down = self.down_3(x3)
x4 = self.left_conv_4(x3_down)
x4_down = self.down_4(x4)
# center
x5 = self.center_conv(x4_down)
# right
x6_up = self.up_1(x5)
temp = torch.cat((x6_up, x4), dim=1)
x6 = self.right_conv_1(temp)
x7_up = self.up_2(x6)
temp = torch.cat((x7_up, x3), dim=1)
x7 = self.right_conv_2(temp)
x8_up = self.up_3(x7)
temp = torch.cat((x8_up, x2), dim=1)
x8 = self.right_conv_3(temp)
x9_up = self.up_4(x8)
temp = torch.cat((x9_up, x1), dim=1)
x9 = self.right_conv_4(temp)
# output
output = self.output(x9)
return output
测试一下
如果代码实现如果任何问题,那么网络的输出维度和输入维度应该是一样的。
if __name__ == "__main__":
a = torch.rand(10, 3, 32, 32)
model = UNet()
b = model(a)
print(b.size()) # torch.Size([10, 3, 32, 32])
注意事项
问题描述
我们一般会用如下函数对网络进行初始化:
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
但会报错:
ModuleAttributeError: 'DoubleConv' object has no attribute 'weight'
原因分析
我们定义了一个类DoubleConv
,该类会被当作卷积层初始化,所以才会报错。
将classname.find('Conv')
中的Conv
改为Conv2d
或者改一下DoubleConv
的名字。
https://zhuanlan.zhihu.com/p/87593567 https://github.com/milesial/Pytorch-UNet