当前位置: 首页 > article >正文

《AI大模型专家之路》No.2:用三个模型洞察大模型NLP的基础能力

用三个模型洞察大模型NLP的基础能力

一、项目概述

在这个基于AI构建AI的思维探索项目中,我们实现了一个基于BERT的中文AI助手系统。该系统集成了文本分类、命名实体识别和知识库管理等功能,深入了解本项目可以让读者充分了解AI大模型训练和推理的基本原理,该项目使用了三个基础大模型:bert-base-chinese,ckiplab/bert-base-chinese-ner``,spacy.lang.zh.Chinese 旨在展示如何构建一个基础的AI应用系统。

在这里插入图片描述

1.1 技术栈

  • 后端:Python + Flask
  • 前端:HTML + JavaScript
  • 模型:BERT中文预训练模型
  • 框架:Transformers + PyTorch

1.2 核心功能

  • 智能问答
  • 知识管理
  • 模型训练
  • 实体识别

1.3 功能分析

这个应用程序试图实现一个简单的AI助手系统,包含以下功能:

  1. 知识管理

    • 通过 KnowledgeBase 类管理知识库
    • 支持添加和检索知识
    • 知识以JSON格式持久化存储
  2. 文本分类

    • 使用 bert-base-chinese 模型
    • 只能进行简单的二分类(0或1)
    • 主要用于判断问题是否与AI相关
  3. 命名实体识别

    • 使用 ckiplab/bert-base-chinese-ner 模型
    • 识别文本中的人名、地名等实体
    • 结果并未充分利用
  4. Web界面

    • 提供聊天界面
    • 支持知识添加
    • 支持模型训练

1.4 对大模型技术的帮助

这个项目对学习大模型技术有以下帮助:

  1. 基础概念学习

    • 了解如何加载和使用预训练模型
    • 理解模型训练的基本流程
    • 掌握文本分类和NER的基本概念
  2. 实践价值

    • 展示了如何构建一个完整的AI应用
    • 演示了模型训练和预测的基本流程
    • 展示了如何处理中文文本
  3. 改进方向

    • 可以学习如何优化模型训练
    • 了解如何更好地整合多个模型
    • 学习如何设计更好的系统架构

二、系统架构

基于Python3.11.5虚拟环境,完整源代码见3.4。

2.1 核心模块

class KnowledgeBase:
    """知识库管理模块"""
    def __init__(self):
        self.knowledge = {}
        self.load_knowledge()

class LearningModule:
    """模型学习模块"""
    def __init__(self):
        self.model_name = "bert-base-chinese"
        self.model_path = "./trained_model"

class ThinkingFramework:
    """思维框架模块"""
    def __init__(self):
        self.cognitive_model = Chinese()
        self.ner_tagger = AutoModelForTokenClassification.from_pretrained("ckiplab/bert-base-chinese-ner")

2.2 模块职责

  1. 知识库模块:管理系统的知识存储和检索
  2. 学习模块:处理模型训练和预测
  3. 思维框架:处理文本分析和实体识别
  4. 用户界面:提供Web交互界面

三、关键技术实现

3.1 模型加载与训练

def load_or_train_model(self):
    """加载或训练模型"""
    try:
        if os.path.exists(self.model_path):
            self.model = AutoModelForTokenClassification.from_pretrained(self.model_path)
        else:
            self.model = AutoModelForTokenClassification.from_pretrained(self.model_name)
        
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=2e-5)
        self.train_initial_data()
    except Exception as e:
        print(f"加载模型时出错: {str(e)}")
        raise

3.2 知识检索实现

def retrieve(self, query):
    """检索与查询相关的知识"""
    results = []
    query_lower = query.lower()
    
    for k, v in self.knowledge.items():
        k_lower = k.lower()
        if query_lower == k_lower:
            results.append((k, v, 1.0))  # 完全匹配
        elif query_lower in k_lower or k_lower in query_lower:
            overlap = len(set(query_lower) & set(k_lower)) / max(len(query_lower), len(k_lower))
            results.append((k, v, overlap))  # 部分匹配

3.3 实体识别实现

def get_ner(self, tokens):
    """获取命名实体识别结果"""
    try:
        tokenizer = AutoTokenizer.from_pretrained("ckiplab/bert-base-chinese-ner")
        inputs = tokenizer(text, return_tensors="pt")
        outputs = self.ner_tagger(**inputs)
        
        predictions = outputs.logits.argmax(dim=2).tolist()[0]
        tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
        
        # 提取实体
        entities = []
        current_entity = []
        current_type = None
        
        for token, pred in zip(tokens, predictions):
            if pred > 0:  # 如果不是O标签
                if not current_entity:
                    current_entity = [token]
                    current_type = pred
                else:
                    current_entity.append(token)

3.4 完整源代码

系统运行截图:
在这里插入图片描述

代码文件1:main.py

#运行成功
import os
import json
from flask import Flask, request, jsonify, render_template
from transformers import AutoTokenizer, AutoModelForTokenClassification
from spacy.lang.zh import Chinese
import numpy as np
import torch
import time
import torch.nn.functional as F
# 设置环境变量,使用镜像源
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

class KnowledgeBase:
    def __init__(self):
        self.knowledge = {}
        self.load_knowledge()  # 启动时加载知识
    
    def add(self, key, value):
        self.knowledge[key] = value
    
    def retrieve(self, query):
        #return [k for k, v in self.knowledge.items() if query.lower() in k.lower()]

        """检索与查询相关的知识
        Args:
        query: 用户查询字符串            
        Returns:
            包含(key, value)元组的列表,按相关性排序
        """
        results = []
        query_lower = query.lower()

        # 首先检查完全匹配
        for k, v in self.knowledge.items():
            k_lower = k.lower()
            if query_lower == k_lower:
                results.append((k, v, 1.0))  # 完全匹配,相关性为1.0
            elif query_lower in k_lower or k_lower in query_lower:
                # 部分匹配,计算相关性分数
                overlap = len(set(query_lower) & set(k_lower)) / max(len(query_lower), len(k_lower))
                results.append((k, v, overlap))
            else:
                # 检查内容中是否包含查询词
                if query_lower in v.lower():
                    results.append((k, v, 0.5))  # 内容匹配,相关性为0.5

        # 按相关性排序
        results.sort(key=lambda x: x[2], reverse=True)

        # 只返回相关性大于0.3的结果
        filtered_results = [(k, v) for k, v, score in results if score > 0.3]

        return filtered_results

    def save_knowledge(self):
        with open("knowledge_base.json", "w", encoding="utf-8") as f:
            json.dump(self.knowledge, f, ensure_ascii=False, indent=4)
    
    def load_knowledge(self):
        try:
            with open("knowledge_base.json", "r", encoding="utf-8") as f:
                self.knowledge = json.load(f)
        except FileNotFoundError:
            self.knowledge = {}



class LearningModule:
    def __init__(self):
        self.model_name = "bert-base-chinese"
        self.model_path = "./trained_model"  # 修改为本地路径
        self.optimizer = None  # 添加优化器初始化
        self.initial_training_data = [  # 添加初始训练数据
            ("我是一个AI助手", 1),
            ("Python是编程语言", 0),
            ("机器学习是人工智能的分支", 1),
            ("深度学习使用神经网络", 1),
            ("自然语言处理很重要", 1),
            ("计算机视觉处理图像", 1),
            ("强化学习用于决策", 1),
            ("数据科学使用Python", 0),
            ("Web开发需要HTML", 0),
            ("数据库存储数据", 0)
        ]
        self.training_data = self.initial_training_data.copy()  # 添加训练数据存储
        self.load_or_train_model()
    
    def save_model(self):
        """保存模型到本地目录"""
        try:
            # 确保目录存在
            os.makedirs(self.model_path, exist_ok=True)
            # 保存模型和分词器
            self.model.save_pretrained(self.model_path)
            self.tokenizer.save_pretrained(self.model_path)
            print(f"模型已保存到: {self.model_path}")
        except Exception as e:
            print(f"保存模型时出错: {str(e)}")
    
    def load_or_train_model(self):
        """加载或训练模型"""
        try:
            # 首先尝试从本地加载模型
            if os.path.exists(self.model_path):
                print(f"从本地加载模型: {self.model_path}")
                self.model = AutoModelForTokenClassification.from_pretrained(self.model_path)
                self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
            else:
                # 如果本地没有模型,从预训练模型加载
                print(f"从预训练模型加载: {self.model_name}")
                self.model = AutoModelForTokenClassification.from_pretrained(self.model_name)
                self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
            
            # 初始化优化器
            self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=2e-5)
            
            # 训练初始数据
            self.train_initial_data()
            
        except Exception as e:
            print(f"加载模型时出错: {str(e)}")
            raise

    def train_initial_data(self):
        """使用初始数据进行训练"""
        sentences, labels = zip(*self.initial_training_data)
        self.train(list(sentences), list(labels))

    def train(self, sentences, labels):
        """训练模型"""
        try:
            # 对所有句子一起进行tokenize,确保填充到相同长度
            inputs = self.tokenizer(sentences, padding=True, truncation=True, return_tensors="pt")
            input_ids = inputs["input_ids"]
            attention_mask = inputs["attention_mask"]
            labels_ids = torch.tensor(labels).long()
            
            # 设置模型为训练模式
            self.model.train()
            
            # 前向传播
            outputs = self.model(input_ids, attention_mask=attention_mask)
            
            # 调整输出和标签的形状以匹配交叉熵损失函数的要求
            # 假设我们只关心[CLS]标记的输出(序列的第一个标记)
            logits = outputs.logits[:, 0, :]  # 取每个序列的第一个标记的输出
            loss = F.cross_entropy(logits, labels_ids)
            
            # 反向传播和优化
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
            print(f"训练完成,损失: {loss.item():.4f}")
            return loss.item()
        except Exception as e:
            print(f"训练过程中出错: {str(e)}")
            import traceback
            traceback.print_exc()
            return None

    def predict(self, sentence):
        """对输入句子进行预测"""
        try:
            # 对输入进行tokenize
            tokenized = self.tokenizer(sentence, return_tensors="pt", padding=True)
            input_ids = tokenized["input_ids"]
            attention_mask = tokenized["attention_mask"]
            
            # 获取模型输出
            with torch.no_grad():
                outputs = self.model(input_ids, attention_mask=attention_mask)
            
            # 处理输出
            logits = outputs.logits
            predictions = torch.argmax(logits, dim=2).tolist()[0]
            
            # 这里可以添加更多处理逻辑,例如将预测结果映射到标签
            # 简单返回预测结果
            return f"预测类别: {predictions[0]}"
        except Exception as e:
            print(f"预测错误: {str(e)}")
            return "无法进行预测"

    def train_model(self, sentence, label):
        # 保存训练数据
        self.training_data.append((sentence, label))
        # 训练模型
        loss = self.train([sentence], [label])
        # 保存模型
        self.save_model()
        return loss

class ThinkingFramework:
    def __init__(self):
        self.cognitive_model = Chinese()
        try:
            print("正在加载NER模型...")
            self.ner_tagger = AutoModelForTokenClassification.from_pretrained("ckiplab/bert-base-chinese-ner")
            print("NER模型加载成功!")
        except Exception as e:
            print(f"NER模型加载失败: {str(e)}")
            print("请确保网络连接正常,并已安装所有必要的依赖。")
            raise
    
    def parse(self, text):
        return self.cognitive_model(text)
    
    def get_ner(self, tokens):
        """获取命名实体识别结果"""
        try:
            if isinstance(tokens, str):
                tokens = self.parse(tokens)
            
            tokenizer = AutoTokenizer.from_pretrained("ckiplab/bert-base-chinese-ner")
            text = tokens.text if hasattr(tokens, 'text') else str(tokens)
            
            inputs = tokenizer(text, return_tensors="pt")
            outputs = self.ner_tagger(**inputs)
            
            # 获取预测结果和标签
            predictions = outputs.logits.argmax(dim=2).tolist()[0]
            tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
            
            # 提取实体
            entities = []
            current_entity = []
            current_type = None
            
            for token, pred in zip(tokens, predictions):
                if pred > 0:  # 如果不是O标签
                    if not current_entity:
                        current_entity = [token]
                        current_type = pred
                    else:
                        current_entity.append(token)
                elif current_entity:
                    entities.append((''.join(current_entity), current_type))
                    current_entity = []
                    current_type = None
            
            return entities
            
        except Exception as e:
            print(f"NER处理错误: {str(e)}")
            return []


class SelfConsciousness:
    def __init__(self):
        self.knowledge = {}
        self.last_update = time.time()
    
    def update(self, new_knowledge):
        if new_knowledge:
            self.knowledge.update({k: time.time() for k, _ in new_knowledge})
            self.last_update = time.time()
            print(f"自我意识已更新,新增知识点: {len(new_knowledge)}")
    
    def save(self):
        with open("consciousness.json", "w", encoding="utf-8") as f:
            json.dump({
                "knowledge": self.knowledge,
                "last_update": self.last_update
            }, f, ensure_ascii=False, indent=4)
        print("自我意识已保存")


class UserInterface:
    def __init__(self, thinking_framework, learning_module, knowledge_base):
        self.thinking_framework = thinking_framework
        self.learning_module = learning_module
        self.knowledge_base = knowledge_base
        self.app = Flask(__name__)
        self.setup_routes()

    def setup_routes(self):
        @self.app.route('/')
        def index():
            return render_template('index.html')

        @self.app.route('/chat', methods=['POST'])
        def chat():
            data = request.get_json()
            user_message = data.get('message', '')
            
            if not user_message:
                return jsonify({'response': '请输入有效的问题。'})

            try:
                # 使用思维框架处理用户输入
                doc = self.thinking_framework.parse(user_message)
                
                # 正确处理NER标签
                try:
                    # 首先解析文本,然后获取NER标签
                    tokens = self.thinking_framework.parse(user_message)
                    ner_tags = self.thinking_framework.get_ner(tokens)
                except Exception as e:
                    print(f"NER处理错误: {str(e)}")
                    ner_tags = []
                
                # 使用学习模块进行预测
                try:
                    prediction = self.learning_module.predict(user_message)
                except Exception as e:
                    print(f"预测错误: {str(e)}")
                    prediction = None
                
                # 从知识库中检索相关信息
                try:
                    relevant_knowledge = self.knowledge_base.retrieve(user_message)
                    print(f"检索到的知识: {relevant_knowledge}")
                except Exception as e:
                    print(f"知识库检索错误: {str(e)}")
                    relevant_knowledge = []
                
                # 生成回复
                response = self.generate_response(user_message, doc, ner_tags, prediction, relevant_knowledge)
                return jsonify({'response': response})
            
            except Exception as e:
                print(f"处理消息时出错: {str(e)}")
                import traceback
                traceback.print_exc()
                return jsonify({'response': f'抱歉,处理您的消息时出现了错误: {str(e)}'})
        
        @self.app.route('/add_knowledge', methods=['POST'])
        def add_knowledge():
            data = request.get_json()
            key = data.get('key', '')
            value = data.get('value', '')
            
            if not key or not value:
                return jsonify({'success': False, 'message': '请提供有效的知识关键词和内容'})
            
            try:
                # 添加知识到知识库
                self.knowledge_base.add(key, value)
                print(f"添加知识: {key} -> {value}")
                return jsonify({'success': True})
            except Exception as e:
                print(f"添加知识时出错: {str(e)}")
                return jsonify({'success': False, 'message': str(e)})
        
        @self.app.route('/train_model', methods=['POST'])
        def train_model():
            data = request.get_json()
            sentence = data.get('sentence', '')
            label = data.get('label')
            
            if not sentence or label is None:
                return jsonify({'success': False, 'message': '请提供有效的训练句子和标签'})
            
            try:
                # 训练模型
                loss = self.learning_module.train_model(sentence, label)
                if loss is not None:
                    print(f"训练模型: '{sentence}' -> {label}, 损失: {loss}")
                    return jsonify({'success': True, 'loss': loss})
                else:
                    return jsonify({'success': False, 'message': '训练失败,请查看服务器日志'})
            except Exception as e:
                print(f"训练模型时出错: {str(e)}")
                return jsonify({'success': False, 'message': str(e)})

    def generate_response(self, user_message, doc, ner_tags, prediction, relevant_knowledge):
        """生成对用户消息的回复"""
        try:
            response_parts = []
            
            # 处理NER结果
            if ner_tags:
                entities = [entity for entity, _ in ner_tags]
                if entities:
                    response_parts.append(f"我注意到您提到了:{', '.join(entities)}")
            
            # 添加知识库中的相关信息
            if relevant_knowledge:
                for key, value in relevant_knowledge[:2]:
                    response_parts.append(f"关于{key},我知道:{value}")
            
            # 根据预测结果调整回复
            if prediction and "预测类别: 1" in prediction:
                response_parts.append("这看起来是一个与AI相关的问题,我很乐意帮您解答。")
            
            # 如果没有找到相关信息
            if not response_parts:
                if "你好" in user_message or "您好" in user_message:
                    response_parts.append("您好!我是AI助手,很高兴为您服务。")
                elif any(word in user_message for word in ["什么", "怎么", "如何", "为什么"]):
                    response_parts.append("这是一个很好的问题。让我为您详细解释...")
                else:
                    response_parts.append("我理解您的问题,让我为您提供相关信息...")
            
            return "\n\n".join(response_parts)
        
        except Exception as e:
            print(f"生成回复时出错: {str(e)}")
            return "我正在处理您的请求,但遇到了一些技术问题。请稍后再试。"

    def start(self):
        # 在Windows上,使用use_reloader=False可以避免WinError 10038
        self.app.run(debug=True, use_reloader=False, threaded=True)

if __name__ == "__main__":
    try:
        # 初始化各个组件
        print("初始化知识库...")
        kb = KnowledgeBase()
        
        print("初始化学习模块...")
        lm = LearningModule()
        
        print("初始化自我意识模块...")
        sf = SelfConsciousness()
        
        print("初始化思维框架...")
        tf = ThinkingFramework()
        
        print("初始化用户界面...")
        uif = UserInterface(tf, lm, kb)
        
        print("添加初始知识...")
        # 添加一些初始知识
        kb.add("人工智能", "人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。")
        kb.add("机器学习", "机器学习是人工智能的一个分支,它使用各种算法来解析数据、学习数据,然后对新数据进行预测。")
        kb.add("深度学习", "深度学习是机器学习的一个分支,它使用多层神经网络来模拟人脑的工作方式。")
        kb.add("自然语言处理", "自然语言处理(NLP)是人工智能的一个子领域,专注于使计算机能够理解、解释和生成人类语言。")
        kb.add("计算机视觉", "计算机视觉是人工智能的一个领域,专注于使计算机能够从图像或视频中获取高级理解。")
        kb.add("强化学习", "强化学习是机器学习的一种方法,通过让代理在环境中采取行动并获得奖励或惩罚来学习最佳策略。")
        kb.add("神经网络", "神经网络是一种受人脑结构启发的计算模型,由多层相互连接的节点(神经元)组成,用于模式识别和机器学习。")
        kb.add("大语言模型", "大语言模型是一种基于深度学习的自然语言处理模型,通过大量文本数据训练,能够生成、理解和翻译人类语言。")
        kb.add("Python", "Python是一种高级编程语言,以其简洁、易读的语法和丰富的库而闻名,广泛应用于数据科学、人工智能和Web开发等领域。")        
        print("更新知识...")
        # 模拟知识更新
        knowledge_items = kb.retrieve("人工智能")
        if knowledge_items:
            print(f"找到相关知识: {knowledge_items}")
            sf.update(knowledge_items)
            # 保存更新后的知识
            sf.save()
        else:
            print("未找到相关知识,跳过更新")
        
        print("训练模型...")
        # 开始训练模型
        lm.train(["我是一个AI助手", "Python是编程语言"], [1, 0])
        
        print("启动Web服务...")
        # 启动服务
        uif.start()
    except Exception as e:
        print(f"启动失败: {str(e)}")
        import traceback
        traceback.print_exc()

代码文件2:index.html(放在在 templates文件夹下)

<!DOCTYPE html>
<html>
<head>
    <title>AI助手</title>
    <meta charset="UTF-8">
    <style>
        body {
            font-family: Arial, sans-serif;
            margin: 0;
            padding: 20px;
            background-color: #f5f5f5;
        }
        .container {
            max-width: 1000px;
            margin: 0 auto;
            display: flex;
            gap: 20px;
        }
        .chat-container {
            flex: 2;
            background-color: white;
            border-radius: 10px;
            box-shadow: 0 2px 10px rgba(0,0,0,0.1);
            padding: 20px;
        }
        .training-container {
            flex: 1;
            background-color: white;
            border-radius: 10px;
            box-shadow: 0 2px 10px rgba(0,0,0,0.1);
            padding: 20px;
        }
        .chat-messages {
            height: 500px;
            overflow-y: auto;
            padding: 10px;
            border: 1px solid #ddd;
            border-radius: 5px;
            margin-bottom: 20px;
        }
        .message {
            margin: 10px 0;
            padding: 10px;
            border-radius: 5px;
        }
        .user-message {
            background-color: #e3f2fd;
            margin-left: 20%;
        }
        .ai-message {
            background-color: #f5f5f5;
            margin-right: 20%;
        }
        .input-container {
            display: flex;
            gap: 10px;
        }
        input[type="text"], textarea {
            flex: 1;
            padding: 10px;
            border: 1px solid #ddd;
            border-radius: 5px;
            font-size: 16px;
        }
        button {
            padding: 10px 20px;
            background-color: #2196f3;
            color: white;
            border: none;
            border-radius: 5px;
            cursor: pointer;
            font-size: 16px;
        }
        button:hover {
            background-color: #1976d2;
        }
        h2 {
            color: #333;
            margin-top: 0;
        }
        .form-group {
            margin-bottom: 15px;
        }
        label {
            display: block;
            margin-bottom: 5px;
            font-weight: bold;
        }
        .status {
            margin-top: 10px;
            padding: 10px;
            border-radius: 5px;
            display: none;
        }
        .success {
            background-color: #d4edda;
            color: #155724;
            display: block;
        }
        .error {
            background-color: #f8d7da;
            color: #721c24;
            display: block;
        }
    </style>
</head>
<body>
    <div class="container">
        <div class="chat-container">
            <h2>AI助手聊天</h2>
            <div class="chat-messages" id="chat-messages"></div>
            <div class="input-container">
                <input type="text" id="user-input" placeholder="请输入您的问题...">
                <button onclick="sendMessage()">发送</button>
            </div>
        </div>
        
        <div class="training-container">
            <h2>知识训练</h2>
            <div class="form-group">
                <label for="knowledge-key">知识关键词</label>
                <input type="text" id="knowledge-key" placeholder="例如:人工智能">
            </div>
            <div class="form-group">
                <label for="knowledge-value">知识内容</label>
                <textarea id="knowledge-value" rows="5" placeholder="例如:人工智能是计算机科学的一个分支..."></textarea>
            </div>
            <button onclick="addKnowledge()">添加知识</button>
            <div id="knowledge-status" class="status"></div>
            
            <h2 style="margin-top: 20px;">模型训练</h2>
            <div class="form-group">
                <label for="train-sentence">训练句子</label>
                <input type="text" id="train-sentence" placeholder="例如:我是一个AI助手">
            </div>
            <div class="form-group">
                <label for="train-label">标签 (01)</label>
                <input type="number" id="train-label" min="0" max="1" value="1">
            </div>
            <button onclick="trainModel()">训练模型</button>
            <div id="train-status" class="status"></div>
        </div>
    </div>

    <script>
        // 添加欢迎消息
        window.onload = function() {
            addMessage("您好!我是AI助手,有什么可以帮助您的吗?", false);
        };
        
        function addMessage(message, isUser) {
            const messagesDiv = document.getElementById('chat-messages');
            const messageDiv = document.createElement('div');
            messageDiv.className = `message ${isUser ? 'user-message' : 'ai-message'}`;
            messageDiv.textContent = message;
            messagesDiv.appendChild(messageDiv);
            messagesDiv.scrollTop = messagesDiv.scrollHeight;
        }

        function sendMessage() {
            const input = document.getElementById('user-input');
            const message = input.value.trim();
            if (message) {
                addMessage(message, true);
                input.value = '';
                
                // 显示加载状态
                const loadingId = showLoading();
                
                fetch('/chat', {
                    method: 'POST',
                    headers: {
                        'Content-Type': 'application/json',
                    },
                    body: JSON.stringify({ message: message })
                })
                .then(response => response.json())
                .then(data => {
                    // 移除加载状态
                    removeLoading(loadingId);
                    addMessage(data.response, false);
                })
                .catch(error => {
                    // 移除加载状态
                    removeLoading(loadingId);
                    console.error('Error:', error);
                    addMessage('抱歉,发生了一些错误。请稍后再试。', false);
                });
            }
        }
        
        function showLoading() {
            const messagesDiv = document.getElementById('chat-messages');
            const loadingDiv = document.createElement('div');
            loadingDiv.className = 'message ai-message';
            loadingDiv.textContent = '正在思考...';
            loadingDiv.id = 'loading-' + Date.now();
            messagesDiv.appendChild(loadingDiv);
            messagesDiv.scrollTop = messagesDiv.scrollHeight;
            return loadingDiv.id;
        }
        
        function removeLoading(id) {
            const loadingDiv = document.getElementById(id);
            if (loadingDiv) {
                loadingDiv.remove();
            }
        }
        
        function addKnowledge() {
            const key = document.getElementById('knowledge-key').value.trim();
            const value = document.getElementById('knowledge-value').value.trim();
            const statusDiv = document.getElementById('knowledge-status');
            
            if (!key || !value) {
                statusDiv.textContent = '请填写完整的知识信息';
                statusDiv.className = 'status error';
                return;
            }
            
            fetch('/add_knowledge', {
                method: 'POST',
                headers: {
                    'Content-Type': 'application/json',
                },
                body: JSON.stringify({ key: key, value: value })
            })
            .then(response => response.json())
            .then(data => {
                if (data.success) {
                    statusDiv.textContent = '知识添加成功!';
                    statusDiv.className = 'status success';
                    // 清空输入框
                    document.getElementById('knowledge-key').value = '';
                    document.getElementById('knowledge-value').value = '';
                } else {
                    statusDiv.textContent = data.message || '知识添加失败';
                    statusDiv.className = 'status error';
                }
            })
            .catch(error => {
                console.error('Error:', error);
                statusDiv.textContent = '发生错误,请稍后再试';
                statusDiv.className = 'status error';
            });
        }
        
        function trainModel() {
            const sentence = document.getElementById('train-sentence').value.trim();
            const label = parseInt(document.getElementById('train-label').value);
            const statusDiv = document.getElementById('train-status');
            
            if (!sentence || isNaN(label) || (label !== 0 && label !== 1)) {
                statusDiv.textContent = '请填写正确的训练数据';
                statusDiv.className = 'status error';
                return;
            }
            
            fetch('/train_model', {
                method: 'POST',
                headers: {
                    'Content-Type': 'application/json',
                },
                body: JSON.stringify({ sentence: sentence, label: label })
            })
            .then(response => response.json())
            .then(data => {
                if (data.success) {
                    statusDiv.textContent = `训练成功!损失: ${data.loss}`;
                    statusDiv.className = 'status success';
                    // 清空输入框
                    document.getElementById('train-sentence').value = '';
                } else {
                    statusDiv.textContent = data.message || '训练失败';
                    statusDiv.className = 'status error';
                }
            })
            .catch(error => {
                console.error('Error:', error);
                statusDiv.textContent = '发生错误,请稍后再试';
                statusDiv.className = 'status error';
            });
        }

        // 按回车发送消息
        document.getElementById('user-input').addEventListener('keypress', function(e) {
            if (e.key === 'Enter') {
                sendMessage();
            }
        });
    </script>
</body>
</html>

四、系统特点与优势

4.1 特点

  1. 模块化设计:各个功能模块独立,便于维护和扩展
  2. 知识持久化:支持知识的保存和加载
  3. 实时训练:支持动态添加训练数据
  4. 多模型集成:结合了分类和NER模型

4.2 优势

  1. 易于理解:代码结构清晰,适合学习
  2. 功能完整:包含了AI应用的基本要素
  3. 可扩展性:便于添加新功能

五、存在的问题与改进方向

5.1 存在的问题

  1. 功能混杂

    • 同时使用了多个模型,但功能不明确
    • 分类模型和NER模型的结果没有很好地整合
    • 知识库检索和模型预测结果没有很好地结合
  2. 技术实现问题

    • 分类模型训练数据太少(只有10个样本)
    • 模型训练过程过于简单
    • 没有充分利用BERT模型的优势
  3. 架构设计问题

    • 各个模块之间耦合度高
    • 错误处理不够完善
    • 缺乏完整的测试机制

5.2 改进方向

  1. 扩充训练数据
def add_training_data(self, sentences, labels):
    """添加更多训练数据"""
    self.training_data.extend(zip(sentences, labels))
    self.train(list(sentences), list(labels))
  1. 优化模型训练
def train(self, sentences, labels):
    """优化训练过程"""
    self.model.train()
    for epoch in range(self.num_epochs):
        # 添加学习率调度
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer)
        # 添加早停机制
        if self.should_early_stop():
            break
  1. 改进知识检索
def retrieve(self, query):
    """改进知识检索算法"""
    # 添加语义相似度计算
    # 使用向量检索
    # 添加缓存机制

六、总结与展望

6.1 项目总结

这个项目展示了如何构建一个基础的AI助手系统,并非一个实用的生产系统, 但作为学习案例具有很好的参考价值,它展示了:

  1. 基础架构

    • 如何组织AI应用的基本结构
    • 如何处理用户输入和输出
    • 如何管理数据和模型
  2. 技术要点

    • 预训练模型的使用
    • 基本的模型训练流程
    • 文本处理的基本方法
  3. 改进空间

    • 模型训练需要优化
    • 系统架构需要重构
    • 功能需要聚焦

对于想要学习大模型技术的开发者来说,这个项目可以作为一个起点,但需要在此基础上进行大量的改进和优化才能成为一个实用的系统。

6.2 未来展望

  1. 引入更先进的模型架构
  2. 优化系统性能
  3. 增加更多实用功能
  4. 改进用户体验

七、参考资料

  1. BERT论文:Attention Is All You Need
  2. Transformers库文档
  3. PyTorch官方文档
  4. Flask Web框架文档

这篇博文全面介绍了项目的实现过程、技术细节和未来展望,适合作为技术学习资料。您可以根据需要调整内容的深度和广度。


http://www.kler.cn/a/577633.html

相关文章:

  • python从入门到精通(二十六):python文件操作之Word全攻略(基于python-docx)
  • 算法006——和为S 的两个数
  • 智能合约中权限管理不当
  • 系统架构设计师—系统架构设计篇—微服务架构
  • 基于PyTorch的深度学习3——基于autograd的反向传播
  • 恢复IDEA的Load Maven Changes按钮
  • dify通过ollama简单配置deepseek模型
  • 如何禁用移动端页面的多点触控和手势缩放
  • 【Gaussian Model】高斯分布模型
  • 新手学习爬虫的案例
  • Centos8部署mongodb报错记录
  • Linux 基础---重定向命令(>、>>)、echo
  • 正版Windows10/11系统盘制作详细教程
  • Linux设备驱动开发之摄像头驱动移植(OV5640)
  • 尚硅谷爬虫note14
  • 【后端开发面试题】每日 3 题(九)
  • PDF 分割工具
  • 请谈谈 HTTP 中的重定向,如何处理 301 和 302 重定向?
  • 在Go语言中,判断变量是否为“空”(零值或未初始化状态)的方法总结
  • K8s 1.27.1 实战系列(六)Pod