BERT的中文问答系统61
改进和完善后的BERT的中文问答系统60代码,涵盖了错误处理、性能优化、用户体验、功能增强、安全性、可扩展性和模块化、以及文档和注释等方面:
import os
import json
import jsonlines
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
import tkinter as tk
from tkinter import filedialog, messagebox, ttk, simpledialog
import logging
from difflib import SequenceMatcher
from datetime import datetime
import requests
from bs4 import BeautifulSoup
import tkcalendar
import locale
import threading
import configparser
# 设置本地化为中文
locale.setlocale(locale.LC_ALL, 'zh_CN.UTF-8')
# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
# 配置日志
LOGS_DIR = os.path.join(PROJECT_ROOT, 'logs')
os.makedirs(LOGS_DIR, exist_ok=True)
def setup_logging():
log_file = os.path.join(LOGS_DIR, datetime.now().strftime('%Y-%m-%d_%H-%M-%S_羲和.txt'))
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler()
]
)
setup_logging()
# 数据集类
class XihuaDataset(Dataset):
def __init__(self, file_path, tokenizer, max_length=128):
self.tokenizer = tokenizer
self.max_length = max_length
self.data = self.load_data(file_path)
def load_data(self, file_path):
data = []
if file_path.endswith('.jsonl'):
with jsonlines.open(file_path) as reader:
for i, item in enumerate(reader):
try:
data.append(item)
except jsonlines.jsonlines.InvalidLineError as e:
logging.warning(f"跳过无效行 {
i + 1}: {
e}")
elif file_path.endswith('.json'):
with open(file_path, 'r') as f:
try:
data = json.load(f)
except json.JSONDecodeError as e:
logging.warning(f"跳过无效文件 {
file_path}: {
e}")
return data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
question = item.get('question', '')
human_answer = item.get('human_answers', [''])[0]
chatgpt_answer = item.get('chatgpt_answers', [''])[0]
try:
inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
human_inputs = self.tokenizer(human_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
chatgpt_inputs = self.tokenizer(chatgpt_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
except Exception as e:
logging.warning(f"跳过无效项 {
idx}: {
e}")
return self.__getitem__((idx + 1) % len(self.data))
return {
'input_ids': inputs['input_ids'].squeeze(),
'attention_mask': inputs['attention_mask'].squeeze(),
'human_input_ids': human_inputs['input_ids'].squeeze(),
'human_attention_mask': human_inputs['attention_mask'].squeeze(),
'chatgpt_input_ids': chatgpt_inputs['input_ids'].squeeze(),
'chatgpt_attention_mask': chatgpt_inputs['attention_mask'].squeeze(),
'human_answer': human_answer,
'chatgpt_answer': chatgpt_answer
}
# 获取数据加载器
def get_data_loader(file_path, tokenizer, batch_size=8, max_length=128):
dataset = XihuaDataset(file_path, tokenizer, max_length)
return DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 模型定义
class XihuaModel(torch.nn.Module):
def __init__(self, pretrained_model_name):
super(XihuaModel, self).__init__()
self.bert = BertModel.from_pretrained(pretrained_model_name)
self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 1)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.pooler_output
logits = self.classifier(pooled_output)
return logits
# 训练函数
def train(model, data_loader, optimizer, criterion, device, progress_var=None):
model.train()
total_loss = 0.0
num_batches = len(data_loader)
for batch_idx, batch in enumerate(data_loader):
try:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
human_input_ids = batch['human_input_ids'].to(device)
human_attention_mask = batch['human_attention_mask'].to(device)
chatgpt_input_ids = batch['chatgpt_input_ids'].to(device)
chatgpt_attention_mask = batch['chatgpt_attention_mask'].to(device)
optimizer.zero_grad()
human_logits = model(human_input_ids, human_attention_mask)
chatgpt_logits = model(chatgpt_input_ids, chatgpt_attention_mask)
human_labels = torch.ones(human_logits.size(0), 1).to(device)
chatgpt_labels = torch.zeros(chatgpt_logits.size(0), 1).to(device)
loss = criterion(human_logits, human_labels) + criterion(chatgpt_logits, chatgpt_labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
if progress_var:
progress_var.set((batch_idx + 1) / num_batches * 100)
except Exception as e:
logging.warning(f"跳过无效批次: {
e}")
return total_loss / len(data_loader)
# 模型评估函数
def evaluate_model(model, data_loader, device):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for batch in data_loader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
human_input_ids = batch['human_input_ids'].to(device)
human_attention_mask = batch['human_attention_mask'].to(device)
chatgpt_input_ids = batch['chatgpt_input_ids'].to(device)
chatgpt_attention_mask = batch['chatgpt_attention_mask'].to(device)
human_logits = model(human_input_ids, human_attention_mask)
chatgpt_logits = model(chatgpt_input_ids, chatgpt_attention_mask)
human_labels = torch.ones(human_logits.size(0), 1).to(device)
chatgpt_labels = torch.zeros(chatgpt_logits.size(<