BERT的中文问答系统69
使用说明:
启动程序:运行脚本后,会启动一个图形界面。
输入问题:在“问题”输入框中输入您的问题,然后点击“获取回答”按钮,羲和将为您提供答案。
评价回答:如果您认为回答准确,请点击“准确”按钮;如果不准确,请点击“不准确”按钮。
查看历史记录:点击“查看历史记录”按钮可以查看之前的聊天记录。
保存历史记录:点击“保存历史记录”按钮可以将聊天记录保存到文件。
训练模型:点击“训练模型”或“重新训练模型”按钮可以对模型进行训练或重新训练。
评估模型:点击“评估模型”按钮可以评估模型的准确率。
使用说明:点击“使用说明”按钮可以查看详细的使用说明。
数据收集:点击“收集数据”按钮可以手动收集数据,输入问题和答案后,数据将保存到相应的数据文件中。
import os
import json
import jsonlines
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
import tkinter as tk
from tkinter import filedialog, messagebox, ttk, simpledialog
import logging
from difflib import SequenceMatcher
from datetime import datetime
import requests
from bs4 import BeautifulSoup
import tkcalendar
import locale
import threading
import configparser
# 设置本地化为中文
locale.setlocale(locale.LC_ALL, 'zh_CN.UTF-8')
# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
# 配置日志
LOGS_DIR = os.path.join(PROJECT_ROOT, 'logs')
os.makedirs(LOGS_DIR, exist_ok=True)
def setup_logging():
log_file = os.path.join(LOGS_DIR, datetime.now().strftime('%Y-%m-%d_%H-%M-%S_羲和.txt'))
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler()
]
)
setup_logging()
# 数据集类
class XihuaDataset(Dataset):
def __init__(self, file_path, tokenizer, max_length=128):
self.tokenizer = tokenizer
self.max_length = max_length
self.data = self.load_data(file_path)
def load_data(self, file_path):
try:
if file_path.endswith('.jsonl'):
with jsonlines.open(file_path) as reader:
return [item for item in reader]
elif file_path.endswith('.json'):
with open(file_path, 'r') as f:
return json.load(f)
except (json.JSONDecodeError, jsonlines.jsonlines.InvalidLineError) as e:
logging.warning(f"加载数据失败: {
e}")
return []
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
question = item.get('question', '')
human_answer = item.get('human_answers', [''])[0]
chatgpt_answer = item.get('chatgpt_answers', [''])[0]
try:
inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
human_inputs = self.tokenizer(human_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
chatgpt_inputs = self.tokenizer(chatgpt_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
except Exception as e:
logging.warning(f"跳过无效项 {
idx}: {
e}")
return self.__getitem__((idx + 1) % len(self.data))
return {
'input_ids': inputs['input_ids'].squeeze(),
'attention_mask': inputs['attention_mask'].squeeze(),
'human_input_ids': human_inputs['input_ids'].squeeze(),
'human_attention_mask': human_inputs['attention_mask'].squeeze(),
'chatgpt_input_ids': chatgpt_inputs['input_ids'].squeeze(),
'chatgpt_attention_mask': chatgpt_inputs['attention_mask'].squeeze(),
'human_answer': human_answer,
'chatgpt_answer': chatgpt_answer
}
# 获取数据加载器
def get_data_loader(file_path, tokenizer, batch_size=8, max_length=128):
dataset = XihuaDataset(file_path, tokenizer, max_length)
return DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 模型定义
class XihuaModel(torch.nn.Module):
def __init__(self, pretrained_model_name):
super(XihuaModel, self).__init__()
self.bert = BertModel.from_pretrained(pretrained_model_name)
self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 1)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.pooler_output
logits = self.classifier(pooled_output)
return logits
# 训练函数
def train(model, data_loader, optimizer, criterion, device, progress_var=None):
model.train()
total_loss = 0.0
num_batches = len(data_loader)
for batch_idx, batch in enumerate(data_loader):
try:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
human_input_ids = batch['human_input_ids'].to(device)
human_attention_mask = batch['human_attention_mask'].to(device)
chatgpt_input_ids = batch['chatgpt_input_ids'].to(device)
chatgpt_attention_mask = batch['chatgpt_attention_mask'].to(device)
optimizer.zero_grad()
human_logits = model(human_input_ids, human_attention_mask)
chatgpt_logits = model(chatgpt_input_ids, chatgpt_attention_mask)
human_labels = torch.ones(human_logits.size(0), 1).to(device)
chatgpt_labels = torch.zeros(chatgpt_logits.size(0), 1).to(device)
loss = criterion(human_logits, human_labels) + criterion(chatgpt_logits, chatgpt_labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
if progress_var:
progress_var.set((batch_idx + 1) / num_batches * 100)
except Exception as e:
logging.warning(f"跳过无效批次: {
e}")
return total_loss / len(data_loader)
# 模型评估函数
def evaluate_model(model, data_loader, device):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for batch in data_loader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
human_input_ids = batch['human_input_ids'].to(device)
human_attention_mask = batch['human_attention_mask'].to(device)
chatgpt_input_ids = batch['chatgpt_input_ids'].to(device)
chatgpt_attention_mask = batch['chatgpt_attention_mask'].to(device)
human_logits = model(human_input_ids, human_attention_mask)
chatgpt_logits = model(chatgpt_input_ids, chatgpt_attention_mask)
human_labels = torch.ones(human_logits.size(0), 1).to(device)
chatgpt_labels = torch.zeros(chatgpt_logits.size(0), 1).to(device)
human_correct = (torch.sigmoid(human_logits) > 0.5).float() == human_labels
chatgpt_correct = (torch.sigmoid(chatgpt_logits) > 0.5).float() == chatgpt_labels
correct += human_correct.sum().item() + chatgpt_correct.sum().item()
total += human_labels.size(0) + chatgpt_labels.size(0)
accuracy = correct / total
return accuracy
# 网络搜索函数
def search_baidu(query):
try:
url = f"https://www.baidu.com/s?wd={
query}"
headers = {