python pytorch tensorflow transforms 模型培训脚本
环境准备
https://www.doubao.com/thread/w5e26d6401c003bb2
执行培训脚本
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, AdamW
import numpy as np
# 自定义数据集类
class SentimentDataset(Dataset):
def __init__(self, texts, labels, tokenizer, max_length):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
label = self.labels[idx]
encoding = self.tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
return {
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'labels': torch.tensor(label, dtype=torch.long)
}
# 新增问答数据
qa_texts = ["李白是那个朝代的诗人?", "地球的卫星是什么?", "中国的首都是哪里?", "水的化学式是什么?", "苹果公司的创始人是谁?"]
qa_labels = [1, 2, 3, 4, 5] # 为每个问题分配一个唯一的标签
# 合并数据
all_texts = qa_texts
all_labels = qa_labels
# 初始化 tokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
# 创建数据集和数据加载器
dataset = SentimentDataset(all_texts, all_labels, tokenizer, max_length=128)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
# 初始化模型,注意 num_labels 需要根据总标签数调整
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=8)
# 定义优化器,加入 weight_decay 进行正则化
optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
# 训练模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
# 早停策略相关参数
best_loss = float('inf')
patience = 3
counter = 0
for epoch in range(20): # 增加训练轮数,但结合早停策略
total_loss = 0
model.train()
for batch in dataloader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
optimizer.zero_grad()
outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
print(f'Epoch {epoch + 1}, Loss: {avg_loss}')
# 早停策略
if avg_loss < best_loss:
best_loss = avg_loss
counter = 0
else:
counter += 1
if counter >= patience:
print("Early stopping triggered!")
break
# 处理提问的函数
def handle_query(model, tokenizer, query_text, device):
model.eval()
with torch.no_grad():
encoding = tokenizer(query_text, return_tensors='pt', padding=True, truncation=True, max_length=128).to(device)
logits = model(**encoding).logits
predictions = np.argmax(logits.cpu().numpy(), axis=1)
return predictions
# 情感标签映射
label_mapping = {
1: "唐朝",
2: "月球",
3: "北京",
4: "H2O",
5: "史蒂夫·乔布斯、史蒂夫·沃兹尼亚克和罗恩·韦恩"
}
# 针对新增的5条数据提问
query_text = ["李白是那个朝代的诗人?", "地球的卫星是什么?", "中国的首都是哪里?", "水的化学式是什么?", "苹果公司的创始人是谁?"]
result = handle_query(model, tokenizer, query_text, device)
readable_result = [label_mapping[pred] for pred in result]
print("Query result:", readable_result)