使用BERT模型微调二分类任务
BERT模型微调二分类任务
准备
工具包:torch 2.6.0+cu126 + transformers4.49.0.dev0
环境:ubuntu20.04 + GeForce RTX3090
微调数据集:IDEA-CCNL/AFQMC
bert模型:google-bert/bert-base-chinese
复现时间:约 1 h
数据类
这里继承Dataset类来实现数据加载功能,代码如下:
from torch.utils.data import Dataset
import json
_TRAIN_DATASET_PATH = "datasets--IDEA-CCNL--AFQMC/snapshots/fd907148d4cfaaadad98cd8d39b967ecf95bd094/afqmc_public/train.json"
_TEST_DATASET_PATH = "datasets--IDEA-CCNL--AFQMC/snapshots/fd907148d4cfaaadad98cd8d39b967ecf95bd094/afqmc_public/dev.json"
class _AFQMC_dataset(Dataset):
def __init__(self,data_file):
self.data = self.load_data(data_file)
def load_data(self,data_file):
Data = {}
with open(data_file,'rt') as f:
for idx, line in enumerate(f):
sample = json.loads(line.strip())
Data[idx] =sample
return Data
def __len__(self):
return len(self.data)
def __getitem__(self,idx):
return self.data[idx]
train_dataset = _AFQMC_dataset(_TRAIN_DATASET_PATH)
test_dataset = _AFQMC_dataset(_TEST_DATASET_PATH)
加载器类(实现batch加载和数据的预处理)
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese") # 分词器
def collate_fn(batch_samples):
batch_sentence1 = []
batch_sentence2 = []
batch_label = []
for sample in batch_samples:
batch_sentence1.append(sample['sentence1'])
batch_sentence2.append(sample['sentence2'])
batch_label.append(int(sample['label']))
X = tokenizer(
batch_sentence1,
batch_sentence2,
padding=True,
truncation=True,
return_tensors='pt'
)
y = torch.tensor(batch_label)
return X, y
from AFQMC_dataset import train_dataset,test_dataset
train_dataLoader = DataLoader(
train_dataset,
batch_size=8,
shuffle=True,
collate_fn=collate_fn
)
val_dataLoader = DataLoader(
test_dataset,
batch_size=8,
shuffle=True,
collate_fn=collate_fn
)
继承Bert类(增加线性层用于适配分类)
import torch.nn as nn
from transformers import BertPreTrainedModel,BertModel
from transformers import AutoConfig
class BertClassification(BertPreTrainedModel):
def __init__(self,config):
super().__init__(config)
self.bert_encoder = BertModel(config,add_pooling_layer=False)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifer = nn.Linear(768,2) # 两类
self.post_init()
def forward(self,x):
bert_output = self.bert_encoder(**x)
cls_vectors = bert_output.last_hidden_state[:,0,:]
cls_vectors = self.dropout(cls_vectors)
logits = self.classifer(cls_vectors)
return logits
config = AutoConfig.from_pretrained("bert-base-chinese")
model = BertClassification.from_pretrained("bert-base-chinese",config=config).to("cuda")
训练与验证
from bert_classification import BertClassification
from AFQMC_dataloader import batch_X,batch_y
from tqdm.auto import tqdm
def train_loop(dataloader,model,loss_fn,optimizer,lr_scheduler,epoch,total_loss):
progress_bar =tqdm(range(len(dataloader)))
progress_bar.set_description(f"loss: {0:>7f}")
finish_step_num = (epoch - 1)*len(dataloader)
model.train()
for step,(X,y) in enumerate(dataloader,start=1):
X,y = X.to("cuda"),y.to("cuda")
pred = model(X)
loss = loss_fn(pred,y)
optimizer.zero_grad() # 梯度清零
loss.backward()
optimizer.step()
lr_scheduler.step()
total_loss += loss.item()
progress_bar.set_description(f"loss: {total_loss/(step+finish_step_num):>7f}")
progress_bar.update(1)
return total_loss
def val_loop(data_loader,model,mode = 'Valid'):
assert mode in ['Test','Valid']
size = len(data_loader.dataset)
correct =0
import torch
model.eval()
with torch.no_grad():
for X,y in data_loader:
X,y = X.to("cuda"),y.to("cuda")
pred = model(X)
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
correct /= size
print(f"{mode}: Accuracy: {(100*correct):>0.1f}%")
return correct
from transformers import get_scheduler
epoch = 3
from AFQMC_dataloader import train_dataLoader,val_dataLoader
num_training_steps = epoch * len(train_dataLoader)
print(num_training_steps)
from transformers import AdamW
from bert_classification import model
optimizer = AdamW(model.parameters(), lr=5e-5)
lr_scheduler = get_scheduler(
name="linear",
optimizer=optimizer,
num_warmup_steps=0,
num_training_steps=num_training_steps,
)
import torch.nn as nn
loss_fn = nn.CrossEntropyLoss()
total_loss = 0
best_acc = 0
for epoch_idx in range(epoch):
print(f"epoch {epoch_idx+1}/{epoch}\n=============================")
total_loss = train_loop(train_dataLoader,model,loss_fn,optimizer,lr_scheduler,epoch_idx+1,total_loss)
valid_acc = val_loop(val_dataLoader,model,mode="Valid")
if valid_acc > best_acc:
best_acc = valid_acc
print("saving new weigths...\n")
import torch
torch.save(model.state_dict(),f"epoch_{epoch_idx+1}_valid_acc_{(100*valid_acc):0.1f}_model_weights.bin")