当前位置: 首页 > article >正文

使用BERT模型微调二分类任务

BERT模型微调二分类任务

准备

工具包:torch 2.6.0+cu126 + transformers4.49.0.dev0
环境:ubuntu20.04 + GeForce RTX3090
微调数据集:IDEA-CCNL/AFQMC
bert模型:google-bert/bert-base-chinese
复现时间:约 1 h

数据类

这里继承Dataset类来实现数据加载功能,代码如下:

from torch.utils.data import Dataset
import json
_TRAIN_DATASET_PATH = "datasets--IDEA-CCNL--AFQMC/snapshots/fd907148d4cfaaadad98cd8d39b967ecf95bd094/afqmc_public/train.json"
_TEST_DATASET_PATH = "datasets--IDEA-CCNL--AFQMC/snapshots/fd907148d4cfaaadad98cd8d39b967ecf95bd094/afqmc_public/dev.json"

class _AFQMC_dataset(Dataset):
    def __init__(self,data_file):
        self.data = self.load_data(data_file)

    def load_data(self,data_file):
        Data = {}
        with open(data_file,'rt') as f:
            for idx, line in enumerate(f):
                sample = json.loads(line.strip())
                Data[idx] =sample
        return Data

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self,idx):
        return self.data[idx]

train_dataset = _AFQMC_dataset(_TRAIN_DATASET_PATH)
test_dataset = _AFQMC_dataset(_TEST_DATASET_PATH)

加载器类(实现batch加载和数据的预处理)

from torch.utils.data import DataLoader
from transformers import AutoTokenizer
import torch
tokenizer  = AutoTokenizer.from_pretrained("bert-base-chinese") # 分词器

def collate_fn(batch_samples):
    batch_sentence1 = []
    batch_sentence2 = []
    batch_label = []
    for sample in batch_samples:
        batch_sentence1.append(sample['sentence1'])
        batch_sentence2.append(sample['sentence2'])
        batch_label.append(int(sample['label']))
    X = tokenizer(
        batch_sentence1,
        batch_sentence2,
        padding=True,
        truncation=True,
        return_tensors='pt'
    )
    y = torch.tensor(batch_label)
    return X, y


from AFQMC_dataset import train_dataset,test_dataset
train_dataLoader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    collate_fn=collate_fn
)

val_dataLoader = DataLoader(
    test_dataset,
    batch_size=8,
    shuffle=True,
    collate_fn=collate_fn
)

继承Bert类(增加线性层用于适配分类)

import torch.nn as nn
from transformers import BertPreTrainedModel,BertModel
from transformers import AutoConfig
class BertClassification(BertPreTrainedModel):
    def __init__(self,config):
        super().__init__(config)
        self.bert_encoder = BertModel(config,add_pooling_layer=False)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifer = nn.Linear(768,2) # 两类
        self.post_init()
    def forward(self,x):
        bert_output = self.bert_encoder(**x)
        cls_vectors = bert_output.last_hidden_state[:,0,:]
        cls_vectors = self.dropout(cls_vectors)
        logits = self.classifer(cls_vectors)
        return logits
    

config = AutoConfig.from_pretrained("bert-base-chinese")
model = BertClassification.from_pretrained("bert-base-chinese",config=config).to("cuda")

训练与验证

from bert_classification import BertClassification
from AFQMC_dataloader import batch_X,batch_y
from tqdm.auto import tqdm
def train_loop(dataloader,model,loss_fn,optimizer,lr_scheduler,epoch,total_loss):
    progress_bar =tqdm(range(len(dataloader)))
    progress_bar.set_description(f"loss: {0:>7f}")
    finish_step_num = (epoch - 1)*len(dataloader)

    model.train()
    for step,(X,y) in enumerate(dataloader,start=1):
        X,y = X.to("cuda"),y.to("cuda")
        pred = model(X)
        loss = loss_fn(pred,y)
        optimizer.zero_grad() # 梯度清零
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        total_loss += loss.item()
        progress_bar.set_description(f"loss: {total_loss/(step+finish_step_num):>7f}")
        progress_bar.update(1)
    return total_loss


def val_loop(data_loader,model,mode = 'Valid'):
    assert mode in ['Test','Valid']
    size = len(data_loader.dataset)
    correct =0 
    import torch
    model.eval()
    with torch.no_grad():
        for X,y in data_loader:
            X,y = X.to("cuda"),y.to("cuda")
            pred = model(X)
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    correct /= size
    print(f"{mode}: Accuracy: {(100*correct):>0.1f}%")
    return correct


from transformers import get_scheduler
epoch = 3
from AFQMC_dataloader import train_dataLoader,val_dataLoader
num_training_steps = epoch * len(train_dataLoader)
print(num_training_steps)
from transformers import AdamW
from bert_classification import model
optimizer = AdamW(model.parameters(), lr=5e-5)

lr_scheduler = get_scheduler(
    name="linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

import torch.nn as nn
loss_fn = nn.CrossEntropyLoss()
total_loss = 0
best_acc = 0
for epoch_idx in range(epoch):
    print(f"epoch {epoch_idx+1}/{epoch}\n=============================")
    total_loss = train_loop(train_dataLoader,model,loss_fn,optimizer,lr_scheduler,epoch_idx+1,total_loss)
    valid_acc = val_loop(val_dataLoader,model,mode="Valid")
    if valid_acc > best_acc:
        best_acc = valid_acc
        print("saving new weigths...\n")
        import torch
        torch.save(model.state_dict(),f"epoch_{epoch_idx+1}_valid_acc_{(100*valid_acc):0.1f}_model_weights.bin")

运行状态

在这里插入图片描述

在这里插入图片描述


http://www.kler.cn/a/553100.html

相关文章:

  • 意图识别概述
  • JetBrains 2024开发者生态报告 -你尝试过用VR头戴设备编程吗
  • 线性模型 - Softmax 回归(参数学习)
  • idea连接gitee(使用idea远程兼容gitee)
  • ollama使用教程
  • 解决MySQL错误:You can‘t specify target table ‘xxx‘ for update in FROM clause
  • Zbrush快捷键
  • 深入解析C2远程控制原理(内含常见C2框架对比)
  • 23种设计模式 - 工厂方法模式
  • DeepSeek、微信、硅基流动、纳米搜索、秘塔搜索……十种不同方法实现DeepSeek使用自由
  • 工业安全的智能哨兵:AI如何筑起生产线的“数字防火墙“
  • 【Hugging Face系列篇】01初步介绍:“AI界的GitHub”
  • 【C语言】C语言编译流程
  • Java代理模式详解:从原理到实践
  • 整理一些安装环境的常用命令
  • Huatuo热更新--如何使用
  • 优先队列(典型算法思想)—— OJ例题算法解析思路
  • 打破限制!自定义 Hooks 如何提升 React 组件的灵活性
  • 用户坐标系(ucs)与系统坐标系(wcs)的转换详解——CAD c#二次开发
  • 【AI工程实践】阅文集团:NLP在网络文学领域的应用