当前位置: 首页 > article >正文

BERT的中文问答系统69

使用说明:
启动程序:运行脚本后,会启动一个图形界面。
输入问题:在“问题”输入框中输入您的问题,然后点击“获取回答”按钮,羲和将为您提供答案。
评价回答:如果您认为回答准确,请点击“准确”按钮;如果不准确,请点击“不准确”按钮。
查看历史记录:点击“查看历史记录”按钮可以查看之前的聊天记录。
保存历史记录:点击“保存历史记录”按钮可以将聊天记录保存到文件。
训练模型:点击“训练模型”或“重新训练模型”按钮可以对模型进行训练或重新训练。
评估模型:点击“评估模型”按钮可以评估模型的准确率。
使用说明:点击“使用说明”按钮可以查看详细的使用说明。
数据收集:点击“收集数据”按钮可以手动收集数据,输入问题和答案后,数据将保存到相应的数据文件中。

import os
import json
import jsonlines
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
import tkinter as tk
from tkinter import filedialog, messagebox, ttk, simpledialog
import logging
from difflib import SequenceMatcher
from datetime import datetime
import requests
from bs4 import BeautifulSoup
import tkcalendar
import locale
import threading
import configparser

# 设置本地化为中文
locale.setlocale(locale.LC_ALL, 'zh_CN.UTF-8')

# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))

# 配置日志
LOGS_DIR = os.path.join(PROJECT_ROOT, 'logs')
os.makedirs(LOGS_DIR, exist_ok=True)

def setup_logging():
    log_file = os.path.join(LOGS_DIR, datetime.now().strftime('%Y-%m-%d_%H-%M-%S_羲和.txt'))
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler()
        ]
    )

setup_logging()

# 数据集类
class XihuaDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_length=128):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = self.load_data(file_path)

    def load_data(self, file_path):
        try:
            if file_path.endswith('.jsonl'):
                with jsonlines.open(file_path) as reader:
                    return [item for item in reader]
            elif file_path.endswith('.json'):
                with open(file_path, 'r') as f:
                    return json.load(f)
        except (json.JSONDecodeError, jsonlines.jsonlines.InvalidLineError) as e:
            logging.warning(f"加载数据失败: {
     e}")
            return []

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        question = item.get('question', '')
        human_answer = item.get('human_answers', [''])[0]
        chatgpt_answer = item.get('chatgpt_answers', [''])[0]

        try:
            inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
            human_inputs = self.tokenizer(human_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
            chatgpt_inputs = self.tokenizer(chatgpt_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
        except Exception as e:
            logging.warning(f"跳过无效项 {
     idx}: {
     e}")
            return self.__getitem__((idx + 1) % len(self.data))

        return {
   
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'human_input_ids': human_inputs['input_ids'].squeeze(),
            'human_attention_mask': human_inputs['attention_mask'].squeeze(),
            'chatgpt_input_ids': chatgpt_inputs['input_ids'].squeeze(),
            'chatgpt_attention_mask': chatgpt_inputs['attention_mask'].squeeze(),
            'human_answer': human_answer,
            'chatgpt_answer': chatgpt_answer
        }

# 获取数据加载器
def get_data_loader(file_path, tokenizer, batch_size=8, max_length=128):
    dataset = XihuaDataset(file_path, tokenizer, max_length)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 模型定义
class XihuaModel(torch.nn.Module):
    def __init__(self, pretrained_model_name):
        super(XihuaModel, self).__init__()
        self.bert = BertModel.from_pretrained(pretrained_model_name)
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        logits = self.classifier(pooled_output)
        return logits

# 训练函数
def train(model, data_loader, optimizer, criterion, device, progress_var=None):
    model.train()
    total_loss = 0.0
    num_batches = len(data_loader)
    for batch_idx, batch in enumerate(data_loader):
        try:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            human_input_ids = batch['human_input_ids'].to(device)
            human_attention_mask = batch['human_attention_mask'].to(device)
            chatgpt_input_ids = batch['chatgpt_input_ids'].to(device)
            chatgpt_attention_mask = batch['chatgpt_attention_mask'].to(device)

            optimizer.zero_grad()
            human_logits = model(human_input_ids, human_attention_mask)
            chatgpt_logits = model(chatgpt_input_ids, chatgpt_attention_mask)

            human_labels = torch.ones(human_logits.size(0), 1).to(device)
            chatgpt_labels = torch.zeros(chatgpt_logits.size(0), 1).to(device)

            loss = criterion(human_logits, human_labels) + criterion(chatgpt_logits, chatgpt_labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            if progress_var:
                progress_var.set((batch_idx + 1) / num_batches * 100)
        except Exception as e:
            logging.warning(f"跳过无效批次: {
     e}")

    return total_loss / len(data_loader)

# 模型评估函数
def evaluate_model(model, data_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            human_input_ids = batch['human_input_ids'].to(device)
            human_attention_mask = batch['human_attention_mask'].to(device)
            chatgpt_input_ids = batch['chatgpt_input_ids'].to(device)
            chatgpt_attention_mask = batch['chatgpt_attention_mask'].to(device)

            human_logits = model(human_input_ids, human_attention_mask)
            chatgpt_logits = model(chatgpt_input_ids, chatgpt_attention_mask)

            human_labels = torch.ones(human_logits.size(0), 1).to(device)
            chatgpt_labels = torch.zeros(chatgpt_logits.size(0), 1).to(device)

            human_correct = (torch.sigmoid(human_logits) > 0.5).float() == human_labels
            chatgpt_correct = (torch.sigmoid(chatgpt_logits) > 0.5).float() == chatgpt_labels

            correct += human_correct.sum().item() + chatgpt_correct.sum().item()
            total += human_labels.size(0) + chatgpt_labels.size(0)

    accuracy = correct / total
    return accuracy

# 网络搜索函数
def search_baidu(query):
    try:
        url = f"https://www.baidu.com/s?wd={
     query}"
        headers = {
   

http://www.kler.cn/a/515664.html

相关文章:

  • 回归人文主义,探寻情感本质:从文艺复兴到AI时代,我的情感探索之旅
  • ChatGPT大模型极简应用开发-CH2-深入了解 GPT-4 和 ChatGPT 的 API
  • 当使用 npm 时,出现 `certificate has expired` 错误通常意味着请求的证书已过期。
  • IP协议格式
  • Android SystemUI——通知栏构建流程(十六)
  • 快速入门Flink
  • w-form-select.vue(自定义下拉框组件)(与后端字段直接相关性)
  • flask实现重启后需要重新输入用户名而避免浏览器使用之前已经记录的用户名
  • Objective-C语言的安全开发
  • web3py+flask+ganache的智能合约教育平台
  • TCP全连接队列
  • Lisp语言的物联网
  • Golang:使用DuckDB查询Parquet文件数据
  • Charles 4.6.7 浏览器网络调试指南:介绍与安装(一)
  • 【赵渝强老师】K8s中Pod探针的HTTPGetAction
  • 浅谈VPP与DPDK技术以及产业界应用实例
  • 【AI编程】记录一下windsurf中Write模式和Chat模式的区别以及 AI Rules的配置方法
  • Azure学生订阅上手实操:在Ubuntu VPS上利用Docker快速部署PostgreSQL数据库
  • 考研408笔记之数据结构(四)——树与二叉树
  • C++:利用二维数组打印杨辉三角形。
  • 基于Spring Boot3 + Vue3 + JDK17的现代化的Java应用开发框架
  • MATLAB中insertAfter函数用法
  • 自动化01
  • 【ElementPlus】在Vue3中实现表格组件封装
  • 超越 GPT-4o!从 HTML 到 Markdown,一键整理复杂网页;AI 对话不再冰冷,大模型对话微调数据集让响应更流畅
  • 使用 Aryn DocPrep、DocParse 和 Elasticsearch 向量数据库实现高质量 RAG