Transformer 代码剖析2 - 模型训练 (pytorch实现)
一、模型初始化模块
参考:项目代码
1.1 参数统计函数
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
技术解析:
numel()
方法计算张量元素总数requires_grad
筛选需要梯度更新的参数- 统计结果反映模型复杂度,典型Transformer-base约65M参数
1.2 权重初始化
def initialize_weights(m):
if hasattr(m, 'weight') and m.weight.dim() > 1:
nn.init.kaiming_uniform_(m.weight.data)
初始化原理:
- Kaiming初始化针对ReLU族激活函数设计
- 保持前向传播时方差一致性
- 公式: W ∼ U ( − 6 / n i n , 6 / n i n ) W \sim U(-\sqrt{6/n_{in}}, \sqrt{6/n_{in}}) W∼U(−6/nin,6/nin)
1.3 模型实例化
model = Transformer(
src_pad_idx=src_pad_idx,
trg_pad_idx=trg_pad_idx,
trg_sos_idx=trg_sos_idx,
d_model=d_model,
enc_voc_size=enc_voc_size,
dec_voc_size=dec_voc_size,
max_len=max_len,
ffn_hidden=ffn_hidden,
n_head=n_heads,
n_layers=n_layers,
drop_prob=drop_prob,
device=device).to(device)
关键参数解析:
参数 | 典型值 | 作用 |
---|---|---|
d_model | 512 | 向量表征维度 |
n_head | 8 | 注意力头数量 |
ffn_hidden | 2048 | 前馈网络隐层维度 |
n_layers | 6 | 编码器/解码器堆叠层数 |
drop_prob | 0.1 | Dropout概率 |
二、训练准备模块
2.1 优化器配置
optimizer = Adam(
params=model.parameters(),
lr=init_lr,
weight_decay=weight_decay,
eps=adam_eps)
Adam优化器数学原理:
θ
t
+
1
=
θ
t
−
η
v
^
t
+
ϵ
m
^
t
\theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{\hat{v}_t} + \epsilon}\hat{m}_t
θt+1=θt−v^t+ϵηm^t
其中
m
^
t
\hat{m}_t
m^t和
v
^
t
\hat{v}_t
v^t为一阶、二阶矩估计的偏差修正项
2.2 学习率调度器
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer=optimizer,
verbose=True,
factor=factor,
patience=patience)
调度策略:
- 监控验证集损失变化
- 当损失停滞时按factor比例(典型0.5)衰减学习率
- patience=5表示连续5次无改善触发衰减
2.3 损失函数
criterion = nn.CrossEntropyLoss(ignore_index=src_pad_idx)
Padding处理机制:
- 通过ignore_index屏蔽填充符的梯度计算
- 数学表达式修正为:
L = − ∑ i = 1 n y i log p i ⋅ I ( y i ≠ pad ) \mathcal{L} = -\sum_{i=1}^{n} y_i \log p_i \cdot \mathbb{I}(y_i \neq \text{pad}) L=−i=1∑nyilogpi⋅I(yi=pad)
三、训练与评估模块
3.1 训练函数
def train(model, iterator, optimizer, criterion, clip):
model.train()
epoch_loss = 0
for i, batch in enumerate(iterator):
src = batch.src
trg = batch.trg
optimizer.zero_grad()
output = model(src, trg[:, :-1])
output_reshape = output.contiguous().view(-1, output.shape[-1])
trg = trg[:, 1:].contiguous().view(-1)
loss = criterion(output_reshape, trg)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
optimizer.step()
epoch_loss += loss.item()
print('step :', round((i / len(iterator)) * 100, 2), '% , loss :', loss.item())
return epoch_loss / len(iterator)
关键技术点:
- 教师强制(Teacher Forcing):使用真实目标序列作为解码器输入
- 序列切片
trg[:, :-1]
去除终止符 - 梯度裁剪防止梯度爆炸
3.2 评估函数
def evaluate(model, iterator, criterion):
model.eval()
epoch_loss = 0
batch_bleu = []
with torch.no_grad():
for i, batch in enumerate(iterator):
src = batch.src
trg = batch.trg
output = model(src, trg[:, :-1])
output_reshape = output.contiguous().view(-1, output.shape[-1])
trg = trg[:, 1:].contiguous().view(-1)
loss = criterion(output_reshape, trg)
epoch_loss += loss.item()
total_bleu = []
for j in range(batch_size):
try:
trg_words = idx_to_word(batch.trg[j], loader.target.vocab)
output_words = output[j].max(dim=1)[1]
output_words = idx_to_word(output_words, loader.target.vocab)
bleu = get_bleu(hypotheses=output_words.split(), reference=trg_words.split())
total_bleu.append(bleu)
except:
pass
total_bleu = sum(total_bleu) / len(total_bleu)
batch_bleu.append(total_bleu)
batch_bleu = sum(batch_bleu) / len(batch_bleu)
return epoch_loss / len(iterator), batch_bleu
BLEU计算原理:
B
L
E
U
=
B
P
⋅
exp
(
∑
n
=
1
N
w
n
log
p
n
)
BLEU = BP \cdot \exp\left(\sum_{n=1}^N w_n \log p_n\right)
BLEU=BP⋅exp(n=1∑Nwnlogpn)
其中:
- BP为简洁惩罚因子
- p n p_n pn为n-gram精度
- w n w_n wn为各阶权重(通常平均加权)
四、运行控制模块
4.1 训练循环
def run(total_epoch, best_loss):
train_losses, test_losses, bleus = [], [], []
for step in range(total_epoch):
start_time = time.time()
train_loss = train(model, train_iter, optimizer, criterion, clip)
valid_loss, bleu = evaluate(model, valid_iter, criterion)
end_time = time.time()
if step > warmup:
scheduler.step(valid_loss)
train_losses.append(train_loss)
test_losses.append(valid_loss)
bleus.append(bleu)
epoch_mins, epoch_secs = epoch_time(start_time, end_time)
if valid_loss < best_loss:
best_loss = valid_loss
torch.save(model.state_dict(), 'saved/model-{0}.pt'.format(valid_loss))
f = open('result/train_loss.txt', 'w')
f.write(str(train_losses))
f.close()
f = open('result/bleu.txt', 'w')
f.write(str(bleus))
f.close()
f = open('result/test_loss.txt', 'w')
f.write(str(test_losses))
f.close()
print(f'Epoch: {step + 1} | Time: {epoch_mins}m {epoch_secs}s')
print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
print(f'\tVal Loss: {valid_loss:.3f} | Val PPL: {math.exp(valid_loss):7.3f}')
print(f'\tBLEU Score: {bleu:.3f}')
模型保存策略:
- 采用验证损失作为保存标准
- 使用
model.state_dict()
保存参数快照 - 文件命名包含验证损失便于版本管理
五、工程实践要点
5.1 训练技巧
- Warm-up策略:前warmup个epoch不启动学习率衰减
- 混合精度训练:可结合
torch.cuda.amp
加速训练 - 梯度累积:小批量数据累积梯度模拟大批量效果
5.2 性能优化
torch.backends.cudnn.benchmark = True # 启用cuDNN自动优化器
torch.autograd.set_detect_anomaly(False) # 禁用异常检测提升速度
5.3 扩展实现
模型并行改造示例
class ParallelTransformer(Transformer):
def __init__(self, ...):
super().__init__(...)
self.encoder = nn.DataParallel(self.encoder)
self.decoder = nn.DataParallel(self.decoder)
本节从代码实现到理论机制进行了多角度解析,完整保留原始代码结构的同时通过流程图解耦了各模块的运作机制。实际应用中可根据任务规模调整超参数,建议在8*V100 GPU环境下进行大规模预训练,结合混合精度训练提升训练效率。
源码(附):
"""
@author : Hyunwoong
@when : 2019-10-22
@homepage : https://github.com/gusdnd852
"""
import math
import time
from torch import nn, optim
from torch.optim import Adam
from data import *
from models.model.transformer import Transformer
from util.bleu import idx_to_word, get_bleu
from util.epoch_timer import epoch_time
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def initialize_weights(m):
if hasattr(m, 'weight') and m.weight.dim() > 1:
nn.init.kaiming_uniform(m.weight.data)
model = Transformer(src_pad_idx=src_pad_idx,
trg_pad_idx=trg_pad_idx,
trg_sos_idx=trg_sos_idx,
d_model=d_model,
enc_voc_size=enc_voc_size,
dec_voc_size=dec_voc_size,
max_len=max_len,
ffn_hidden=ffn_hidden,
n_head=n_heads,
n_layers=n_layers,
drop_prob=drop_prob,
device=device).to(device)
print(f'The model has {count_parameters(model):,} trainable parameters')
model.apply(initialize_weights)
optimizer = Adam(params=model.parameters(),
lr=init_lr,
weight_decay=weight_decay,
eps=adam_eps)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
verbose=True,
factor=factor,
patience=patience)
criterion = nn.CrossEntropyLoss(ignore_index=src_pad_idx)
def train(model, iterator, optimizer, criterion, clip):
model.train()
epoch_loss = 0
for i, batch in enumerate(iterator):
src = batch.src
trg = batch.trg
optimizer.zero_grad()
output = model(src, trg[:, :-1])
output_reshape = output.contiguous().view(-1, output.shape[-1])
trg = trg[:, 1:].contiguous().view(-1)
loss = criterion(output_reshape, trg)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
optimizer.step()
epoch_loss += loss.item()
print('step :', round((i / len(iterator)) * 100, 2), '% , loss :', loss.item())
return epoch_loss / len(iterator)
def evaluate(model, iterator, criterion):
model.eval()
epoch_loss = 0
batch_bleu = []
with torch.no_grad():
for i, batch in enumerate(iterator):
src = batch.src
trg = batch.trg
output = model(src, trg[:, :-1])
output_reshape = output.contiguous().view(-1, output.shape[-1])
trg = trg[:, 1:].contiguous().view(-1)
loss = criterion(output_reshape, trg)
epoch_loss += loss.item()
total_bleu = []
for j in range(batch_size):
try:
trg_words = idx_to_word(batch.trg[j], loader.target.vocab)
output_words = output[j].max(dim=1)[1]
output_words = idx_to_word(output_words, loader.target.vocab)
bleu = get_bleu(hypotheses=output_words.split(), reference=trg_words.split())
total_bleu.append(bleu)
except:
pass
total_bleu = sum(total_bleu) / len(total_bleu)
batch_bleu.append(total_bleu)
batch_bleu = sum(batch_bleu) / len(batch_bleu)
return epoch_loss / len(iterator), batch_bleu
def run(total_epoch, best_loss):
train_losses, test_losses, bleus = [], [], []
for step in range(total_epoch):
start_time = time.time()
train_loss = train(model, train_iter, optimizer, criterion, clip)
valid_loss, bleu = evaluate(model, valid_iter, criterion)
end_time = time.time()
if step > warmup:
scheduler.step(valid_loss)
train_losses.append(train_loss)
test_losses.append(valid_loss)
bleus.append(bleu)
epoch_mins, epoch_secs = epoch_time(start_time, end_time)
if valid_loss < best_loss:
best_loss = valid_loss
torch.save(model.state_dict(), 'saved/model-{0}.pt'.format(valid_loss))
f = open('result/train_loss.txt', 'w')
f.write(str(train_losses))
f.close()
f = open('result/bleu.txt', 'w')
f.write(str(bleus))
f.close()
f = open('result/test_loss.txt', 'w')
f.write(str(test_losses))
f.close()
print(f'Epoch: {step + 1} | Time: {epoch_mins}m {epoch_secs}s')
print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
print(f'\tVal Loss: {valid_loss:.3f} | Val PPL: {math.exp(valid_loss):7.3f}')
print(f'\tBLEU Score: {bleu:.3f}')
if __name__ == '__main__':
run(total_epoch=epoch, best_loss=inf)