【NLP9-Transformer经典案例】

Transformer经典案例

1、语言模型

以一个符合语言规律的序列为输入,模型将利用序列间关系等特征,输出在一个在所有词汇上的概率分布,这样的模型称为语言模型。

2、语言模型能解决的问题

根据语言模型定义,可以在它的基础上完成机器翻译,文本生成等任务,因为我们通过最后输出的概率分布来预测下一个词汇是什么

语言模型可以判断输入的序列是否为一句完整的话,因为我们可以根据输出的概率分布查看最大概率是否落在句子结束符上,来判断完整性

语言模型本身的训练目标是预测下一个词,因为它的特征提取部分会抽象很多语言序列之间的关系,这些关系可能同样对其它语言类任务有效果。因此可以作为预训练模型进行迁移学习

3、模型实现步骤

1、导包

2、导入wikiText-2数据集并作基本处理

3、构建用于模型输入的批次化数据

4、构建训练和评估函数

5、进行训练和评估(包括验证以及测试)

4、数据准备

wikiText-2数据集体量中等,训练集共有600篇文章,共208万左右词汇,在33278个不重复词汇,OVV(有多少正常英文词汇不在该数据集中的占比)为2.6%。验证集和测试集均为60篇。

torchtext重要功能

对文本数据进行处理,比如文本语料加载,文本迭代器构建等。

包含很多经典文本语料的预加载方法。其中包括的语料有:用于情感分析的SST和IMDB,用于问题分类TREC,用于及其翻译的WMT14,IWSLT,以及用于语言模型任务wikiText-2

# 1、导包
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
#英文文本数据集工具包
import torchtext

#导入英文分词工具
#from torchtext.legacy.data import Field

from torchtext.data.utils import get_tokenizer
#from torchtext.legacy.data import *
#导入已经构建完成的Transformer包

from pyitcast.transformer import TransformerModel

import torch
import torchtext
from torchtext.legacy.data import Field,TabularDataset,Iterator,BucketIterator


TEXT=torchtext.data.Field(tokenize=get_tokenizer("basic_english"),
                          init_token ='<sos>',
                          eos_token ='<eos>',
                          lower = True)
print(TEXT)
# 2、导入wikiText-2数据集并作基本处理
train_txt,val_txt,test_txt = torchtext.datasets.WikiText2.splits(TEXT)
print(test_txt.examples[0].text[:10])

#将训练集文本数据构建一个vocab对象,可以使用vocab对象的stoi方法统计文本共包含的不重复的词汇总数
TEXT.build_vocab(train_txt)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# 3、构建用于模型输入的批次化数据
def batchify(data,batch_size):
    #data,之前得到的文本数据
    # bsz,批次的样本量
    #第一步使用TEXT的numericalize方法将单词映射成对应的联系数字
    data = TEXT.numericalize([data.examples[0].text])
    #取得需要经过多少次的batch_size后能够遍历完所有的数据
    nbatch = data.size(0) // batch_size
    #利用narrow方法对数据进行切割
    #第一参数代表横轴切割还是纵轴切割,0 代表横轴,1代表纵轴
    #第二个参数,第三个参数分别代表切割的起始位置和终止位置
    data=data.narrow(0,0,nbatch * batch_size)
    #对data的形状进行转变
    data = data.view(batch_size,-1).t().contiguous()
    return data.to(device)

# x=torch.tensor([[1,2,3],[4,5,6],[7,8,9]])
# print(x.narrow(0,0,2))
# print(x.narrow(1,1,2))

#设置训练数据、验证数据、测试数据的批次大小
batch_size =20
eval_batch_size =10
train_data = batchify(train_txt,batch_size)
val_data = batchify(val_txt,eval_batch_size)
test_data = batchify(test_txt,eval_batch_size)

#设定句子的最大长度
bptt =35
def get_batch(source,i):
    seq_len = min(bptt,len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].view(-1)
    return data,target

# source = test_data
# i =1
# x,y = get_batch(source,i)
# print(x)
# print(y)


# 4、构建训练和评估函数
#通过TEXT.vocab.stoi方法获取不重复的词汇总数
ntokens = len(TEXT.vocab.stoi)
#设置词嵌入维度的值等于200
emsize =200
#设置前馈全连接层的节点数等于200
nhid =200
#设置编码层层数等于2
nlayers=2
#设置多头注意力中的头数等于2
nhead =2
#设置置零比率
dropout =0.2
#将参数传入TransformerModel 中实例化模型
model = TransformerModel(ntokens,emsize,nhead,nhid,nlayers,dropout).to(device)

#设定损失函数,采用交叉熵损失函数
criterion= nn.CrossEntropyLoss()

#设置学习率
lr =5.0

#设置优化器
optimizer = torch.optim.SGD(model.parameters(),lr=lr)

criterion = nn.CrossEntropyLoss()
lr =5.0
optimizer = torch.optim.SGD(model.parameters(),lr=lr)
#定义学习率调整器,使用torch自带的lr_scheduler,将优化器传入
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,1.0,gamma=0.95)


# 5、进行训练和评估(包括验证以及测试)

#训练函数
import time
def train():
    #首先开启训练模式
    model.train()
    #定义初始损失值
    total_loss =0

    start_time = time.time()
    #遍历训练数据进行模型的训练
    for batch,i in enumerate(range(0,train_data.size(0) -1 ,bptt)):
        #通过前面的get_batch函数获取源数据和目标数据
        data,targets = get_batch(train_data,i)
        #设置梯度归零
        optimizer.zero_grad()
        #通过模型预测输出
        output = model(data)
        #计算损失值
        loss = criterion(output.view(-1,ntokens),targets)
        #进行反向传播
        loss.backward()
        #进行梯度规范化,防止出现梯度爆炸或者梯度消失
        torch.nn.utils.clip_grad_norm(model.parameters(),0.5)
        #进行参数更新
        optimizer.step()
        #将损失值进行累加
        total_loss +=loss.item()
        # 获取当前开始时间
        log_interval =200
        #打印日志信息
        if batch % log_interval ==0 and batch>0:
            #计算平均损失
            cur_loss = total_loss/log_interval
            #计算训练到目前的耗时
            elapsed = time.time() - start_time
            #打印日志信息
            print('| epoch {:3d} | {:5d}/{:5d} batches | '
                  'lr {:02.2f} | ms/batch {:5.2f} | '
                  'loss {:5.2f} | ppl {:8.2f}'.format(
                   epoch, batch, len(train_data) // bptt, scheduler.get_lr()[0],
                   elapsed * 1000 / log_interval,
                   cur_loss, math.exp(cur_loss)))
            #每个打印批次借结束后,将总损失值清零
            total_loss =0
            start_time= time.time()

#评估函数
def evaluate(eval_model,data_source):
    #eval_model,代表每轮训练后产生的模型
    # data_source,代表验证集数据或者测试集数据
    #首先开启评估模式
    eval_model.eval()
    #初始化总损失值
    total_loss =0
    #模型开启评估模式,不进行反向传播求梯度
    with torch.no_grad():
        #遍历验证数据
        for i in range(0,data_source.size(0)-1,bptt):
            #首先通过get_batch函数获取源数据和目标数据
            data,targets = get_batch(data_source,i)
            #将源数据放入评估模型中,进行预测
            output = eval_model(data)
            #对输出张量进行变形
            output_flat = output.view(-1,ntokens)
            #累加损失值
            total_loss +=criterion(output_flat,targets).item()
        #返回评估的总损失值
        return total_loss

best_val_loss = float("inf")

epochs =3
best_model =None

for epoch in range(1,epochs +1):
    epoch_start_time = time.time()
    train()
    val_loss = evaluate(model,val_data)
    print('-'*50)
    print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
          'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                     val_loss, math.exp(val_loss)))
    print('-'*50)
    #通过比较当前轮次的损失值,获取最佳模型
    if val_loss<best_val_loss:
        best_val_loss=val_loss
        best_model=model
    #每个轮次后调整优化器的学习率
    scheduler.step()

#添加测试的流程代码
test_loss = evaluate(best_model,test_data)
print('-'*90)
print('|End of training |test loss {:5,2f}'.format(test_loss))
print('-'*50)


没有运行成功

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.kler.cn/a/273431.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

放慢音频速度的三个方法 享受慢音乐

如何让音频慢速播放&#xff1f;我们都知道&#xff0c;在观看视频时&#xff0c;我们可以选择快进播放&#xff0c;但是很少有软件支持慢速播放。然而&#xff0c;将音频慢速播放在某些情况下是非常必要的。例如&#xff0c;当我们学习一门新语言时&#xff0c;我们可以将音频…

【数据挖掘】实验3:常用的数据管理

实验3&#xff1a;常用的数据管理 一&#xff1a;实验目的与要求 1&#xff1a;熟悉和掌握常用的数据管理方法&#xff0c;包括变量重命名、缺失值分析、数据排序、随机抽样、字符串处理、文本分词。 二&#xff1a;实验内容 【创建新变量】 方法1&#xff1a; mydata <…

还原wps纯粹的编辑功能

1.关闭稻壳模板&#xff1a; 1.1. 启动wps(注意不要乱击稻壳模板&#xff0c;点了就找不到右键菜单了) 1.2. 在稻壳模板选项卡右击&#xff1a;选不再默认展示 2.关闭托盘中wps云盘图标&#xff1a;右击云盘图标/同步与设置&#xff1a; 2.1.关闭云文档同步 2.2.窗口选桌面应用…

VSCode下使用github初步

由于各种需要&#xff0c;现在需要统一将一些代码提交搞github&#xff0c;于是有了在VSCode下使用github的需求。之前只是简单的使用git clone&#xff0c;代码提交这些用的是其他源代码工具&#xff0c;于是得学习实操下&#xff0c;并做一记录以备后用。 安装 VSCode安装 …

java的成员变量和局部变量

1、什么是成员变量和局部变量&#xff1f; 2、成员变量和局部变量区别 区别 成员变量 局部变量 类中位置不同 类中方法外 方法内或者方法声明上 内存中位置不同 堆内存 栈内存 生命周期不同 随着对象的存在而存在&#xff0c;随着对象的消失而消失 随着方法的调用而…

基础:TCP三次握手做了什么,为什么要握手?

1. TCP 三次握手在做些什么 1. 第一次握手 &#xff1a; 1&#xff09;握手作用&#xff1a;客户端发出建立连接请求。 2&#xff09;数据处理&#xff1a;客户端发送连接请求报文段&#xff0c;将SYN位置为1&#xff0c;Sequence Number为x;然后&#xff0c;客户端进入SYN_S…

【DataWhale学习笔记-蝴蝶书共读】大语言模型背后

从图灵测试到ChatGPT 1950年&#xff0c;艾伦•图灵(Alan Turing)发表论文《计算机器与智能》&#xff08; Computing Machinery and Intelligence&#xff09;&#xff0c;提出并尝试回答“机器能否思考”这一关键问题。在论文中&#xff0c;图灵提出了“模仿游戏”&#xff…

CTF题型 Http请求走私总结Burp靶场例题

CTF题型 Http请求走私总结&靶场例题 文章目录 CTF题型 Http请求走私总结&靶场例题HTTP请求走私HTTP请求走私漏洞原理分析为什么用前端服务器漏洞原理界定标准界定长度 重要!!!实验环境前提POST数据包结构必要结构快速判断Http请求走私类型时间延迟CL-TETE-CL 练习例题C…

EI Scopus检索 | 第二届大数据、物联网与云计算国际会议(ICBICC 2024) |

会议简介 Brief Introduction 2024年第二届大数据、物联网与云计算国际会议(ICBICC 2024) 会议时间&#xff1a;2024年12月29日-2025年1月1日 召开地点&#xff1a;中国西双版纳 大会官网&#xff1a;ICBICC 2024-2024 International Conference on Big data, IoT, and Cloud C…

苍穹外卖-day09:用户端历史订单模块(理解业务逻辑),商家端订单管理模块(理解业务逻辑),校验收货地址是否超出配送范围(相关API)

用户端历史订单模块 1. 查询历史订单&#xff08;分页查询&#xff09; 1.1 需求分析和设计 产品原型&#xff1a; 业务规则 分页查询历史订单可以根据订单状态查询展示订单数据时&#xff0c;需要展示的数据包括&#xff1a;下单时间、订单状态、订单金额、订单明细&#…

springboot280基于WEB的旅游推荐系统设计与实现

旅游推荐系统设计与实现 传统办法管理信息首先需要花费的时间比较多&#xff0c;其次数据出错率比较高&#xff0c;而且对错误的数据进行更改也比较困难&#xff0c;最后&#xff0c;检索数据费事费力。因此&#xff0c;在计算机上安装旅游推荐系统软件来发挥其高效地信息处理…

android 顺滑滑动嵌套布局

1. 背景 最近项目中用到了上面的布局&#xff0c;于是使用了scrollviewrecycleview&#xff0c;为了自适应高度&#xff0c;重写了recycleview&#xff0c;实现了高度自适应&#xff1a; public class CustomRecyclerView extends RecyclerView {public CustomRecyclerView(Non…

Mac玩《幻兽帕鲁》为什么打不开D3DMetal?d3d错误怎么办 d3dxl error

我之前发了一篇讲Mac电脑玩Steam热门新游《幻兽帕鲁》的文章&#xff08;没看过的点这里&#xff09;&#xff0c;后来也看到很多朋友去尝试了&#xff0c;遇到了一些问题&#xff0c;无法进入《幻兽帕鲁》游戏&#xff0c;或者是玩的时候卡顿以及出现黑屏&#xff0c;通过我的…

基于YOLOv8深度学习的橙子病害智能诊断与防治系统【python源码+Pyqt5界面+数据集+训练代码】深度学习实战、目标分类

《博主简介》 小伙伴们好&#xff0c;我是阿旭。专注于人工智能、AIGC、python、计算机视觉相关分享研究。 ✌更多学习资源&#xff0c;可关注公-仲-hao:【阿旭算法与机器学习】&#xff0c;共同学习交流~ &#x1f44d;感谢小伙伴们点赞、关注&#xff01; 《------往期经典推…

mac清除dns缓存指令 mac清除缓存怎么清理

你是否曾经被要求清理dns缓存并刷新&#xff1f;清理dns缓存一般是由于修改了主机文件&#xff0c;或者想排除一些网络上的故障。在Mac上清除dns缓存需要使用命令行来实现。在本文中&#xff0c;软妹子将向大家介绍mac清除dns缓存指令&#xff0c;并展示mac清除缓存怎么清理。 …

【vscode】vscode重命名变量后多了很多空白行

这种情况&#xff0c;一般出现在重新安装 vscode 后出现。 原因大概率是语言服务器没设置好或设置对。 以 Python 为例&#xff0c;到设置里搜索 "python.languageServer"&#xff0c;将 Python 的语言服务器设置为 Pylance 即可。

计算地球圆盘负荷产生的位移

1.研究背景 计算受表面载荷影响的弹性体变形问题有着悠久的历史&#xff0c;涉及到许多著名的数学家和物理学家&#xff08;Boussinesq 1885&#xff1b;Lamb 1901&#xff1b;Love 1911&#xff0c;1929&#xff1b;Shida 1912&#xff1b;Terazawa 1916&#xff1b;Munk &…

Ubuntu 搭建gitlab服务器,及使用repo管理

一、GitLab安装与配置 GitLab 是一个用于仓库管理系统的开源项目&#xff0c;使用Git作为代码管理工具&#xff0c;并在此基础上搭建起来的Web服务。 1、安装Ubuntu系统&#xff08;这个教程很多&#xff0c;就不展开了&#xff09;。 2、安装gitlab社区版本&#xff0c;有需…

利用textarea和white-space实现最简单的文章编辑器 支持缩进和换行

当你遇到一个非常基础的文章发布和展示的需求&#xff0c;只需要保留换行和空格缩进&#xff0c;你是否会犹豫要使用富文本编辑器&#xff1f;实际上这个用原生的标签两步就能搞定&#xff01; 1.直接用textarea当编辑器 textarea本身就可以保存空格和换行符&#xff0c;示例如…

vue使用element-ui 实现自定义分页

可以通过插槽实现自定义的分页。在layout里面进行配置。 全部代码 export default { name:Cuspage, props:{total:Number, }, data(){return {currentPage:1,pageSize:10,} } methods: {setslot (h) {return(<div class"cusPage"›<span on-click{this.toBe…
最新文章