当前位置: 首页 > article >正文

Day26下 - BERT项目实战

BERT论文:https://arxiv.org/pdf/1810.04805

BERT架构:

 

BERT实战 

1. 读取数据

# pandas 适合表格类数据读取
import pandas as pd
import numpy as np

# sep: 分隔符
data = pd.read_csv(filepath_or_buffer="samples.tsv", sep="\t").to_numpy()

# 打乱样本顺序
np.random.shuffle(data)

2. 打包数据

# 深度学习框架
import torch
# 深度学习中的封装层
from torch import nn
# 引入数据集
from torch.utils.data import Dataset
# 引入数据集加载器
from torch.utils.data import DataLoader


class SentiDataset(Dataset):
    """
        自定义数据集
    """
    def __init__(self, data):
        """
            初始化
        """
        self.data = data
    
    def __getitem__(self, idx):
        """
            按索引获取单个样本
        """
        x, y = self.data[idx]
        return x, y
    
    def __len__(self):
        """
            返回数据集中的样本个数
        """
        return len(self.data)


# 训练集(前4500个作为训练集)
train_dataset = SentiDataset(data=data[:4500])
train_dataloader = DataLoader(dataset=train_dataset, shuffle=True, batch_size=32)
# 测试集(4500之后的作为测试集)
test_dataset = SentiDataset(data=data[4500:])
test_dataloader = DataLoader(dataset=test_dataset, shuffle=False, batch_size=32)


for x, y in train_dataloader:
    
    # print(x)
    # print(y)
    break

3. 构建模型

# 用于加载 BERT 分词器
from transformers import BertTokenizer
# 用于加载 BERT 序列分类器
from transformers import BertForSequenceClassification


# 从 ModelScope 上下载 
# from modelscope import snapshot_download
# 设置 模型id model_id
# 设置 cache_dir 缓存目录
# model_dir = snapshot_download(model_id='tiansz/bert-base-chinese', cache_dir="./bert")


# 模型地址
model_dir = "bert-base-chinese"


# 加载分词器
tokenizer = BertTokenizer.from_pretrained(model_dir)


tokenizer


device = "cuda:0" if torch.cuda.is_available() else "cpu"
device


# 二分类分类器
model = BertForSequenceClassification.from_pretrained(model_dir, num_labels=2)
model.to(device = device)


num_params = 0
for param in model.parameters():
    if param.requires_grad:
        num_params += param.numel()
num_params


# 冻结所有预训练层
# for param in model.parameters():
#     param.requires_grad=False
# 只训练分类层
# model.classifier.weight.requires_grad=True
# model.classifier.bias.requires_grad=True


num_params = 0
for param in model.parameters():
    if param.requires_grad:
        num_params += param.numel()
num_params


# 类别字典
label2idx = {"正面": 0, "负面": 1}
idx2label = {0: "正面", 1: "负面"}


4. 训练

from torch import nn
# 损失函数
loss_fn = nn.CrossEntropyLoss()
# 优化器
optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-4)
# 定义训练轮次
epochs = 5


def get_acc(dataloader):
    """
        计算准确率
    """
    # 设置为评估模式
    model.eval()
    accs = []
    # 构建一个无梯度的环境
    with torch.no_grad():
        # 逐个批次计算
        for X, y in train_dataloader:
            # 编码
            X = tokenizer.batch_encode_plus(batch_text_or_text_pairs=X, 
                                            padding=True, 
                                            truncation=True,
                                            max_length=100,
                                            return_tensors="pt")
            # 转张量
            y = torch.tensor(data=[label2idx.get(label) for label in y], dtype=torch.long).cuda()
            # 1. 正向传播
            y_pred = model(input_ids=X["input_ids"].to(device=device), 
                           attention_mask=X["attention_mask"].to(device=device))
            # 2. 计算准确率
            acc = (y_pred.logits.argmax(dim=-1) == y).to(dtype=torch.float).mean().item()
            accs.append(acc)
    return sum(accs) / len(accs)


def train():
    """
        训练过程
    """
    # 训练之前:先看看准确率
    train_acc = get_acc(train_dataloader)
    test_acc = get_acc(test_dataloader)
    print(f"初始:Train_Acc: {train_acc}, Test_Acc: {test_acc}")
    # 遍历每一轮
    for epoch in range(epochs):
        model.train()
        # 遍历每个批次
        for X, y in train_dataloader:
            # 编码
            X = tokenizer.batch_encode_plus(batch_text_or_text_pairs=X, 
                                            padding=True, 
                                            truncation=True,
                                            max_length=100,
                                            return_tensors="pt")
            # 转张量
            y = torch.tensor(data=[label2idx.get(label) for label in y], dtype=torch.long).cuda()
            
            # 1. 正向传播
            y_pred = model(input_ids=X["input_ids"].to(device=device), 
                           attention_mask=X["attention_mask"].to(device=device))
                       
            # 2. 计算损失
            loss = loss_fn(y_pred.logits, y)
            
            # 3. 反向传播
            loss.backward()
            
            # 4. 优化一步
            optimizer.step()
            
            # 5. 清空梯度
            optimizer.zero_grad()
        # 每轮都计算一下准备率
        train_acc = get_acc(train_dataloader)
        test_acc = get_acc(test_dataloader)
        print(f"Epoch: {epoch +1}, Train_Acc: {train_acc}, Test_Acc: {test_acc}")



train()

5. 保存模型

# 保存训练好的模型
model.save_pretrained(save_directory="./sentiment_model")
# 保存分词器
tokenizer.save_pretrained(save_directory="./sentiment_model")

6. 预测

# 加载分词器
tokenizer = BertTokenizer.from_pretrained("./sentiment_model")
# 加载模型
model = BertForSequenceClassification.from_pretrained("./sentiment_model").cuda()


def predict(text="服务挺好的"):
    # 设置为评估模式
    model.eval()
    with torch.no_grad():
        inputs = tokenizer(text=text,
                           padding=True, 
                           truncation=True,
                           max_length=100,
                           return_tensors="pt")
        y_pred = model(input_ids=inputs["input_ids"].to(device=device), 
                       attention_mask=inputs["attention_mask"].to(device=device))
        y_pred = y_pred.logits.argmax(dim=-1).cpu().numpy()
        result = idx2label.get(y_pred[0])
        return result


predict(text="有老鼠,厕所脏")


http://www.kler.cn/a/444123.html

相关文章:

  • 厦门凯酷全科技有限公司短视频带货可靠吗?
  • Pytorch | 从零构建Vgg对CIFAR10进行分类
  • 浏览器要求用户确认 Cookies Privacy(隐私相关内容)是基于隐私法规的要求,VUE 实现,html 代码
  • AI的进阶之路:从机器学习到深度学习的演变(二)
  • OpenCV圆形标定板检测算法findGrid原理详解
  • SpringBoot开发——整合JSONPath解析JSON信息
  • 2024 年的科技趋势
  • vue-cli 5接入模块联邦 module federation
  • 【GO环境安装】mac系统+GoLand使用
  • nginx 记录完整的 request 及 response
  • 使用JustAuth实现gittee登录
  • 中型项目下的 MySQL 挑战与应对
  • 利用Python爬虫实现数据收集与挖掘
  • 音视频入门基础:MPEG2-TS专题(18)——PES流简介
  • HTML基本标签详解
  • MySQL Explain 分析SQL语句性能
  • PostgreSQL: 事务年龄
  • python学opencv|读取图像(十四)BGR图像和HSV图像通道拆分
  • 医院与医疗设备供应商网络安全事故综述
  • ElasticSearch学习7
  • 粘包由应用层协议解决
  • django模型——ORM模型2
  • 校园点餐订餐外卖跑腿Java源码
  • 基于单片机的电能数据采集系统(论文+图纸)
  • 文件包含 0 1学习
  • 【Prompt Engineering】1.编写 Prompt 的原则