UNet结构介绍

Jachin Zhang

引言

UNet实际上是一种比较老的架构,最初于2015年的论文U-Net: Convolutional Networks for Biomedical Image Segmentation中提出,主要用于医学图像分割。但近年来随着AIGC的兴起,UNet也逐渐被用于其他领域,如图像生成、图像修复等。Unet的结构并不复杂,但效果不错。对于扩散模型来说,UNet结构几乎是标配,仅使用残差网络的效果远不及它。

下面对这个网络结构作简要介绍,并在后面给出Pytorch实现。

UNet结构

网络结构示意图
网络结构示意图

UNet结构如上图所示,它的主体是一个Encoder-Decoder的结构。其中,编码器由多个下采样层组成,解码器由多个上采样层组成。编码器和解码器之间通过跳跃连接连接,以保留更多的特征信息。因其结构类似于字母U,故得名UNet。除此之外,它的最大的特点来自于其跳层连接(copy and crop)。

  • conv 3x3, ReLU
    卷积层,卷积核大小为3x3,激活函数为ReLU。
  • max pool 2x2
    最大池化层,池化核大小为2x2。
  • up-conv 2x2
    这里是用于图像上采样的反卷积层,卷积核大小为2x2。
  • conv 1x1
    1*1的卷积层,可以调整通道数而不改变图像尺寸。

Pytorch实现

本文对UNet结构进行简要的代码实现。(叠甲:写的比较简陋,图像尺寸设的不好可能还会出现报错。此外,该网络可以结合时间步编码和注意力机制等,这里暂未体现。)

首先导入必要的库:

1
2
import torch
import torch.nn as nn

UNet类的初始化:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class UNet(nn.Module):
def __init__(self, in_channels, out_channels, device, layers=4):
super(UNet, self).__init__()
self.layers = layers
self.device = device
self.to_be_cropped_list = []

self.input_encoder = nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1)
self.output_decoder = nn.Conv2d(64, out_channels, kernel_size=3, stride=1, padding=1)

self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
for i in range(layers):
self.downs.append(nn.Sequential(
nn.Conv2d(64 * (2 ** i), 64 * (2 ** (i + 1)), kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64 * (2 ** (i + 1)), 64 * (2 ** (i + 1)), kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
))
self.ups.insert(0, nn.Sequential(
nn.Conv2d(64 * (2 ** (i + 1)), 64 * (2 ** i), kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64 * (2 ** i), 64 * (2 ** i), kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
))

参数:

  • in_channels:输入通道数
  • out_channels:输出通道数
  • device:设备(cpu或cuda)
  • layers:UNet的层数

input_encoderoutput_decoder用于改变图像的通道数,而downsups分别对应模型主体的编码器和解码器。这里为了方便,使用循环将编码器和解码器的采样层使用ModuleList存储,并且每轮循环(每层)对称地加入采样层。

注意到在这段代码中,我们并没有将池化和反卷积(即改变图像尺寸大小)的操作加入到downsups中,而是只有卷积的改变图像通道数的操作。这是因为该网络还需要实现“跳层连接”,即需要将最大池化之前的张量经过裁剪之后和上采样之后的张量进行拼接。如果将池化和反卷积的操作放在Sequential里面,拼接操作所需要的张量无法提取出来。

(事实上代码的写法可以改进,比如将上采样层和下采样层分别封装成类UpBlockDownBlock,这样可以将反卷积和池化操作包含其中,并提取出跳层连接所需的张量。但是我懒得改了QAQ)

还有一个小细节是,原论文中的卷积操作是将padding设为0了的,这样每次卷积会使得图像尺寸发生改变。我个人更倾向于设置padding使得图像尺寸不发生改变,这样在拼接操作时不会出现尺寸不匹配的问题。(其实是因为padding设为0会在拼接时发生报错,两个张量的通道数会相差1。这个问题暂时还没有修正,充分体现出作者的代码能力亟待提高。哭……)

将在图像编码过程中提取的张量经过裁剪操作后与解码过程中对应的张量拼接,其中的裁剪操作实现如下:

1
2
3
4
5
def crop_tensor(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
target_size = target.shape[2]
x_size = x.shape[2]
delta = (x_size - target_size) // 2
return x[:, :, delta:x_size - delta, delta:x_size - delta]

注意:这个函数输入的两个张量的通道数最好为偶数,否则后续的张量拼接可能会因为通道数不匹配出现报错。因此输入的图像尺寸(张量[Batch_size, Channels, Height, Width]中的后两个)最好为2的整数次幂。

图像编码和解码操作的实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def encode(self, x: torch.Tensor) -> torch.Tensor:
for idx, down in enumerate(self.downs):
# print(f'Down {idx + 1}.1\t', x.shape)
x = down(x) # conv
# print(f'Down {idx + 1}.2\t', x.shape)
self.to_be_cropped_list.insert(0, x) if idx < self.layers - 1 else None
x = nn.MaxPool2d(kernel_size=2, stride=2)(x)
# print(f'Down {idx + 1}.3\t', x.shape) if not idx < self.layers - 1 else None
return x

def decode(self, x: torch.Tensor) -> torch.Tensor:
for idx, up in enumerate(self.ups):
# print(f'Up {idx + 1}.1\t\t', x.shape)
conv_trans = nn.ConvTranspose2d(
64 * (2 ** (self.layers - idx)),
64 * (2 ** (self.layers - idx - 1)),
kernel_size=2,
stride=2,
padding=0
).to(self.device)
x = conv_trans(x)
# print(f'Up {idx + 1}.2\t\t', x.shape)
cropped = self.crop_tensor(self.to_be_cropped_list[idx], x)
x = torch.cat([x, cropped], dim=1)
# print(f'Up {idx + 1}.3\t\t', x.shape)
x = up(x) # conv
# print(f'Up {idx + 1}.4\t\t', x.shape) if not idx < self.layers - 1 else None
return x

前向传播函数的实现如下:

1
2
3
4
5
6
7
8
9
10
def forward(self, x: torch.Tensor) -> torch.Tensor:
# print('Input Size\t', x.shape)
x = self.input_encoder(x)
self.to_be_cropped_list.insert(0, x)
# print('Encoded Input\t', x.shape)
x = self.encode(x)
x = self.decode(x)
x = self.output_decoder(x)
# print('Output Size\t', x.shape)
return x

主函数:

1
2
3
4
5
6
7
8
9
10
11
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(in_channels=1, out_channels=1, device=device, layers=4).to(device)
model.eval()
H = 512
W = 512
input_tensor = torch.randn(1, 1, H, W).to(device)
output_tensor = model.forward(input_tensor)

if __name__ == "__main__":
main()

注意到这里的代码将打印张量维度的语句都注释掉了。运行这些语句可以方便地查看张量的维度变化情况(#处为本文中手动添加的注释):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
Input Size       torch.Size([1, 1, 512, 512])
Encoded Input torch.Size([1, 64, 512, 512])
Down 1.1 torch.Size([1, 64, 512, 512])
# conv
Down 1.2 torch.Size([1, 128, 512, 512])
# maxpool
Down 2.1 torch.Size([1, 128, 256, 256])
# conv
Down 2.2 torch.Size([1, 256, 256, 256])
# maxpool
Down 3.1 torch.Size([1, 256, 128, 128])
# conv
Down 3.2 torch.Size([1, 512, 128, 128])
# maxpool
Down 4.1 torch.Size([1, 512, 64, 64])
# conv
Down 4.2 torch.Size([1, 1024, 64, 64])
# maxpool
Down 4.3 torch.Size([1, 1024, 32, 32])
Up 1.1 torch.Size([1, 1024, 32, 32])
# conv_trans
Up 1.2 torch.Size([1, 512, 64, 64])
# crop and cat
Up 1.3 torch.Size([1, 1024, 64, 64])
# conv
Up 2.1 torch.Size([1, 512, 64, 64])
# conv_trans
Up 2.2 torch.Size([1, 256, 128, 128])
# crop and cat
Up 2.3 torch.Size([1, 512, 128, 128])
# conv
Up 3.1 torch.Size([1, 256, 128, 128])
# conv_trans
Up 3.2 torch.Size([1, 128, 256, 256])
# crop and cat
Up 3.3 torch.Size([1, 256, 256, 256])
# conv
Up 4.1 torch.Size([1, 128, 256, 256])
# conv_trans
Up 4.2 torch.Size([1, 64, 512, 512])
# crop and cat
Up 4.3 torch.Size([1, 128, 512, 512])
# conv
Up 4.4 torch.Size([1, 64, 512, 512])
Output Size torch.Size([1, 1, 512, 512])

可以看到,在编码器部分,张量的尺寸逐渐减小,通道数逐渐增多;而在解码器部分,张量的尺寸逐渐增大,通道数逐渐减少。

  • Title: UNet结构介绍
  • Author: Jachin Zhang
  • Created at : 2025-02-17 19:48:58
  • Updated at : 2025-02-17 22:26:05
  • Link: https://jachinzhang1.github.io/2025/02/17/unet/
  • License: This work is licensed under CC BY-NC-SA 4.0.
Comments
On this page
UNet结构介绍