当前位置: 首页 > article >正文

huggingface利用bert-base-chinese实现中文情感分类

利用pytorch模式

先做一些数据预处理工作,本文主要使用的数据集是lansinuote/ChnSentiCorp

from transformers import BertTokenizer
token = BertTokenizer.from_pretrained('bert-base-chinese')

import torch
from datasets import load_dataset

dataset = load_dataset('lansinuote/ChnSentiCorp')
print(type(dataset))
class Dataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        text = self.dataset[idx]['text']
        label = self.dataset[idx]['label']
        return text, label
dataset = Dataset(dataset['train'])

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def collate_fn(data):
    sents = [i[0] for i in data]
    labels = [i[1] for i in data]

    #编码
    data = token.batch_encode_plus(batch_text_or_text_pairs=sents,
                                   truncation=True,
                                   padding='max_length',
                                   max_length=500,
                                   return_tensors='pt',
                                   return_length=True)

    #input_ids:编码之后的数字
    #attention_mask:是补零的位置是0,其他位置是1
    input_ids = data['input_ids'].to(device)
    attention_mask = data['attention_mask'].to(device)
    token_type_ids = data['token_type_ids'].to(device)
    labels = torch.LongTensor(labels).to(device)

    #print(data['length'], data['length'].max())

    return input_ids, attention_mask, token_type_ids, labels

loader = torch.utils.data.DataLoader(dataset, batch_size=32, collate_fn=collate_fn, shuffle=True, drop_last=True)
len(loader)  # 计算数据集的批次数

引入bert-base-chinese模型

from transformers import BertModel

pretrained = BertModel.from_pretrained('bert-base-chinese').to(device)
sum(i.numel() for i in pretrained.parameters())/1e6  # 计算模型参数总数

for param in pretrained.parameters():
    param.requires_grad = False  # 冻结参数

模型后面添加几个层

class Model(torch.nn.Module):
    def __init__(self, pretrained):
        super(Model, self).__init__()
        self.bert = pretrained
        self.fn1 = torch.nn.Linear(768, 256)
        self.relu = torch.nn.ReLU()
        self.fn2 = torch.nn.Linear(256, 768)
        self.classifier = torch.nn.Linear(768, 2)  # 768是BERT的输出维度,2是分类数

    def forward(self, input_ids, attention_mask, token_type_ids):
        with torch.no_grad():
            output = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        #加两个线性层加一个ReLU激活
        output = self.fn1(output.last_hidden_state[:,0])
        output = self.relu(output)
        output = self.fn2(output)
        out = self.classifier(output)
        return out

定义训练器

from transformers import AdamW
from transformers.optimization import get_scheduler

def train():
    optimizer = AdamW(model.parameters(), lr=1e-5)
    criterion = torch.nn.CrossEntropyLoss()
    scheduler = get_scheduler("linear", optimizer=optimizer, 
                              num_training_steps=len(loader)*3,
                              num_warmup_steps=0)
    model.train()
    for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
        optimizer.zero_grad()
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()
        if i % 10 == 0:
            out = outputs.argmax(dim=1)
            accuracy = (out == labels).sum().item() / len(labels)
            lr = optimizer.state_dict()['param_groups'][0]['lr']
            print(i, loss.item(), accuracy, lr)

开始训练

train()  # 开始训练

测试

def test():
    loader_test = torch.utils.data.DataLoader(
        Dataset(load_dataset('lansinuote/ChnSentiCorp')['test']),
        batch_size=32,
        collate_fn=collate_fn,
        shuffle=True,
        drop_last=True
    )
    model.eval()
    correct = 0
    total = 0
    for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader_test):
        if i == 5: break  # 只测试前5个批次
        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        out = outputs.argmax(dim=1)
        correct += (out == labels).sum().item()
        total += len(labels)
    print('Accuracy:', correct / total)
test()  # 开始测试

利用transformers的工具

数据集是从huggingface下载的,无需进入Dataset类进行额外变换,只需要做一些简单的预处理

import torch
from datasets import load_dataset
dataset = load_dataset('lansinuote/ChnSentiCorp')
dataset['train'] = dataset['train'].shuffle().select(range(2000))
dataset['test'] = dataset['test'].shuffle().select(range(100))
def f(data):
    return token.batch_encode_plus(data['text'], truncation=True, max_length=512)
dataset = dataset.map(f, batched=True, remove_columns=['text'], batch_size=1000, num_proc=3)
def f(data):
    return [len(i) <= 512 for i in data['input_ids']]
dataset = dataset.filter(f, batched=True, num_proc=3, batch_size=1000)

引入模型并添加几层

from transformers import BertModel

pretrained = BertModel.from_pretrained('bert-base-chinese')
sum(i.numel() for i in pretrained.parameters())/1e6  # 计算模型参数总数
for param in pretrained.parameters():
    param.requires_grad = False  # 冻结参数
    
import torch
from transformers import BertModel

class Model(torch.nn.Module):
    def __init__(self, pretrained):
        super(Model, self).__init__()
        self.bert = pretrained
        self.fn1 = torch.nn.Linear(768, 256)
        self.relu = torch.nn.ReLU()
        self.fn2 = torch.nn.Linear(256, 768)
        self.classifier = torch.nn.Linear(768, 2)  # 768是BERT的输出维度,2是分类数

    def forward(self, input_ids, attention_mask, token_type_ids, labels=None):
        with torch.no_grad():
            output = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        cls_output = output.last_hidden_state[:, 0]  # 获取[CLS]的输出
        output = self.fn1(cls_output)
        output = self.relu(output)
        output = self.fn2(output)
        logits = self.classifier(output)  # 输出 logits

        loss = None
        if labels is not None:
            loss_fn = torch.nn.CrossEntropyLoss()
            loss = loss_fn(logits, labels)  # 计算损失

        return (loss, logits) if loss is not None else logits

注意在forward函数中我多加了个参数,labels,因为数据集里面是携带labels的,而且huggingface的特定任务模型也是接受labels这个参数的,如果不加可能不适应huggingface的trainer的调用。

评估函数和训练函数

import evaluate
metric = evaluate.load("accuracy")
import numpy as np
from transformers.trainer_utils import EvalPrediction

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=1)
    acc = metric.compute(predictions=predictions, references=labels)
    return acc

# 定义训练函数
from transformers import Trainer, TrainingArguments

# 参数
training_args = TrainingArguments(
    output_dir="./output_dir",
    evaluation_strategy="steps",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=2,
    weight_decay=0.01,
    eval_steps=20,
    no_cuda=True,
    report_to='none',
)
# 训练器
from transformers import Trainer
from transformers import DataCollatorWithPadding

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset['train'],
    eval_dataset=dataset['test'],
    data_collator=DataCollatorWithPadding(token),
    compute_metrics=compute_metrics,
)

训练和评估

trainer.train()
trainer.evaluate()

http://www.kler.cn/a/378112.html

相关文章:

  • Golang的容器化技术实践总结
  • 四种电子杂志制作软件
  • 编译原理复习---正则表达式+有穷自动机
  • ROS1入门教程6:复杂行为处理
  • 开发一个DApp项目:DeFi、DApp开发与公链DApp开发
  • 【恶意软件检测】一种基于API语义提取的Android恶意软件检测方法(期刊等级:CCF-B、Q2)
  • Jenkins面试整理-如何在 Jenkins 中配置构建任务?
  • 新生代对象垃圾回收如何避免全堆扫描
  • 70B的模型需要多少张A10的卡可以部署成功,如果使用vLLM
  • 基于 Java 语言双代号网络图自动绘制系统
  • Vue 事件阻止 e.preventDefault();click.prevent
  • 使用GraphQL构建高效API
  • ArcGIS API for JavaScript 基础应用+实例展示+水波纹特效
  • SpringBoot整合minio服务
  • windows下用CMake构建使用protobuf的应用,编译使用VS2022
  • LeetCode 3226. 使两个整数相等的位更改次数
  • UML介绍-不同类间关系
  • 【Linux】从零开始使用多路转接IO --- poll
  • 利用 Direct3D 绘制几何体—8.光栅器状态
  • 刘艳兵-DBA021-升级到Oracle Database 12c时,关于使用Export/Import方法迁移数据的说法是正确的?
  • 第三次RHCSA作业
  • 【vue】11.Vue 3生命周期钩子在实践中的具体应用
  • 《JVM第1课》Java 跨平台原理
  • qt QScrollArea详解
  • Git 的特殊配置文件
  • FPGA实现串口升级及MultiBoot(十一)QuickBoot介绍