BERT训练环节(代码实现)
1.代码实现
#导包
import torch
from torch import nn
import dltools
#加载数据需要用到的声明变量
batch_size, max_len = 1, 64
#获取训练数据迭代器、词汇表
train_iter, vocab = dltools.load_data_wiki(batch_size, max_len)
#其余都是二维数组
#tokens, segments, valid_lens(一维), pred_position, mlm_weights, mlm, nsp(一维)对应每条数据i中包含的数据
for i in train_iter: #遍历迭代器
break #只遍历一条数据
[tensor([[ 3, 25, 0, 4993, 0, 24, 4, 26, 13, 2, 158, 20, 5, 73, 1399, 2, 9, 813, 9, 987, 45, 26, 52, 46, 53, 158, 2, 5, 3140, 5880, 9, 543, 6, 6974, 2, 2, 315, 6, 8, 5, 8698, 8, 17229, 9, 308, 2, 4, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), tensor([[0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), tensor([47.]), tensor([[ 9, 15, 26, 32, 34, 35, 45, 0, 0, 0]]), tensor([[1., 1., 1., 1., 1., 1., 1., 0., 0., 0.]]), tensor([[ 484, 1288, 20, 6, 2808, 9, 18, 0, 0, 0]]), tensor([0])]
#创建BERT网络模型
net = dltools.BERTModel(len(vocab), num_hiddens=128, norm_shape=[128],
ffn_num_input=128, ffn_num_hiddens=256, num_heads=2,
num_layers=2, dropout=0.2, key_size=128, query_size=128,
value_size=128, hid_in_features=128, mlm_in_features=128,
nsp_in_features=128)
#调用设备上的GPU
devices = dltools.try_all_gpus()
#损失函数对象
loss = nn.CrossEntropyLoss() #多分类问题,使用交叉熵
#@save #表示用于指示某些代码应该被保存或导出,以便于管理和重用
def _get_batch_loss_bert(net, loss, vocab_size, tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X, mlm_Y, nsp_y):
#前向传播
#获取遮蔽词元的预测结果、下一个句子的预测结果
_, mlm_Y_hat, nsp_Y_hat = net(tokens_X, segments_X, valid_lens_x.reshape(-1), pred_positions_X)
#计算遮蔽语言模型的损失
mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) * mlm_weights_X.reshape(-1,1)
mlm_l = mlm_l.sum() / (mlm_weights_X.sum() + 1e-8) #MLM损失函数的归一化版本 #加一个很小的数1e-8,防止分母为0,抵消上一行代码乘以的数值
#计算下一个句子预测任务的损失
nsp_l = loss(nsp_Y_hat, nsp_y)
l = mlm_l + nsp_l
return mlm_l, nsp_l, l
def train_bert(train_iter, net, loss, vocab_size, devices, num_steps): #文本词元样本量太多,全跑完花费的时间太多,若num_steps=1在BERT中表示,跑了1个batch_size
net = nn.DataParallel(net, device_ids=devices).to(devices[0]) #调用设备的GPU
trainer = torch.optim.Adam(net.parameters(), lr=0.01) #梯度下降的优化算法Adam
step, timer = 0, dltools.Timer() #设置计时器
#调用画图工具
animator = dltools.Animator(xlabel='step', ylabel='loss', xlim=[1, num_steps], legend=['mlm', 'nsp'])
#遮蔽语言模型损失的和, 下一句预测任务损失的和, 句子对的数量, 计数
metric = dltools.Accumulator(4) #Accumulator类被设计用来收集和累加各种指标(metric)
num_steps_reached = False #设置一个判断标志, 训练步数是否达到预设的步数
while step < num_steps and not num_steps_reached:
for tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X, mlm_Y, nsp_y in train_iter:
#将遍历的数据发送到设备上
tokens_X = tokens_X.to(devices[0])
segments_X = segments_X.to(devices[0])
valid_lens_x = valid_lens_x.to(devices[0])
pred_positions_X = pred_positions_X.to(devices[0])
mlm_weights_X = mlm_weights_X.to(devices[0])
mlm_Y, nsp_y = mlm_Y.to(devices[0]), nsp_y.to(devices[0])
#梯度清零
trainer.zero_grad()
timer.start() #开始计时
mlm_l, nsp_l, l = _get_batch_loss_bert(net, loss, vocab_size, tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X, mlm_Y, nsp_y)
l.backward() #反向传播
trainer.step() #梯度更新
metric.add(mlm_l, nsp_l, tokens_X.shape[0], l) #累积的参数指标
timer.stop() #计时停止
animator.add(step + 1, (metric[0] / metric[3], metric[1] / metric[3])) #画图的
step += 1 #训练完一个batch_size,就+1
if step == num_steps: #若步数与预设的训练步数相等
num_steps_reached = True #判断标志改为True
break #退出while循环
print(f'MLM loss {metric[0] / metric[3]:.3f}, 'f'NSP loss {metric[1] / metric[3]:.3f}')
print(f'{metric[2]/ timer.sum():.1f} sentence pairs/sec on 'f'{str(devices)}')
train_bert(train_iter, net, loss, len(vocab), devices, 500)
def get_bert_encoding(net, tokens_a, tokens_b=None):
tokens, segments = dltools.get_tokens_and_segments(tokens_a, tokens_b)
token_ids = torch.tensor(vocab[tokens], device=devices[0]).unsqueeze(0) #unsqueeze(0)增加一个维度
segments = torch.tensor(segments, device=devices[0]).unsqueeze(0)
valid_len = torch.tensor(len(tokens), device=devices[0]).unsqueeze(0)
endoced_X, _, _ = net(token_ids, segments, valid_len)
return endoced_X
tokens_a = ['a', 'crane', 'is', 'flying']
encoded_text = get_bert_encoding(net, tokens_a)
# 词元:'<cls>','a','crane','is','flying','<sep>'
encoded_text_cls = encoded_text[:, 0, :]
encoded_text_crane = encoded_text[:, 2, :]
encoded_text.shape, encoded_text_cls.shape, encoded_text_crane[0][:3]
(torch.Size([1, 6, 128]), torch.Size([1, 128]), tensor([-0.5872, -0.0510, -0.7376], device='cuda:0', grad_fn=<SliceBackward0>))
encoded_text_crane
tensor([[-5.8725e-01, -5.0994e-02, -7.3764e-01, -4.3832e-02, 9.2467e-02, 1.2745e+00, 2.7062e-01, 6.0271e-01, -5.5055e-02, 7.5122e-02, 4.4872e-01, 7.5821e-01, -6.1558e-02, -1.2549e+00, 2.4479e-01, 1.3132e+00, -1.0382e+00, -4.7851e-03, -6.3590e-01, -1.3180e+00, 5.2245e-02, 5.0982e-01, 7.4168e-02, -2.2352e+00, 7.4425e-02, 5.0371e-01, 7.2120e-02, -4.6384e-01, -1.6588e+00, 6.3987e-01, -6.4567e-01, 1.7187e+00, -6.9696e-01, 5.6788e-01, 3.2628e-01, -1.0486e+00, -7.2610e-01, 5.7909e-02, -1.6380e-01, -1.2834e+00, 1.6431e+00, -1.5972e+00, -4.5678e-03, 8.8022e-02, 5.5931e-02, -7.2332e-02, -4.9313e-01, -4.2971e+00, 6.9757e-01, 7.0690e-02, -1.8613e+00, 2.0366e-01, 8.9868e-01, -3.4565e-01, 9.6776e-02, 1.3699e-02, 7.1410e-01, 5.4820e-01, 9.7358e-01, -8.1038e-01, 2.6216e-01, -5.7850e-01, -1.1969e-01, -2.5277e-01, -2.0046e-01, -1.6718e-01, 5.5540e-01, -1.8172e-01, -2.5639e-02, -6.0961e-01, -1.1521e-03, -9.2973e-02, 9.5226e-01, -2.4453e-01, 9.7340e-01, -1.7908e+00, -2.9840e-02, 2.3087e+00, 2.4889e-01, -7.2734e-01, 2.1827e+00, -1.1172e+00, -7.0915e-02, 2.5138e+00, -1.0356e+00, -3.7332e-02, -5.6668e-01, 5.2251e-01, -5.0058e-01, 1.7354e+00, 4.0760e-01, -1.2982e-01, -7.0230e-01, 3.1563e+00, 1.8754e-01, 2.0220e-01, 1.4500e-01, 2.3296e+00, 4.5522e-02, 1.1762e-01, 1.0662e+00, -4.0858e+00, 1.6024e-01, 1.7885e+00, -2.7034e-01, -1.6869e-01, -8.7018e-02, -4.2451e-01, 1.1446e-01, -1.5761e+00, 7.6947e-02, 2.4336e+00, 4.5346e-02, -6.5078e-02, 1.4203e+00, 3.7165e-01, -7.9571e-01, -1.3515e+00, 4.1511e-02, 1.3561e-01, -3.3006e+00, 1.4821e-01, 1.3024e-01, 1.9966e-01, -8.5910e-01, 1.4505e+00, 7.6774e-02, 9.3771e-01]], device='cuda:0', grad_fn=<SliceBackward0>)
tokens_a, tokens_b = ['a', 'crane', 'driver', 'came'], ['he', 'just', 'left']
encoded_pair = get_bert_encoding(net, tokens_a, tokens_b)
# 词元:'<cls>','a','crane','driver','came','<sep>','he','just', 'left','<sep>'
encoded_pair_cls = encoded_pair[:, 0, :]
encoded_pair_crane = encoded_pair[:, 2, :]
encoded_pair.shape, encoded_pair_cls.shape, encoded_pair_crane[0][:3]
(torch.Size([1, 10, 128]), torch.Size([1, 128]), tensor([-0.4637, -0.0569, -0.6119], device='cuda:0', grad_fn=<SliceBackward0>))