当前位置: 首页 > article >正文

BERT的中文问答系统42

我们将对现有的代码进行扩展,以支持360百科的功能。这包括修改XihuaChatbotGUI类中的相关方法,以及添加一个新的搜索360百科的函数。此外,我们还需要更新历史记录的保存格式,以包含360百科的结果。

项目结构
code
project_root/

├── data/
│ └── train_data.jsonl

├── logs/
│ └── [log_files]

├── models/
│ └── xihua_model.pth

├── main.py
└── README.md
README.md
markdown

羲和聊天机器人

项目介绍

羲和聊天机器人是一个基于BERT模型的问答系统。它可以从训练数据中学习,并能够回答用户提出的问题。此外,用户可以通过界面评价机器人的回答是否准确,并提供百度百科和360百科的参考答案。

目录结构

project_root/

├── data/
│ └── train_data.jsonl

├── logs/
│ └── [log_files]

├── models/
│ └── xihua_model.pth

├── main.py
└── README.md

code

依赖

  • Python 3.7+
  • PyTorch
  • Transformers
  • Tkinter
  • Requests
  • BeautifulSoup

安装

pip install torch transformers requests beautifulsoup4

运行

python main.py

功能
用户输入问题,机器人给出回答。
用户可以评价回答是否准确。
如果回答不准确,可以选择查看百度百科或360百科的结果。
训练和重新训练模型。
查看和保存历史记录。

main.py

import os
import json
import jsonlines
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
import tkinter as tk
from tkinter import filedialog, messagebox, ttk
import logging
from difflib import SequenceMatcher
from datetime import datetime
import requests
from bs4 import BeautifulSoup

# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))

# 配置日志
LOGS_DIR = os.path.join(PROJECT_ROOT, 'logs')
os.makedirs(LOGS_DIR, exist_ok=True)

def setup_logging():
    log_file = os.path.join(LOGS_DIR, datetime.now().strftime('%Y-%m-%d_%H-%M-%S_羲和.txt'))
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler()
        ]
    )

setup_logging()

# 数据集类
class XihuaDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_length=128):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = self.load_data(file_path)

    def load_data(self, file_path):
        data = []
        if file_path.endswith('.jsonl'):
            with jsonlines.open(file_path) as reader:
                for i, item in enumerate(reader):
                    try:
                        data.append(item)
                    except jsonlines.jsonlines.InvalidLineError as e:
                        logging.warning(f"跳过无效行 {
     i + 1}: {
     e}")
        elif file_path.endswith('.json'):
            with open(file_path, 'r') as f:
                try:
                    data = json.load(f)
                except json.JSONDecodeError as e:
                    logging.warning(f"跳过无效文件 {
     file_path}: {
     e}")
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        question = item['question']
        human_answer = item['human_answers'][0]
        chatgpt_answer = item['chatgpt_answers'][0]

        try:
            inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
            human_inputs = self.tokenizer(human_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
            chatgpt_inputs = self.tokenizer(chatgpt_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
        except Exception as e:
            logging.warning(f"跳过无效项 {
     idx}: {
     e}")
            return self.__getitem__((idx + 1) % len(self.data))

        return {
   
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'human_input_ids': human_inputs['input_ids'].squeeze(),
            'human_attention_mask': human_inputs['attention_mask'].squeeze(),
            'chatgpt_input_ids': chatgpt_inputs['input_ids'].squeeze(),
            'chatgpt_attention_mask': chatgpt_inputs['attention_mask'].squeeze(),
            'human_answer': human_answer,
            'chatgpt_answer': chatgpt_answer
        }

# 获取数据加载器
def get_data_loader(file_path, tokenizer, batch_size=8, max_length=128):
    dataset = XihuaDataset(file_path, tokenizer, max_length)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 模型定义
class XihuaModel(torch.nn.Module):
    def __init__(self, pretrained_model_name='F:/models/bert-base-chinese'):
        super(XihuaModel, self).__init__()
        self.bert = BertModel.from_pretrained(pretrained_model_name)
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        logits = self.classifier(pooled_output)
        return logits

# 训练函数
def 

http://www.kler.cn/a/414836.html

相关文章:

  • C++ 【异步日志模块和std::cout << 一样使用习惯替代性好】 使用示例,后续加上远程日志
  • 非关系型数据库有哪些特点?
  • 微信小程序蓝牙writeBLECharacteristicValue写入数据返回成功后,实际硬件内信息查询未存储?
  • CTF之密码学(键盘加密)
  • 20241124 Typecho 视频插入插件
  • 【H2O2|全栈】Node.js(1)
  • 基于Springboot的网上商城系统【附源码】
  • P8723 [蓝桥杯 2020 省 AB3] 乘法表
  • 02-Linux系统权限维持
  • 力扣hot100-->排序
  • 23种设计模式-原型(Prototype)设计模式
  • 【自适应和反应式机器人控制】编程练习 1.1:计算最优轨迹 + 编程练习 1.3:基于三阶多项式的闭式时变轨迹发生器
  • Redis - ⭐常用命令
  • BC-Linux8.6设置静态IP
  • Ubuntu FTP服务器的权限设置
  • 设计模式---单例模式
  • 使用R语言绘制简单地图的教程
  • 【知识科普】Restful架构风格
  • 16 go语言(golang) - 并发编程select和workerpool
  • Kafka面试题(三)-- 内含面试重点
  • Navicat 预览变更sql
  • AI潮汐日报1128期:Sora泄露引发争议、百度早期研究对AI领域Scaling Law的贡献、Meta发布系列AI开源项目
  • 【linux014】文件操作命令篇 - head 命令
  • 镜舟科技积极参与北京市开源项目产融对接会,共谋开源新未来
  • HarmonyOS(60)性能优化之状态管理最佳实践
  • 【ArcGIS Pro实操第11期】经纬度数据转化成平面坐标数据