当前位置: 首页 > article >正文

Pytorch封装简单RNN模型,进行中文训练及文本预测

简述

使用pytorch封装简单RNN模型,使用单层nn.RNNnn.Linear等实现,然后做简单的文本预测。

数据集

代码参考李沐:https://zh-v2.d2l.ai/chapter_recurrent-neural-networks/rnn-concise.html,但他使用的是一篇英文小说,
这里改为使用COIG-CQIA中文数据集中的:douban_book_introduce.jsonlruozhiba_ruozhiba_ruozhiba.jsonl两个文件,本文目的是为了学习rnn,所以数据集比较简单,不过这个数据集由于都是问答形式,不像小说那样有主题性,所以感觉学习效果不好。理想的应该还是找个中文长篇小说之类。

COIG-CQIA: https://huggingface.co/datasets/m-a-p/COIG-CQIA

另外由于COIG-CQIA的数据是指令问答形式的json文件,所以这里稍作处理,改为单个问题+答案为一行的纯文本txt格式, 去除其它json字段及各种符号。

代码如下:

def jsonl_to_txt(dir_path):  
    dict_list = []  
    jsonl_list = os.listdir(dir_path)  
  
    qa_list = list()  
  
    chars_to_remove = r'[,。?;、:“”:!~()『』「」【】\"\[\]➕〈〉/<>()‰\%《》\*\?\-\.…·○01234567890123456789•\n\t abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ—*]'  
  
    for jsonl in jsonl_list:  
        path = os.path.join(dir_path, jsonl)  
        print(path)  
        with open(path, 'r', encoding='utf-8') as f:  
            jsonl_data = f.readlines()  
            for line in jsonl_data:  
                line_dict = JSON.loads(line)  
                qa = line_dict['instruction'] + line_dict['output']  
                qa = re.sub(chars_to_remove, '', qa).strip()  
                qa_list.append(qa)  
  
    path = os.path.join(dir_path, 'chengyu_qa.txt')  
    with open(path, 'w', encoding='utf-8') as f:  
        f.write('\n'.join(qa_list))  
  
  
if __name__ == '__main__':  
    dir_path = '../data/COIG-CQIA'  
    jsonl_to_txt(dir_path)  
  
    print()

上面处理完毕后,还需要进行词元化、构建词典等步骤,参考:
python实现简单中文词元化、词典构造、时序数据集封装等-CSDN博客

模型封装

RNN — PyTorch 2.4 documentation

可以先观察一下tensorboard的add_graph函数对模型可视化后的结构:

在这里插入图片描述

这里使用单层的RNN(nn.RNN有默认参数num_layers=1),nn.functional.one_hot是为了实现单词的向量化表示,后续可以优化成nn.Embedding来做词向量。

nn.functional.one_hot前将x进行了转置,这里有点抽象,来关注一下nn.RNN的参数要求,便可理解。

先看x的初始shape为(batch_size, time_size),转置并向量化后为(time_size, batch_size, vocab_size)

若不转置直接向量化,则为(batch_size, time_size, vocab_size),实际上这两种格式的数据nn.RNN都支持。

但若为(batch_size, time_size, vocab_size)形式,则需在创建nn.RNN实例时指定参数batch_first=False。

在这里插入图片描述

另外,还需要提供一个初始的隐状态,这里用init_state函数实现。

在这里插入图片描述

class SimpleRNNModel(nn.Module):  
    def __init__(self, vocab_size, hidden_size):  
        super(SimpleRNNModel, self).__init__()  
        self.vocab_size = vocab_size  
        self.hidden_size = hidden_size  
  
        self.rnn = nn.RNN(vocab_size, hidden_size)  
        self.linear = nn.Linear(hidden_size, vocab_size)  
  
    def forward(self, x, hidden_state=None):  
        x = nn.functional.one_hot(x.T.long(), num_classes=self.vocab_size)  
        x = x.to(torch.float32)  
  
        outputs, hidden_state = self.rnn(x, hidden_state)  
        # rrn的outputs.shape(N, L, D*H)  
        outputs = outputs.reshape(-1, self.hidden_size)  
        outputs = self.linear(outputs)  
        return outputs, hidden_state  
  
    def init_state(self, device, batch_size=1):  
        return torch.zeros((self.rnn.num_layers, batch_size, self.hidden_size), device=device)  

梯度裁剪

源自李沐:https://zh-v2.d2l.ai/chapter_recurrent-neural-networks/rnn-scratch.html

def grad_clipping(net, max_norm):  
    if isinstance(net, nn.Module):  
        params = [p for p in net.parameters() if p.requires_grad]  
    else:  
        params = net.params  
    norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))  
    if norm > max_norm:  
        for param in params:  
            param.grad[:] *= max_norm / norm

模型训练

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  
print(f'\ndevice: {device}')  
  
corpus, vocab = load_corpus("../data/COIG-CQIA/qa_list.txt")  
  
vocab_size = len(vocab)  
hidden_size = 256  
epochs = 5  
batch_size = 50  
learning_rate = 0.01  
time_size = 4  
max_grad_max_norm = 0.5  
  
dataset = make_dataset(corpus=corpus, time_size=time_size)  
data_loader = data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)  
  
net = SimpleRNNModel(vocab_size, hidden_size)  
net.to(device)  
  
# print(net.state_dict())  
  
criterion = nn.CrossEntropyLoss()  
criterion.to(device)  
optimizer = optim.Adam(net.parameters(), lr=learning_rate)  
  
writer = SummaryWriter('./train_logs')  
# 随便定义个输入, 好使用add_graph  
tmp = torch.rand((batch_size, time_size)).to(device)  
writer.add_graph(net, tmp)  
  
loss_counter = 0  
total_loss = 0  
ppl_list = list()  
total_train_step = 0  
  
for epoch in range(epochs):  
    print('------------Epoch {}/{}'.format(epoch + 1, epochs))  
  
    for X, y in data_loader:  
        X, y = X.to(device), y.to(device)  
        # 如果各个批次间的时序是连续的,则可以把上次的hidden_state传入下个批次, 不然就要重置hidden_state  
        # 这里batch_size=X.shape[0]是因为在加载数据时, DataLoader没有设置丢弃不完整的批次, 所以存在实际批次不满足设定的batch_size  
        hidden_state = net.init_state(batch_size=X.shape[0], device=device)  
        outputs, hidden_state = net(X, hidden_state=hidden_state)  
  
        optimizer.zero_grad()  
        # y也变成 时间序列*批次大小的行数, 才和 outputs 一致  
        y = y.T.reshape(-1)  
        # 交叉熵的第二个参数需要LongTorch  
        loss = criterion(outputs, y.long())  
        loss.backward()  
        # 求完梯度之后可以考虑梯度裁剪, 再更新梯度  
        grad_clipping(net, max_grad_max_norm)  
        optimizer.step()  
  
        total_loss += loss.item()  
        loss_counter += 1  
        total_train_step += 1  
        if total_train_step % 10 == 0:  
            print(f'Epoch: {epoch + 1}, 累计训练次数: {total_train_step}, 本次loss: {loss.item():.4f}')  
            writer.add_scalar('train_loss', loss.item(), total_train_step)  
  
    ppl = np.exp(total_loss / loss_counter)  
    ppl_list.append(ppl)  
    print(f'Epoch {epoch + 1} 结束, batch_loss_average: {total_loss / loss_counter}, perplexity: {ppl}')  
    writer.add_scalar('ppl', ppl, epoch + 1)  
    total_loss = 0  
    loss_counter = 0  
  
    torch.save(net.state_dict(), './save/epoch_{}_ppl_{}.pth'.format(epoch + 1, ppl))  
  
writer.close()

tensorboard训练过程观察

横轴为训练epoch。

在这里插入图片描述

横轴为训练次数。

在这里插入图片描述

文本预测

这里首先完善模型的预测函数(该函数放到模型中):

def predict(self, prefix, num_preds, vocab, device):  
    state = self.init_state(batch_size=1, device=device)  
    # prefix为字符, 转成索引  
    outputs = [vocab.word2idx(prefix[0])]  
    get_input = lambda: torch.tensor([outputs[-1]], device=device).reshape((1, 1))  
    # 一个字符一个字符跑一遍, 对用户输入进行预热, 即对输入的各个字符间建立联系  
    for y in prefix[1:]:  # 预热期  
        _, state = self.forward(get_input(), state)  
        outputs.append(vocab.word2idx(y))  
    # 刚好每次都用上一次的预测值做输入  
    for _ in range(num_preds):  # 预测num_preds步  
        y, state = self.forward(get_input(), state)  
        outputs.append(int(y.argmax(dim=1).reshape(1)))  
    return ''.join([vocab.idx2word(i) for i in outputs])

实现对提示词处理及预测函数的调用:

注意:这里的语料库应和训练使用的一致。

def predict(state_dict_path, vocab, prefix=None, num_preds=3):  
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  
  
    vocab_size = len(vocab)  
    hidden_size = 256  
  
    net = SimpleRNNModel(vocab_size, hidden_size).to(device)  
    net.load_state_dict(torch.load(state_dict_path, map_location=device, weights_only=True))  
  
    net.eval()  
    with torch.no_grad():  
        outputs = net.predict(prefix=prefix, num_preds=num_preds, vocab=vocab, device=device)  
    return outputs  
  
  
if __name__ == '__main__':  
    corpus, vocab = load_corpus("../data/COIG-CQIA/qa_list.txt")  
    # corpus, vocab = load_corpus("../data/COIG-CQIA/chengyu_qa.txt")  
    # print(len(vocab))  
    # idx = [vocab.word2idx(ch) for ch in prefix]  
    path = "../save/Simple/新建文件夹/state_dict-time_size_30-ppl_1.pth"  
  
    prefix = "有什么超赞的诗句"  
    print(f'提示词: {prefix}')  
    outputs = predict(path, vocab, prefix=prefix, num_preds=22)  
    print(f'预测输出: {outputs}\n')

http://www.kler.cn/a/281618.html

相关文章:

  • 深度学习学习经验——目标检测及其应用
  • 【Spring Boot进阶】掌握Spring Boot框架核心注解:从入门到精通(实战指南)
  • win10配置adb环境变量
  • 使用QT开发一些特殊相机的思路:个人经验
  • React -TS学习—— useRef
  • 面试题(13)
  • Windows 10/11降级漏洞的工具包现已发布 仅供安全测试
  • UniApp 小程序
  • 八股总结-----C++、数据结构、算法
  • 美国高防服务器租用
  • OpenCV中使用金字塔LK光流法(上)
  • 【小沐学Rust】Rust实现TCP网络通信
  • IP-RDS-222、IP-PRZ-59-AM12、EG-TRZ-42-L、EG-TRZ-42-H比例减压阀放大器
  • 从 MLOps 到 MLOops:揭露机器学习平台的攻击面
  • RecyclerView嵌套RecyclerView,上下滑动的时候会出现item数据以及view的显示异常问题
  • 红黑树、B+Tree、B—Tree
  • 【XR】优化SLAM SDK的稳定性
  • Qt:玩转QPainter序列九
  • uni-app小程序当前页面刷新怎么实现
  • CSS中的align-content属性:实现垂直居中的新方式