深度学习基础:基于DDPM模型的生成模型demo笔记(二)

Jachin Zhang

前情提要

本系列上篇博客中我们已经实现了UNet神经网络用于DDPM模型反向过程中的噪声预测,并成功在MNIST和CIFAR-10数据集上训练和测试。其中训练于MNIST数据集的模型生成情况良好,但训练于CIFAR-10数据集上的模型无法生成有意义的图片。

推测原因在于网络结构还不足以拟合多通道的复杂图像情形下的任务。因此我们打算加入自注意力机制增强模型对图像特征的提取,并试图将类别标签的特征融入模型使其可以实现类别条件的控制生成。此外,我们可以在训练过程中使用一些策略调整学习率,优化模型的训练过程。

引入类别标签特征

将类别特征融入模型训练的方法有很多,不过受到最近读的某篇文章的启发,我打算将该特征融入时间步编码中,这样操作简单,无需改动整体的网络结构。

networks.pyPositiona0lEmbedding类,展示改动后的结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class PositionalEncoding(nn.Module):
def __init__(self,
max_seq_len: int,
d_model: int,
n_classes: int=None):
super().__init__()

# Assume d_model is an even number for convenience
assert d_model % 2 == 0

# time step encoding
# ...existing code...

self.t_embedding = nn.Embedding(max_seq_len, d_model)
self.t_embedding.weight.data = pe
self.t_embedding.requires_grad_(False)

# label encoding
self.use_condition = n_classes is not None
if self.use_condition:
self.label_embedding = nn.Embedding(n_classes, d_model)

def forward(self, t, label=None):
t_emb = self.t_embedding(t)
if self.use_condition and label is not None:
label_emb = self.label_embedding(label)
return t_emb + label_emb
return t_emb
为了区分时间步编码和类别标签编码,我们将时间步编码器更名为t_embeddding,类别编码器则是label_embedding,这两个编码器的输出分别是t_emblabel_emb。将所得的时间步编码和类别编码相加作为最终输出即可,该输出同时包含了两部分的特征(原理和残差网络有些类似)。

为此,神经网络要作一些微小的改动,即将标签输入包括进去。

首先是UNet初始化函数的输入:

1
2
3
4
5
6
7
class UNet(nn.Module):
def __init__(self,
n_steps,
channels=[10, 20, 40, 80],
pe_dim=10,
residual=False,
n_classes: int=10):
也就是加上一个类别的数量。找到该类中的self.pe的定义,改为:
1
self.pe = PositionalEncoding(n_steps, pe_dim, n_classes)
然后是该类的forward方法:
1
2
3
4
5
def forward(self, x, t, label=None):
n = t.shape[0]
t = self.pe(t, label)
encoder_outs = []
# ...existing code...

最后在该模块的build_network函数中做如下更改:

1
2
3
4
5
6
def build_network(n_steps, 
channels=None,
pe_dim=None,
residual=True,
n_classes=10):
return UNet(n_steps, channels, pe_dim, residual, n_classes)

其实就是在各个输入的地方加上类别标签相关的参数。另外,别忘记在配置文件的network部分加一条:

1
2
network:
n_classes: 10

自注意力机制引入

networks.py中添加SelfAttention类,实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class SelfAttention(nn.Module):
def __init__(self, channels: int, num_heads: int):
super().__init__()
self.channels = channels
self.mha = nn.MultiheadAttention(channels, num_heads, batch_first=True)
self.ln = nn.LayerNorm([channels])
self.ff_self = nn.Sequential(
nn.LayerNorm([channels]),
nn.Linear(channels, channels),
nn.GELU(),
nn.Linear(channels, channels),
)

def forward(self, x: torch.Tensor):
size = x.shape[-2:]
x = x.view(x.shape[0], self.channels, -1).transpose(1, 2)
x_ln = self.ln(x)
attention_value, _ = self.mha(x_ln, x_ln, x_ln)
attention_value = attention_value + x
attention_value = self.ff_self(attention_value) + attention_value
return attention_value.transpose(1, 2).view(x.shape[0], self.channels, *size)

该自注意力模块先对传入的图片张量进行标准化,再对其作多头注意力运算,将注意力值与图片张量残差相加后用一个前馈网络对其进行进一步处理。我们将这个模块应用于UNet的mid block中。在Unet类的初始化方法参数中增加一个use_attention: bool=True,并将该方法中的mid block部分做如下改动:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# mid block
self.pe_mid = nn.Linear(pe_dim, prev_channel)
channel = channels[-1]
if not use_attention:
self.mid = nn.Sequential(
UNetBlock((prev_channel, Hs[-1], Ws[-1]),
prev_channel,
channel,
residual=residual),
UNetBlock((channel, Hs[-1], Ws[-1]),
channel,
channel,
residual=residual)
)
else:
self.mid = nn.Sequential(
UNetBlock((prev_channel, Hs[-1], Ws[-1]),
prev_channel,
channel,
residual=residual),
SelfAttention(channel, 4),
UNetBlock((channel, Hs[-1], Ws[-1]),
channel,
channel,
residual=residual)
)
prev_channel = channel

相应地修改build_network函数的参数,并在配置文件添加:

1
2
network:
attention: true

学习率调整策略

尝试使用余弦退火调度策略动态调整学习率。计算公式为:

其中η_t为当前学习率,η_max和η_min分别为最大(初始自定)和最小学习率;T_cur为当前轮次,T_max为总轮次。学习率变化遵循余弦函数,从最大值平稳下降到最小值,下降速度在开始和结束时较慢,中间阶段较快。这有助于模型在初期快速接近最优解,中期稳定下降,后期做细微调整。

train.py中添加函数:

1
2
3
4
5
6
7
8
9
def get_lr_scheduler(optimizer):
if train_opts['lr_scheduler']['type'] == 'cosine':
# 配置文件中这里就可以选'MSE'或'hybrid'了
return torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=train_opts['n_epochs'],
eta_min=1e-6
)
return None
对应地,在train()中找到定义optimizer处,添加一行语句:
1
scheduler = get_lr_scheduler(optimizer)
在训练循环中作如下修改:
1
2
3
4
5
6
7
8
for epoch in range(n_epochs):
# ...existing code...
for x, label in tqdm(dataloader, ncols=60):
# ...existing code...
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step() if scheduler is not None else None # added line

以上就是本次的主要调整。因为重构的部分比较零碎,而上面介绍中对代码的修改难免有疏漏的地方难以排查,所以在这里记录这部分修改完成后完整的代码记录。

`dataset.py`
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torchvision
from torchvision.transforms import transforms
from torch.utils.data import DataLoader

import yaml
with open('options.yml', 'r') as f:
opt = yaml.safe_load(f)
dataset_opts = opt['dataset']


def get_dataset(batch_size):

trainset = None
dataset_name = dataset_opts['name']
root = dataset_opts['root']
if dataset_name == 'mnist':
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
trainset = torchvision.datasets.MNIST(
root=root,
train=True,
download=True,
transform=transform
)
elif dataset_name == 'cifar10':
transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
trainset = torchvision.datasets.CIFAR10(
root=root,
train=True,
download=True,
transform=transform
)
trainloader = DataLoader(trainset, batch_size, shuffle=True)
return trainloader
`logger.py`
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import logging
import os

def get_logger(log_root, log_name='log.log'):
logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')

assert os.path.isdir(log_root), f"{log_root} is not a directory"
handler = logging.FileHandler(os.path.join(log_root, log_name))
handler.setFormatter(formatter)

logger.addHandler(handler)
return logger

def dict2info(info_name: str, opt_dict: dict):
info_str = f'{info_name}:'
for key in opt_dict.keys():
info_str = info_str + f'\n\t[{key}: {opt_dict[key]}]'
return info_str
`modules/ddpm.py`
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch

class DDPM():
def __init__(self,
device,
n_steps: int,
min_beta: float = 0.0001,
max_beta: float = 0.02):
self.n_steps = n_steps
self.betas = torch.linspace(min_beta, max_beta, n_steps).to(device)
self.alphas = 1 - self.betas
self.alphas_bars = torch.empty_like(self.alphas)
product = 1
for i, alpha in enumerate(self.alphas):
product *= alpha
self.alphas_bars[i] = product

def sample_forward(self, x, t, eps=None):
alpha_bar = self.alphas_bars[t].reshape(-1, 1, 1, 1)
if eps is None:
eps = torch.randn_like(x)
res = eps * torch.sqrt(1 - alpha_bar) + torch.sqrt(alpha_bar) * x
return res

def sample_backward(self, img_shape, net, device, simple_var=True, label=None):
x = torch.randn(img_shape).to(device)
net = net.to(device)

if label is not None:
if not isinstance(label, torch.Tensor):
label = torch.tensor([label] * img_shape[0],
dtype=torch.long).to(device)
label = label.reshape(-1, 1) # (batch_size, 1)

for t in range(self.n_steps-1, -1, -1):
x = self.sample_backward_step(x, t, net, simple_var, label)
return x

def sample_backward_step(self, x_t, t, net, simple_var=True, label=None):
n = x_t.shape[0] # batch size
t_tensor = torch.tensor([t] * n, dtype=torch.long).to(x_t.device).unsqueeze(1)
eps = net(x_t, t_tensor, label)

if t == 0:
noise = 0
else:
# simple_var 用于控制取值方式
if simple_var:
var = self.betas[t]
else:
var = (1-self.alphas_bars[t-1])/(1-self.alphas_bars[t])*self.betas[t]
noise = torch.randn_like(x_t)
noise *= torch.sqrt(var)

mean = (x_t - (1-self.alphas[t])/torch.sqrt(1-self.alphas_bars[t])*eps) / \
torch.sqrt(self.alphas[t])

x_t = mean + noise
return x_t # updated image

`modules/networks.py`
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
import torch
import torch.nn as nn
import torch.nn.functional as F

import yaml

with open('options.yml', 'r') as f:
opt = yaml.safe_load(f)
img_shape = opt['dataset']['img_shape'][opt['dataset']['name']]


class PositionalEncoding(nn.Module):
def __init__(self,
max_seq_len: int,
d_model: int,
n_classes: int=None):
super().__init__()

# Assume d_model is an even number for convenience
assert d_model % 2 == 0

# time step encoding
pe = torch.zeros(max_seq_len, d_model)
i_seq = torch.linspace(0, max_seq_len - 1, max_seq_len)
j_seq = torch.linspace(0, d_model - 2, d_model // 2)
pos, two_i = torch.meshgrid(i_seq, j_seq)
pe_2i = torch.sin(pos / 1e4 ** (two_i / d_model))
pe_2i_1 = torch.cos(pos / 1e4 ** (two_i / d_model))
pe = torch.stack((pe_2i, pe_2i_1), 2).reshape(max_seq_len, d_model)

self.t_embedding = nn.Embedding(max_seq_len, d_model)
self.t_embedding.weight.data = pe
self.t_embedding.requires_grad_(False)

# label encoding
self.use_condition = n_classes is not None
if self.use_condition:
self.label_embedding = nn.Embedding(n_classes, d_model)

def forward(self, t, label=None):
t_emb = self.t_embedding(t)
if self.use_condition and label is not None:
label_emb = self.label_embedding(label)
return t_emb + label_emb
return t_emb


class SelfAttention(nn.Module):
def __init__(self, channels: int, num_heads: int):
super().__init__()
self.channels = channels
self.mha = nn.MultiheadAttention(channels, num_heads, batch_first=True)
self.ln = nn.LayerNorm([channels])
self.ff_self = nn.Sequential(
nn.LayerNorm([channels]),
nn.Linear(channels, channels),
nn.GELU(),
nn.Linear(channels, channels),
)

def forward(self, x: torch.Tensor):
size = x.shape[-2:]
x = x.view(x.shape[0], self.channels, -1).transpose(1, 2)
x_ln = self.ln(x)
attention_value, _ = self.mha(x_ln, x_ln, x_ln)
attention_value = attention_value + x
attention_value = self.ff_self(attention_value) + attention_value
return attention_value.transpose(1, 2).view(x.shape[0], self.channels, *size)


class UNetBlock(nn.Module):
def __init__(self,
shape,
in_c,
out_c,
residual=False):
super().__init__()
self.ln = nn.LayerNorm(shape)
self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=1)
self.activation = nn.ReLU()
self.residual = residual
if residual:
if in_c == out_c:
self.residual_conv = nn.Identity()
else:
self.residual_conv = nn.Conv2d(in_c, out_c, kernel_size=1)

def forward(self, x):
out = self.activation(self.conv1(self.ln(x)))
out = self.conv2(out)
if self.residual:
out += self.residual_conv(x)
out = self.activation(out)
return out


class UNet(nn.Module):
def __init__(self,
n_steps: int,
channels: list=[10, 20, 40, 80],
pe_dim: int=10,
residual: bool=False,
n_classes: int=10,
use_attention: bool = True):
super().__init__()
C, H, W = img_shape[0], img_shape[1], img_shape[2]
n_layers = len(channels)
Hs = [H]
Ws = [W]
cH = H
cW = W
for _ in range(n_layers - 1):
cH //= 2
cW //= 2
Hs.append(cH)
Ws.append(cW)

self.pe = PositionalEncoding(n_steps, pe_dim, n_classes)
self.encoders = nn.ModuleList()
self.decoders = nn.ModuleList()
self.pe_linears_en = nn.ModuleList()
self.pe_linears_de = nn.ModuleList()
self.downs = nn.ModuleList()
self.ups = nn.ModuleList()
prev_channel = C

# down blocks
for channel, cH, cW in zip(channels[0:-1], Hs[0:-1], Ws[0:-1]):
self.pe_linears_en.append(
nn.Sequential(nn.Linear(pe_dim, prev_channel),
nn.ReLU(),
nn.Linear(prev_channel, prev_channel))
)
self.encoders.append(
nn.Sequential(
UNetBlock((prev_channel, cH, cW),
prev_channel,
channel,
residual=residual),
UNetBlock((channel, cH, cW),
channel,
channel,
residual=residual)
)
)
self.downs.append(nn.Conv2d(channel, channel, kernel_size=2, stride=2))
prev_channel = channel

# mid block
self.pe_mid = nn.Linear(pe_dim, prev_channel)
channel = channels[-1]
if not use_attention:
self.mid = nn.Sequential(
UNetBlock((prev_channel, Hs[-1], Ws[-1]),
prev_channel,
channel,
residual=residual),
UNetBlock((channel, Hs[-1], Ws[-1]),
channel,
channel,
residual=residual)
)
else:
self.mid = nn.Sequential(
UNetBlock((prev_channel, Hs[-1], Ws[-1]),
prev_channel,
channel,
residual=residual),
SelfAttention(channel, 4),
UNetBlock((channel, Hs[-1], Ws[-1]),
channel,
channel,
residual=residual)
)
prev_channel = channel

# up blocks
for channel, cH, cW in zip(channels[-2::-1], Hs[-2::-1], Ws[-2::-1]):
self.pe_linears_de.append(nn.Linear(pe_dim, prev_channel))
self.decoders.append(
nn.Sequential(
UNetBlock((channel * 2, cH, cW),
channel * 2,
channel,
residual=residual),
UNetBlock((channel, cH, cW),
channel,
channel,
residual=residual)
)
)
self.ups.append(nn.ConvTranspose2d(prev_channel, channel, kernel_size=2, stride=2))
prev_channel = channel

self.conv_out = nn.Conv2d(prev_channel, C, kernel_size=3, stride=1, padding=1)

def forward(self, x, t, label=None):
n = t.shape[0]
t = self.pe(t, label)
encoder_outs = []
for pe_linear, encoder, down in zip(self.pe_linears_en, self.encoders, self.downs):
pe = pe_linear(t).reshape(n, -1, 1, 1)
x = encoder(x + pe)
encoder_outs.append(x)
x = down(x)
pe = self.pe_mid(t).reshape(n, -1, 1, 1)
x = self.mid(x + pe)
for pe_linear, decoder, up, encoder_out in zip(self.pe_linears_de, self.decoders,
self.ups, encoder_outs[::-1]):
pe = pe_linear(t).reshape(n, -1, 1, 1)
x = up(x)
pad_x = encoder_out.shape[2] - x.shape[2]
pad_y = encoder_out.shape[3] - x.shape[3]
x = F.pad(x,
(pad_x // 2, pad_x - pad_x//2, pad_y // 2, pad_y - pad_y//2))
x = torch.cat((encoder_out, x), dim=1)
x = decoder(x + pe)
x = self.conv_out(x)
return x

def build_network(n_steps: int,
channels: list,
pe_dim: bool=None,
residual: bool=True,
n_classes: int=10,
use_attention: bool=True):
return UNet(n_steps, channels, pe_dim, residual, n_classes, use_attention)

`train.py`
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch
import torch.nn as nn
from dataset import get_dataset
from modules.networks import build_network
from modules.ddpm import DDPM
from tqdm import tqdm

from tensorboardX import SummaryWriter
import os
import yaml
from datetime import datetime
from logger import get_logger, dict2info

## initialization
with open('options.yml', 'r') as f:
opt = yaml.safe_load(f)
f.close()
curr_time = datetime.now().strftime("%Y%m%d-%H%M%S")
save_root = f'models/{curr_time}-{opt['dataset']['name']}'
os.makedirs(os.path.join(save_root, 'ckpts'))
logger = get_logger(save_root, 'train.log')
train_opts = opt['train']
writer = SummaryWriter()


def get_lr_scheduler(optimizer):
if train_opts['lr_scheduler']['type'] == 'cosine':
return torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=train_opts['n_epochs'],
eta_min=1e-6
)
return None

def train(ddpm: DDPM, dataloader, net: nn.Module, device):
n_epochs = train_opts['n_epochs']
n_steps = ddpm.n_steps
net = net.to(device)
if train_opts['loss'] == 'MSE':
loss_fn = nn.MSELoss()
# TODO: more loss functions
else:
raise ValueError(f"Unknown loss function: {train_opts['loss']}")
optimizer = torch.optim.Adam(net.parameters(), float(train_opts['lr']))
scheduler = get_lr_scheduler(optimizer)

logger.info("Start training.")
step = 0
resume_epochs_passed = 0

ckpt_path: str = train_opts['resume']
if ckpt_path is not None:
if os.path.exists(ckpt_path):
net.load_state_dict(torch.load(ckpt_path))
print(f'Load model from {ckpt_path}.')
logger.info(f'Load model from {ckpt_path}.')
resume_epochs_passed = int(ckpt_path.split('/')[-1].split('.')[0].split('_')[-1]) + 1
step += resume_epochs_passed * len(dataloader)

for epoch in range(n_epochs):
total_loss = 0
truth_epoch = epoch + resume_epochs_passed
for x, label in tqdm(dataloader, ncols=60):
batch_size = x.shape[0]
x, label = x.to(device), label.to(device)

t = torch.randint(0, n_steps, (batch_size, )).to(device)
eps = torch.randn_like(x).to(device)
x_t = ddpm.sample_forward(x, t, eps)
eps_theta = net(x_t, t.reshape(batch_size, 1), label.reshape(batch_size, 1))
loss = loss_fn(eps_theta, eps)
writer.add_scalar('loss/step', loss, step)
total_loss += loss
step += 1

optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step() if scheduler is not None else None

epoch_loss = total_loss / len(dataloader)
logger.info(f"epoch {truth_epoch} | loss {epoch_loss}")
writer.add_scalar('loss/epoch', epoch_loss, truth_epoch)
print(f'epoch {truth_epoch} | loss {epoch_loss}')
save_path = os.path.join(save_root, 'ckpts', f'epoch_{truth_epoch}.pth')
torch.save(net.state_dict(), save_path)
print(f"Model checkpoint has been saved into {save_path}.")

logger.info("Training stage finished.")

if __name__ == '__main__':
# get dataloader
dataloader = get_dataset(batch_size=train_opts['batch_size'])
logger.info(f'dataset: {opt['dataset']['name']}')

# set device
device = opt['device']
logger.info(f'device: {device}')

# DDPM settings
ddpm_opts = opt['ddpm']
ddpm_info = dict2info("DDPM", dict(ddpm_opts))
logger.info(ddpm_info)
ddpm = DDPM(device,
ddpm_opts['n_steps'],
float(ddpm_opts['min_beta']),
float(ddpm_opts['max_beta'])
)

# network settings
network_opts = opt['network']
net_info = dict2info("network", dict(network_opts))
logger.info(net_info)
net = build_network(n_steps=ddpm_opts['n_steps'],
channels=network_opts['channels'],
pe_dim=network_opts['pe_dim'],
residual=network_opts['residual'],
n_classes=network_opts['n_classes'],
use_attention=network_opts['attention'])

# start training
train_info = dict2info("training options", dict(train_opts))
logger.info(train_info)
train(ddpm, dataloader, net, device)

`test.py`
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
import torch.nn as nn
import cv2
import einops
import numpy as np
import yaml
import os
from tqdm import tqdm

from modules.ddpm import DDPM
from modules.networks import build_network
from logger import get_logger, dict2info

from datetime import datetime
curr_time = datetime.now().strftime("%Y%m%d-%H%M%S")

with open('options.yml', 'r') as f:
opt = yaml.safe_load(f)
f.close()
test_opts = opt['test']


def generate(ddpm: DDPM,
net: nn.Module,
output_path: str,
n_sample: int,
device,
img_shape,
simple_var=True,
label=None):
net = net.to(device)
net = net.eval()
C, H, W = img_shape[0], img_shape[1], img_shape[2]

with torch.no_grad():
shape = (n_sample, C, H, W)
imgs = ddpm.sample_backward(shape,
net,
device,
simple_var,
label).detach().to(device)
imgs = (imgs + 1) / 2 * 255
imgs = imgs.clamp(0, 255)
imgs = einops.rearrange(imgs,
'(b1 b2) c h w -> (b1 h) (b2 w) c',
b1=int(n_sample**0.5))
imgs = imgs.cpu()
imgs = imgs.numpy().astype(np.uint8)
cv2.imwrite(output_path, imgs)


if __name__ == '__main__':
# DDPM settings
ddpm_opts = opt['ddpm']
ddpm = DDPM(opt['device'],
int(ddpm_opts['n_steps']),
float(ddpm_opts['min_beta']),
float(ddpm_opts['max_beta']))

# network settings
network_opts = opt['network']
network_opts_dict = dict(network_opts)
net = build_network(n_steps=ddpm_opts['n_steps'],
channels=network_opts['channels'],
pe_dim=network_opts['pe_dim'],
residual=network_opts['residual'])
# load network checkpoint
ckpt_path = test_opts['ckpt_path']
assert os.path.exists(ckpt_path), f'{ckpt_path} is an invalid path.'
net.load_state_dict(torch.load(ckpt_path))

# output path settings
output_dir = os.path.join(test_opts['output_dir'], f'{curr_time}')
os.makedirs(output_dir) if not os.path.exists(output_dir) else None

# logger settings
logger = get_logger(output_dir, 'test.log')
logger.info(dict2info("DDPM", dict(ddpm_opts)))
logger.info(dict2info("network", dict(network_opts)))
logger.info(dict2info("test options", dict(test_opts)))

# other settings
n_sample = test_opts['n_samples']
device = opt['device']
img_shape = opt['dataset']['img_shape'][opt['dataset']['name']]

# sample images from noise
print(f'Output Directory: {output_dir}')
print(f'Classes Tested: {test_opts['classes']}')

for label in tqdm(test_opts['classes'], ncols=60):
output_path = os.path.join(output_dir, f'{label}.png')
generate(ddpm=ddpm,
net=net,
output_path=output_path,
n_sample=n_sample,
device=device,
img_shape=img_shape,
simple_var=True,
label=label)
logger.info(f"Save result to {output_path}")
`options.yml`
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
device: 'cuda:1'

dataset:
name: 'cifar10'
root: './cache'
img_shape:
mnist: [1, 28, 28]
cifar10: [3, 32, 32]

train:
n_epochs: 50
batch_size: 64
loss: 'MSE'
lr: 2e-4
lr_scheduler:
type: 'cosine'
warmup_epochs: 5
resume: ~

test:
output_dir: './results'
ckpt_path: ~
n_samples: 81
classes: [0,1,2,3,4,5,6,7,8,9]

network:
channels: [64, 128, 256, 512, 1024]
pe_dim: 256
residual: true
attention: true
n_classes: 10

ddpm:
n_steps: 1000
min_beta: 1e-4
max_beta: 2e-2

结果展示与分析

首先看看改进后的模型在MNIST数据集上的生成结果(部分):

训练于MNIST数据集的模型在标签0,6,9条件下的生成结果
训练于MNIST数据集的模型在标签0,6,9条件下的生成结果

可以看到,这种标签特征融合的方式还是非常有效果的,模型针对不同类别特征顺利地生成了明显可辨识的阿拉伯数字,准确率还算不错。接下来看看CIFAR-10数据集的生成结果(部分):

训练于CIFAR-10数据集的模型在标签0(飞机),6(青蛙),9(卡车)条件下的生成结果
训练于CIFAR-10数据集的模型在标签0(飞机),6(青蛙),9(卡车)条件下的生成结果

这个结果是模型在训练20个周期后生成的,之所以图片有明显的黑框是因为处理该数据集时使用了transforms.RandomCrop(32, padding=4)。与上次的色块不同,这次能明显看出来模型已经在试图生成具有现实语义的图案了,但对应的类别特征还十分不明显,基本无法分辨。如果我们尝试继续训练一段时间然后看看效果呢?

过拟合(疑似?)的训练于CIFAR-10上的模型
过拟合(疑似?)的训练于CIFAR-10上的模型

很不幸,我们并不能通过简单延长训练阶段来提升模型效果。模型对噪声的预测逐渐发生了偏移,因此集中产生了纯黑或纯白的图像(集中生成纯黑或纯白图像的现象不会出现在同一个模型中,即一个模型的预测发生较大偏移时,要么生成纯黑图像要么生成纯白图像)。为什么会产生这种现象?

我的分析是:网络结构还不够强大,现有架构缺乏更妥善的注意力机制处理、更深的网络深度和针对RGB图像的特殊处理;对数据集进行适当的增强和准确的归一化也很重要(我当时简单地将数据集标准化的均值和方差设为0.5,但实际上要根据数据集本身的相关属性进行设定)。

下一步的改进计划是:

  1. 改良UNet网络结构

  2. 引入EMA方法

  3. 改进损失函数的衡量

事实上,由于该项目是我用来熟悉项目构建流程和相关方法的,所以很多改进策略并没有相关论文材料支撑,而是主要来自我的“一拍脑袋”和LLM的建议。所以我的改进并非都有效,事实上在实验过程中也确实发生负改进的情况(模型复杂度增加了但效果并没有变好)。所以我的方法仅供参考!该系列博客只是记录我的该demo构建历程,因此我将其放在【学习笔记】的类别而非【经验分享】。希望我的思考能够帮到读者!

该系列未完待续!

  • Title: 深度学习基础:基于DDPM模型的生成模型demo笔记(二)
  • Author: Jachin Zhang
  • Created at : 2025-03-06 22:26:07
  • Updated at : 2025-03-07 22:07:17
  • Link: https://jachinzhang1.github.io/2025/03/06/ddpm-project-2/
  • License: This work is licensed under CC BY-NC-SA 4.0.
Comments
On this page
深度学习基础:基于DDPM模型的生成模型demo笔记(二)