当前位置: 首页 > article >正文

BERT的中文问答系统23

为了确保日志目录问题不会影响模型训练,我们可以在代码中增加更健壮的日志目录处理逻辑。具体来说,我们可以确保日志目录在代码执行的早期阶段就被创建,并且在遇到任何问题时记录详细的错误信息,但不中断整个程序的执行。

以下是修改后的代码,增加了对日志目录的更健壮处理:

python

import os
import json
import jsonlines
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from transformers import BertModel, BertTokenizer
import tkinter as tk
from tkinter import filedialog, messagebox, scrolledtext, ttk
import logging
from difflib import SequenceMatcher
from datetime import datetime
from torch.cuda.mp import GradScaler, autocast
import torch.multiprocessing as mp
import psutil
import torch.distributed as dist

# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))

# 配置日志
LOGS_DIR = os.path.join(PROJECT_ROOT, 'logs')

def ensure_directory_exists(directory):
    try:
        os.makedirs(directory, exist_ok=True)
        logging.info(f"确保目录存在: {directory}")
    except OSError as e:
        logging.error(f"创建目录 {directory} 失败: {e}")

def setup_logging():
    log_file = os.path.join(LOGS_DIR, datetime.now().strftime('%Y-%m-%d/%H-%M-%S/羲和.txt'))
    log_dir = os.path.dirname(log_file)
    ensure_directory_exists(log_dir)

    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler()
        ]
    )

setup_logging()

# 数据集类
class XihuaDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_length=128):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = self.load_data(file_path)

    def load_data(self, file_path):
        data = []
        if file_path.endswith('.jsonl'):
            with jsonlines.open(file_path) as reader:
                for i, item in enumerate(reader):
                    try:
                        data.append(item)
                    except jsonlines.jsonlines.InvalidLineError as e:
                        logging.warning(f"跳过无效行 {i + 1}: {e}")
        elif file_path.endswith('.json'):
            with open(file_path, 'r') as f:
                try:
                    data = json.load(f)
                except json.JSONDecodeError as e:
                    logging.warning(f"跳过无效文件 {file_path}: {e}")
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        question = item['question']
        human_answer = item['human_answers'][0]
        chatgpt_answer = item['chatgpt_answers'][0]

        try:
            inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
            human_inputs = self.tokenizer(human_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
            chatgpt_inputs = self.tokenizer(chatgpt_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
        except Exception as e:
            logging.warning(f"跳过无效项 {idx}: {e}")
            return self.__getitem__((idx + 1) % len(self.data))

        return {
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'human_input_ids': human_inputs['input_ids'].squeeze(),
            'human_attention_mask': human_inputs['attention_mask'].squeeze(),
            'chatgpt_input_ids': chatgpt_inputs['input_ids'].squeeze(),
            'chatgpt_attention_mask': chatgpt_inputs['attention_mask'].squeeze(),
            'human_answer': human_answer,
            'chatgpt_answer': chatgpt_answer
        }

# 获取数据加载器
def get_data_loader(file_path, tokenizer, batch_size=8, max_length=128, distributed=False, num_workers=4):
    dataset = XihuaDataset(file_path, tokenizer, max_length)
    if distributed:
        sampler = DistributedSampler(dataset)
        return DataLoader(dataset, batch_size=batch_size, sampler=sampler, num_workers=num_workers)
    else:
        return DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

# 模型定义
class XihuaModel(torch.nn.Module):
    def __init__(self, pretrained_model_name='F:/models/bert-base-chinese'):
        super(XihuaModel, self).__init__()
        self.bert = BertModel.from_pretrained(pretrained_model_name)
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        logits = self.classifier(pooled_output)
        return logits

# 训练函数
def train(model, data_loader, optimizer, criterion, device, scaler=None, gradient_accumulation_steps=1):
    model.train()
    total_loss = 0.0
    optimizer.zero_grad()
    for step, batch in enumerate(data_loader):
        try:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            human_input_ids = batch['human_input_ids'].to(device)
            human_attention_mask = batch['human_attention_mask'].to(device)
            chatgpt_input_ids = batch['chatgpt_input_ids'].to(device)
            chatgpt_attention_mask = batch['chatgpt_attention_mask'].to(device)

            with autocast():  # 使用自动混合精度
                human_logits = model(human_input_ids, human_attention_mask)
                chatgpt_logits = model(chatgpt_input_ids, chatgpt_attention_mask)

                human_labels = torch.ones(human_logits.size(0), 1).to(device)
                chatgpt_labels = torch.zeros(chatgpt_logits.size(0), 1).to(device)

                loss = criterion(human_logits, human_labels) + criterion(chatgpt_logits, chatgpt_labels)

            if gradient_accumulation_steps > 1:
                loss = loss / gradient_accumulation_steps

            scaler.scale(loss).backward()

            if (step + 1) % gradient_accumulation_steps == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            total_loss += loss.item()
        except Exception as e:
            logging.warning(f"跳过无效批次: {e}")

    return total_loss / len(data_loader)

# 主训练函数
def main_train(rank, world_size, retrain=False, multi_gpu=False):
    if multi_gpu:
        dist.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
        torch.cuda.set_device(rank)

    device = torch.device(f'cuda:{rank}' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device: {device}')

    tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')
    model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(device)

    if multi_gpu:
        model = DDP(model, device_ids=[rank])

    if retrain:
        model.load_state_dict(torch.load(os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'), map_location=device, weights_only=True))
        model.to(device)
        model.train()

    model.gradient_checkpointing_enable()

    optimizer = optim.Adam(model.parameters(), lr=1e-5)
    criterion = torch.nn.BCEWithLogitsLoss()
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2)
    scaler = torch.amp.GradScaler('cuda')

    max_memory = torch.cuda.get_device_properties(device).total_memory * 0.9 if torch.cuda.is_available() else float('inf')
    batch_size = get_max_batch_size(model, device, max_memory)
    logging.info(f'Using batch size: {batch_size}')

    train_data_loader = get_data_loader(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'), tokenizer, batch_size=batch_size, max_length=128, distributed=multi_gpu, num_workers=4)

    num_epochs = 3
    gradient_accumulation_steps = 2  # 梯度累积步骤
    best_loss = float('inf')
    best_model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')
    ensure_directory_exists(os.path.dirname(best_model_path))  # 确保模型目录存在

    writer = SummaryWriter(log_dir=os.path.join(PROJECT_ROOT, 'logs/tensorboard'))

    for epoch in range(num_epochs):
        train_loss = train(model, train_data_loader, optimizer, criterion, device, scaler, gradient_accumulation_steps)
        logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.8f}')
        writer.add_scalar('Training Loss', train_loss, epoch)
        scheduler.step(train_loss)

        if rank == 0:
            if train_loss < best_loss:
                best_loss = train_loss
                torch.save(model.state_dict(), best_model_path)
                logging.info(f"模型在 Epoch {epoch+1} 更新,Loss: {train_loss:.8f}")

    if rank == 0:
        logging.info("模型训练完成并保存")

    if multi_gpu:
        dist.destroy_process_group()

# 动态调整批大小
def get_max_batch_size(model, device, max_memory=1024 * 1024 * 1024):  # 默认最大显存为1GB
    batch_size = 1
    while True:
        try:
            input_ids = torch.randint(0, 100, (batch_size, 128)).to(device)
            attention_mask = torch.ones(batch_size, 128).to(device)
            with torch.no_grad():
                model(input_ids, attention_mask)
            batch_size *= 2
        except RuntimeError:
            return batch_size // 2

# 启动多GPU训练
def launch_training(retrain=False, multi_gpu=False):
    if multi_gpu and torch.cuda.device_count() > 1:
        world_size = torch.cuda.device_count()
        mp.spawn(main_train, args=(world_size, retrain, multi_gpu), nprocs=world_size, join=True)
    else:
        main_train(0, 1, retrain, multi_gpu)

# GUI界面
class XihuaChatbotGUI:
    def __init__(self, root):
        self.root = root
        self.root.title("羲和聊天机器人")

        self.tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(self.device)
        self.load_model()
        self.model.eval()

        # 加载训练数据集以便在获取答案时使用
        self.data = self.load_data(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'))

        self.create_widgets()

    def create_widgets(self):
        self.question_label = tk.Label(self.root, text="问题:")
        self.question_label.pack()

        self.question_entry = tk.Entry(self.root, width=50)
        self.question_entry.pack()

        self.answer_button = tk.Button(self.root, text="获取回答", command=self.get_answer)
        self.answer_button.pack()

        self.answer_label = tk.Label(self.root, text="回答:")
        self.answer_label.pack()

        self.answer_text = scrolledtext.ScrolledText(self.root, height=10, width=50)
        self.answer_text.pack()

        self.train_button = tk.Button(self.root, text="训练模型", command=self.train_model)
        self.train_button.pack()

        self.retrain_button = tk.Button(self.root, text="重新训练模型", command=lambda: self.train_model(retrain=True))
        self.retrain_button.pack()

        self.multi_gpu_var = tk.BooleanVar()
        self.multi_gpu_checkbox = tk.Checkbutton(self.root, text="使用多GPU", variable=self.multi_gpu_var)
        self.multi_gpu_checkbox.pack()

        self.log_text = scrolledtext.ScrolledText(self.root, height=10, width=50)
        self.log_text.pack()

        self.progress_bar = ttk.Progressbar(self.root, orient='horizontal', length=300, mode='determinate')
        self.progress_bar.pack()

    def get_answer(self):
        question = self.question_entry.get()
        if not question:
            messagebox.showwarning("输入错误", "请输入问题")
            return

        inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=128)
        with torch.no_grad():
            input_ids = inputs['input_ids'].to(self.device)
            attention_mask = inputs['attention_mask'].to(self.device)
            logits = self.model(input_ids, attention_mask)
        
        if logits.item() > 0:
            answer_type = "羲和回答"
        else:
            answer_type = "零回答"

        specific_answer = self.get_specific_answer(question, answer_type)

        self.answer_text.delete(1.0, tk.END)
        self.answer_text.insert(tk.END, f"{answer_type}\n{specific_answer}")

    def get_specific_answer(self, question, answer_type):
        # 使用模糊匹配查找最相似的问题
        best_match = None
        best_ratio = 0.0
        for item in self.data:
            ratio = SequenceMatcher(None, question, item['question']).ratio()
            if ratio > best_ratio:
                best_ratio = ratio
                best_match = item

        if best_match:
            if answer_type == "羲和回答":
                return best_match['human_answers'][0]
            else:
                return best_match['chatgpt_answers'][0]
        return "这个我也不清楚,你问问零吧"

    def load_data(self, file_path):
        data = []
        if file_path.endswith('.jsonl'):
            with jsonlines.open(file_path) as reader:
                for i, item in enumerate(reader):
                    try:
                        data.append(item)
                    except jsonlines.jsonlines.InvalidLineError as e:
                        logging.warning(f"跳过无效行 {i + 1}: {e}")
        elif file_path.endswith('.json'):
            with open(file_path, 'r') as f:
                try:
                    data = json.load(f)
                except json.JSONDecodeError as e:
                    logging.warning(f"跳过无效文件 {file_path}: {e}")
        return data

    def load_model(self):
        model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')
        ensure_directory_exists(os.path.dirname(model_path))  # 确保模型目录存在
        if os.path.exists(model_path):
            self.model.load_state_dict(torch.load(model_path, map_location=self.device, weights_only=True))
            logging.info("加载现有模型")
        else:
            logging.info("没有找到现有模型,将使用预训练模型")

    def train_model(self, retrain=False):
        file_path = filedialog.askopenfilename(filetypes=[("JSONL files", "*.jsonl"), ("JSON files", "*.json")])
        if not file_path:
            messagebox.showwarning("文件选择错误", "请选择一个有效的数据文件")
            return

        try:
            dataset = XihuaDataset(file_path, self.tokenizer)
            data_loader = get_data_loader(file_path, self.tokenizer, batch_size=8, max_length=128, distributed=self.multi_gpu_var.get(), num_workers=4)
            
            # 加载已训练的模型权重
            if retrain:
                self.model.load_state_dict(torch.load(os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'), map_location=self.device, weights_only=True))
                self.model.to(self.device)
                self.model.train()

            optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-5)
            criterion = torch.nn.BCEWithLogitsLoss()
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2)
            scaler = torch.amp.GradScaler('cuda')

            # 启用梯度检查点
            self.model.bert.gradient_checkpointing_enable()

            max_memory = torch.cuda.get_device_properties(self.device).total_memory * 0.9 if torch.cuda.is_available() else float('inf')
            batch_size = get_max_batch_size(self.model, self.device, max_memory)
            logging.info(f'Using batch size: {batch_size}')
            data_loader = get_data_loader(file_path, self.tokenizer, batch_size=batch_size, max_length=128, distributed=self.multi_gpu_var.get(), num_workers=4)

            writer = SummaryWriter(log_dir=os.path.join(PROJECT_ROOT, 'logs/tensorboard'))
            best_loss = float('inf')
            best_model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')
            ensure_directory_exists(os.path.dirname(best_model_path))  # 确保模型目录存在

            for epoch in range(3):
                self.progress_bar['value'] = (epoch + 1) / 3 * 100
                self.root.update_idletasks()
                train_loss = train(self.model, data_loader, optimizer, criterion, self.device, scaler, 2)
                logging.info(f'Epoch [{epoch+1}/3], Loss: {train_loss:.4f}')
                self.log_text.insert(tk.END, f'Epoch [{epoch+1}/3], Loss: {train_loss:.4f}\n')
                self.log_text.yview(tk.END)
                writer.add_scalar('Training Loss', train_loss, epoch)
                scheduler.step(train_loss)

                if train_loss < best_loss:
                    best_loss = train_loss
                    torch.save(self.model.state_dict(), best_model_path)
                    logging.info(f"模型在 Epoch {epoch+1} 更新,Loss: {train_loss:.4f}")

            logging.info("模型训练完成并保存")
            messagebox.showinfo("训练完成", "模型训练完成并保存")
        except Exception as e:
            logging.error(f"模型训练失败: {e}")
            messagebox.showerror("训练失败", f"模型训练失败: {e}")

# 主函数
if __name__ == "__main__":
    # 启动GUI
    root = tk.Tk()
    app = XihuaChatbotGUI(root)
    root.mainloop()

关键修改点
确保目录存在:在 ensure_directory_exists 函数中,确保日志目录和模型目录在代码执行的早期阶段被创建。
日志配置:在 setup_logging 函数中,调用 ensure_directory_exists 确保日志目录存在。
模型加载:在 load_model 和 train_model 函数中,调用 ensure_directory_exists 确保模型目录存在。
通过这些修改,即使日志目录或模型目录存在问题,也不会中断整个程序的执行。程序会记录详细的错误信息,并继续运行其他部分。


http://www.kler.cn/a/372794.html

相关文章:

  • GraphRAG: Auto Prompt Tuning 实践
  • k8s集群安装
  • 51c嵌入式~单片机~合集6
  • 1.6 从 GPT-1 到 GPT-3.5:一路的风云变幻
  • ipmitool设置带外账号权限
  • Typora + PowerShell 在终端打开文件
  • 使用Virtual Audio Cable捕获系统音频输出并使用Python处理
  • pc电脑屏幕分辨率尺寸
  • MetaArena推出《Final Glory》:引领Web3游戏技术新风向
  • Flutter鸿蒙next 封装对话框详解
  • 实现链式结构的二叉树
  • C++ 之 VS2010 和MySQL数据库的链接问题
  • 【K8S系列】Kubernetes 中 Service 无法访问及解决方案【已解决】
  • 单片机如何入门?
  • Android adb命令获取设备id
  • CentOS 9 Stream 上安装 Visual Studio Code
  • Go语言八股(Ⅲ)
  • C#与C++交互开发系列(十四):C++中STL容器与C#集合传递的形式
  • lvm故障小贴士
  • 响应报文时间
  • CPB数据集:由斯坦福大学发布,一个新的视频问题回答任务基准,能够连续且全面处理视频数据
  • Golang | Leetcode Golang题解之第521题最长特殊序列I
  • C#与C++交互开发系列(十七):线程安全
  • 查询windows或者linux上 支持的所有字体
  • 100种算法【Python版】第9篇——二分法
  • 香港海洋投资引领海洋牧场新一轮融资热潮