当前位置: 首页 > article >正文

Transformer 代码剖析2 - 模型训练 (pytorch实现)

一、模型初始化模块

参考:项目代码

1.1 参数统计函数

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
遍历模型参数
筛选可训练参数
统计参数数量
返回总数

技术解析:

  • numel()方法计算张量元素总数
  • requires_grad筛选需要梯度更新的参数
  • 统计结果反映模型复杂度,典型Transformer-base约65M参数

1.2 权重初始化

def initialize_weights(m):
    if hasattr(m, 'weight') and m.weight.dim() > 1:
        nn.init.kaiming_uniform_(m.weight.data)
检查模块属性
是否包含多维权重?
应用Kaiming初始化
跳过初始化

初始化原理:

  • Kaiming初始化针对ReLU族激活函数设计
  • 保持前向传播时方差一致性
  • 公式: W ∼ U ( − 6 / n i n , 6 / n i n ) W \sim U(-\sqrt{6/n_{in}}, \sqrt{6/n_{in}}) WU(6/nin ,6/nin )

1.3 模型实例化

model = Transformer(
    src_pad_idx=src_pad_idx,
    trg_pad_idx=trg_pad_idx,
    trg_sos_idx=trg_sos_idx,
    d_model=d_model,
    enc_voc_size=enc_voc_size,
    dec_voc_size=dec_voc_size,
    max_len=max_len,
    ffn_hidden=ffn_hidden,
    n_head=n_heads,
    n_layers=n_layers,
    drop_prob=drop_prob,
    device=device).to(device)

关键参数解析:

参数典型值作用
d_model512向量表征维度
n_head8注意力头数量
ffn_hidden2048前馈网络隐层维度
n_layers6编码器/解码器堆叠层数
drop_prob0.1Dropout概率

二、训练准备模块

2.1 优化器配置

optimizer = Adam(
    params=model.parameters(),
    lr=init_lr,
    weight_decay=weight_decay,
    eps=adam_eps)

Adam优化器数学原理:
θ t + 1 = θ t − η v ^ t + ϵ m ^ t \theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{\hat{v}_t} + \epsilon}\hat{m}_t θt+1=θtv^t +ϵηm^t
其中 m ^ t \hat{m}_t m^t v ^ t \hat{v}_t v^t为一阶、二阶矩估计的偏差修正项

2.2 学习率调度器

scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer=optimizer,
    verbose=True,
    factor=factor,
    patience=patience)

调度策略:

  • 监控验证集损失变化
  • 当损失停滞时按factor比例(典型0.5)衰减学习率
  • patience=5表示连续5次无改善触发衰减

2.3 损失函数

criterion = nn.CrossEntropyLoss(ignore_index=src_pad_idx)

Padding处理机制:

  • 通过ignore_index屏蔽填充符的梯度计算
  • 数学表达式修正为:
    L = − ∑ i = 1 n y i log ⁡ p i ⋅ I ( y i ≠ pad ) \mathcal{L} = -\sum_{i=1}^{n} y_i \log p_i \cdot \mathbb{I}(y_i \neq \text{pad}) L=i=1nyilogpiI(yi=pad)

三、训练与评估模块

3.1 训练函数

def train(model, iterator, optimizer, criterion, clip):
    model.train()
    epoch_loss = 0
    for i, batch in enumerate(iterator):
        src = batch.src
        trg = batch.trg

        optimizer.zero_grad()
        output = model(src, trg[:, :-1])
        output_reshape = output.contiguous().view(-1, output.shape[-1])
        trg = trg[:, 1:].contiguous().view(-1)

        loss = criterion(output_reshape, trg)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()

        epoch_loss += loss.item()
        print('step :', round((i / len(iterator)) * 100, 2), '% , loss :', loss.item())

    return epoch_loss / len(iterator) 
设置训练模式
遍历数据批次
清空梯度
前向传播
维度重塑
损失计算
反向传播
梯度裁剪
参数更新
累计损失

关键技术点:

  1. 教师强制(Teacher Forcing):使用真实目标序列作为解码器输入
  2. 序列切片trg[:, :-1]去除终止符
  3. 梯度裁剪防止梯度爆炸

3.2 评估函数

def evaluate(model, iterator, criterion):
    model.eval()
    epoch_loss = 0
    batch_bleu = []
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            src = batch.src
            trg = batch.trg
            output = model(src, trg[:, :-1])
            output_reshape = output.contiguous().view(-1, output.shape[-1])
            trg = trg[:, 1:].contiguous().view(-1)

            loss = criterion(output_reshape, trg)
            epoch_loss += loss.item()

            total_bleu = []
            for j in range(batch_size):
                try:
                    trg_words = idx_to_word(batch.trg[j], loader.target.vocab)
                    output_words = output[j].max(dim=1)[1]
                    output_words = idx_to_word(output_words, loader.target.vocab)
                    bleu = get_bleu(hypotheses=output_words.split(), reference=trg_words.split())
                    total_bleu.append(bleu)
                except:
                    pass

            total_bleu = sum(total_bleu) / len(total_bleu)
            batch_bleu.append(total_bleu)

    batch_bleu = sum(batch_bleu) / len(batch_bleu)
    return epoch_loss / len(iterator), batch_bleu
设置评估模式
禁用梯度计算
遍历数据批次
模型推理
损失计算
生成文本
BLEU分数计算
指标聚合

BLEU计算原理:
B L E U = B P ⋅ exp ⁡ ( ∑ n = 1 N w n log ⁡ p n ) BLEU = BP \cdot \exp\left(\sum_{n=1}^N w_n \log p_n\right) BLEU=BPexp(n=1Nwnlogpn)
其中:

  • BP为简洁惩罚因子
  • p n p_n pn为n-gram精度
  • w n w_n wn为各阶权重(通常平均加权)

四、运行控制模块

4.1 训练循环

def run(total_epoch, best_loss):
    train_losses, test_losses, bleus = [], [], []
    for step in range(total_epoch):
        start_time = time.time()
        train_loss = train(model, train_iter, optimizer, criterion, clip)
        valid_loss, bleu = evaluate(model, valid_iter, criterion)
        end_time = time.time()

        if step > warmup:
            scheduler.step(valid_loss)

        train_losses.append(train_loss)
        test_losses.append(valid_loss)
        bleus.append(bleu)
        epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        if valid_loss < best_loss:
            best_loss = valid_loss
            torch.save(model.state_dict(), 'saved/model-{0}.pt'.format(valid_loss))

        f = open('result/train_loss.txt', 'w')
        f.write(str(train_losses))
        f.close()

        f = open('result/bleu.txt', 'w')
        f.write(str(bleus))
        f.close()

        f = open('result/test_loss.txt', 'w')
        f.write(str(test_losses))
        f.close()

        print(f'Epoch: {step + 1} | Time: {epoch_mins}m {epoch_secs}s')
        print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
        print(f'\tVal Loss: {valid_loss:.3f} |  Val PPL: {math.exp(valid_loss):7.3f}')
        print(f'\tBLEU Score: {bleu:.3f}')
初始化记录列表
遍历训练轮次
计时开始
训练步骤
验证步骤
学习率调整
保存最佳模型
记录指标
输出训练信息

模型保存策略:

  • 采用验证损失作为保存标准
  • 使用model.state_dict()保存参数快照
  • 文件命名包含验证损失便于版本管理

五、工程实践要点

5.1 训练技巧

  1. Warm-up策略:前warmup个epoch不启动学习率衰减
  2. 混合精度训练:可结合torch.cuda.amp加速训练
  3. 梯度累积:小批量数据累积梯度模拟大批量效果

5.2 性能优化

torch.backends.cudnn.benchmark = True  # 启用cuDNN自动优化器 
torch.autograd.set_detect_anomaly(False)  # 禁用异常检测提升速度 

5.3 扩展实现

模型并行改造示例 
class ParallelTransformer(Transformer):
    def __init__(self, ...):
        super().__init__(...)
        self.encoder = nn.DataParallel(self.encoder)
        self.decoder = nn.DataParallel(self.decoder)

本节从代码实现到理论机制进行了多角度解析,完整保留原始代码结构的同时通过流程图解耦了各模块的运作机制。实际应用中可根据任务规模调整超参数,建议在8*V100 GPU环境下进行大规模预训练,结合混合精度训练提升训练效率。


源码(附):

"""
@author : Hyunwoong
@when : 2019-10-22
@homepage : https://github.com/gusdnd852
"""
import math
import time

from torch import nn, optim
from torch.optim import Adam

from data import *
from models.model.transformer import Transformer
from util.bleu import idx_to_word, get_bleu
from util.epoch_timer import epoch_time


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def initialize_weights(m):
    if hasattr(m, 'weight') and m.weight.dim() > 1:
        nn.init.kaiming_uniform(m.weight.data)


model = Transformer(src_pad_idx=src_pad_idx,
                    trg_pad_idx=trg_pad_idx,
                    trg_sos_idx=trg_sos_idx,
                    d_model=d_model,
                    enc_voc_size=enc_voc_size,
                    dec_voc_size=dec_voc_size,
                    max_len=max_len,
                    ffn_hidden=ffn_hidden,
                    n_head=n_heads,
                    n_layers=n_layers,
                    drop_prob=drop_prob,
                    device=device).to(device)

print(f'The model has {count_parameters(model):,} trainable parameters')
model.apply(initialize_weights)
optimizer = Adam(params=model.parameters(),
                 lr=init_lr,
                 weight_decay=weight_decay,
                 eps=adam_eps)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
                                                 verbose=True,
                                                 factor=factor,
                                                 patience=patience)

criterion = nn.CrossEntropyLoss(ignore_index=src_pad_idx)


def train(model, iterator, optimizer, criterion, clip):
    model.train()
    epoch_loss = 0
    for i, batch in enumerate(iterator):
        src = batch.src
        trg = batch.trg

        optimizer.zero_grad()
        output = model(src, trg[:, :-1])
        output_reshape = output.contiguous().view(-1, output.shape[-1])
        trg = trg[:, 1:].contiguous().view(-1)

        loss = criterion(output_reshape, trg)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()

        epoch_loss += loss.item()
        print('step :', round((i / len(iterator)) * 100, 2), '% , loss :', loss.item())

    return epoch_loss / len(iterator)


def evaluate(model, iterator, criterion):
    model.eval()
    epoch_loss = 0
    batch_bleu = []
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            src = batch.src
            trg = batch.trg
            output = model(src, trg[:, :-1])
            output_reshape = output.contiguous().view(-1, output.shape[-1])
            trg = trg[:, 1:].contiguous().view(-1)

            loss = criterion(output_reshape, trg)
            epoch_loss += loss.item()

            total_bleu = []
            for j in range(batch_size):
                try:
                    trg_words = idx_to_word(batch.trg[j], loader.target.vocab)
                    output_words = output[j].max(dim=1)[1]
                    output_words = idx_to_word(output_words, loader.target.vocab)
                    bleu = get_bleu(hypotheses=output_words.split(), reference=trg_words.split())
                    total_bleu.append(bleu)
                except:
                    pass

            total_bleu = sum(total_bleu) / len(total_bleu)
            batch_bleu.append(total_bleu)

    batch_bleu = sum(batch_bleu) / len(batch_bleu)
    return epoch_loss / len(iterator), batch_bleu


def run(total_epoch, best_loss):
    train_losses, test_losses, bleus = [], [], []
    for step in range(total_epoch):
        start_time = time.time()
        train_loss = train(model, train_iter, optimizer, criterion, clip)
        valid_loss, bleu = evaluate(model, valid_iter, criterion)
        end_time = time.time()

        if step > warmup:
            scheduler.step(valid_loss)

        train_losses.append(train_loss)
        test_losses.append(valid_loss)
        bleus.append(bleu)
        epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        if valid_loss < best_loss:
            best_loss = valid_loss
            torch.save(model.state_dict(), 'saved/model-{0}.pt'.format(valid_loss))

        f = open('result/train_loss.txt', 'w')
        f.write(str(train_losses))
        f.close()

        f = open('result/bleu.txt', 'w')
        f.write(str(bleus))
        f.close()

        f = open('result/test_loss.txt', 'w')
        f.write(str(test_losses))
        f.close()

        print(f'Epoch: {step + 1} | Time: {epoch_mins}m {epoch_secs}s')
        print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
        print(f'\tVal Loss: {valid_loss:.3f} |  Val PPL: {math.exp(valid_loss):7.3f}')
        print(f'\tBLEU Score: {bleu:.3f}')


if __name__ == '__main__':
    run(total_epoch=epoch, best_loss=inf)


http://www.kler.cn/a/569198.html

相关文章:

  • 【大模型学习笔记】0基础本地部署dify教程
  • AI辅助学习vue第十四章
  • 欧拉22.03系统安装离线redis 6.2.5
  • vue3配置端口,比底部vue调试
  • logback日志输出配置范例
  • FPGA AXI-Stream协议详解与仿真实践
  • Git版本管理逻辑解析:从核心原理到工作流实践
  • Java零基础入门笔记:(7)异常
  • 中间件专栏之Redis篇——Redis中过期key删除和内存淘汰策略
  • TCP传输过程中问题的检测和解决
  • 物联网坡体斜度监测设备 顶级功能,还想集成CPS 红外 土质监测
  • 如何用 TikTok 的创作工具提升你的视频质量?
  • 项目准备(flask+pyhon+MachineLearning)- 3
  • Notpad++通过SFTP连接ubuntu20.04实现windows下文件修改
  • 计算机面试项目经历描述技巧
  • 530 Login fail. A secure connection is requiered(such as ssl)-java发送QQ邮箱(简单配置)
  • 回归实战详细代码+解析:预测新冠感染人数
  • DeepSeek的开源周有什么看点?
  • DeepEP库开源啦!DeepSeek优化GPU通信,破算力瓶颈。
  • 计算机网络——详解TCP三握四挥