当前位置: 首页 > article >正文

Python代码解析:问题分类器实现

Python代码解析:问题分类器实现

    • 引言
    • 代码结构概览
    • 代码详解
      • 1. 导入必要的库
      • 2. 定义`QuestionClassify`类
      • 3. 训练模型
      • 4. 预测分类
      • 5. 加载训练数据
      • 6. 主程序
    • 总结
    • 参考资料

引言

在自然语言处理(NLP)中,问题分类是一个重要的任务。通过将问题分类到预定义的类别中,可以帮助我们更好地理解和处理用户的问题。本文将通过解析一段Python代码,详细介绍如何实现一个基于朴素贝叶斯分类器的问题分类器。

代码结构概览

首先,我们来看一下代码的整体结构:

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time     :2022/8/30 9:56
# @File     :question_classify.py
# @Description:问题分类
import os
import re

import jieba
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB

from common import file_util, constant, nlp_util


class QuestionClassify:
    """
        问题分类
    """

    def __init__(self):
        self.train_x, self.train_y = load_train_data()
        # 文本向量化
        self.tfidf_vec = TfidfVectorizer()
        self.train_vec = self.tfidf_vec.fit_transform(self.train_x).toarray()
        self.model = self.train_model_nb()

    # 训练模型
    def train_model_nb(self):
        """
            利用朴素贝叶斯分类器训练模型
            :return:
        """
        nb = MultinomialNB(alpha=0.01)
        nb.fit(self.train_vec, self.train_y)
        return nb

    def predict(self, question):
        """
        预测分类
        :param question:
        :return:
        """
        # 词性标注
        text_cut_gen = nlp_util.posseg(question)

        # 获取模板
        # 替换nr(人名)works(作品)ng(名词词素)
        # 原始问题
        text_src_list = []
        # 一般化的问题,把人名替换为nr,依此类推
        text_normal_list = []
        for item in text_cut_gen:
            text_src_list.append(item.word)
            if item.flag in ['nr', 'works', 'ng']:
                text_normal_list.append(item.flag)
            else:
                text_normal_list.append(item.word)

        # 拼成一句话
        question_normal = [" ".join(text_normal_list)]
        print(question_normal)
        question_vector = self.tfidf_vec.transform(question_normal).toarray()
        predict = self.model.predict(question_vector)[0]
        return predict


def load_train_data():
    train_x = []
    train_y = []
    file_path_list = file_util.get_file_list(os.path.join(constant.DATA_DIR, "question"))
    for file_item in file_path_list:
        # 获取文件名中的label
        label = re.sub(r'\D', "", file_item)
        if label.isnumeric():
            label_num = int(label)
            # 读取文件内容
            with (open(file_item, "r", encoding="utf-8")) as file:
                lines = file.readlines()
                for line in lines:
                    # 分词
                    word_list = list(jieba.cut(str(line).strip()))
                    train_x.append(" ".join(word_list))
                    train_y.append(label_num)
    return train_x, train_y


if __name__ == '__main__':
    classify = QuestionClassify()
    category = classify.predict("钱学森的作品有哪些")
    print(category)

代码详解

1. 导入必要的库

import os
import re

import jieba
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB

from common import file_util, constant, nlp_util
  • os:用于处理文件路径。
  • re:用于正则表达式操作,提取文件名中的标签。
  • jieba:用于中文分词。
  • TfidfVectorizer:用于文本向量化,计算TF-IDF值。
  • MultinomialNB:朴素贝叶斯分类器。
  • file_utilconstantnlp_util:自定义模块,分别用于文件操作、常量定义和NLP工具。

2. 定义QuestionClassify

class QuestionClassify:
    """
        问题分类
    """

    def __init__(self):
        self.train_x, self.train_y = load_train_data()
        # 文本向量化
        self.tfidf_vec = TfidfVectorizer()
        self.train_vec = self.tfidf_vec.fit_transform(self.train_x).toarray()
        self.model = self.train_model_nb()
  • __init__方法:初始化训练数据、TF-IDF向量化器和朴素贝叶斯模型。

3. 训练模型

def train_model_nb(self):
    """
        利用朴素贝叶斯分类器训练模型
        :return:
    """
    nb = MultinomialNB(alpha=0.01)
    nb.fit(self.train_vec, self.train_y)
    return nb
  • train_model_nb方法:使用朴素贝叶斯分类器训练模型,并返回训练好的模型。

4. 预测分类

def predict(self, question):
    """
    预测分类
    :param question:
    :return:
    """
    # 词性标注
    text_cut_gen = nlp_util.posseg(question)

    # 获取模板
    # 替换nr(人名)works(作品)ng(名词词素)
    # 原始问题
    text_src_list = []
    # 一般化的问题,把人名替换为nr,依此类推
    text_normal_list = []
    for item in text_cut_gen:
        text_src_list.append(item.word)
        if item.flag in ['nr', 'works', 'ng']:
            text_normal_list.append(item.flag)
        else:
            text_normal_list.append(item.word)

    # 拼成一句话
    question_normal = [" ".join(text_normal_list)]
    print(question_normal)
    question_vector = self.tfidf_vec.transform(question_normal).toarray()
    predict = self.model.predict(question_vector)[0]
    return predict
  • predict方法:对输入的问题进行词性标注,并将其转换为一般化的形式。然后使用TF-IDF向量化器将问题向量化,并使用训练好的模型进行预测。

5. 加载训练数据

def load_train_data():
    train_x = []
    train_y = []
    file_path_list = file_util.get_file_list(os.path.join(constant.DATA_DIR, "question"))
    for file_item in file_path_list:
        # 获取文件名中的label
        label = re.sub(r'\D', "", file_item)
        if label.isnumeric():
            label_num = int(label)
            # 读取文件内容
            with (open(file_item, "r", encoding="utf-8")) as file:
                lines = file.readlines()
                for line in lines:
                    # 分词
                    word_list = list(jieba.cut(str(line).strip()))
                    train_x.append(" ".join(word_list))
                    train_y.append(label_num)
    return train_x, train_y
  • load_train_data函数:从指定目录中读取训练数据文件,提取文件名中的标签,并对文件内容进行分词,生成训练数据集。

6. 主程序

if __name__ == '__main__':
    classify = QuestionClassify()
    category = classify.predict("钱学森的作品有哪些")
    print(category)
  • 创建QuestionClassify对象,并使用predict方法对输入的问题进行分类,打印分类结果。

总结

通过这段代码,我们学会了如何实现一个基于朴素贝叶斯分类器的问题分类器。这个过程包括导入必要的库、定义问题分类类、训练模型、预测分类、加载训练数据和运行主程序。希望这篇文章对你理解如何实现问题分类器有所帮助。

参考资料

  • Jieba中文分词
  • scikit-learn官方文档
  • Python re 模块文档
  • Python os 模块文档

希望这篇文章对你有所帮助!如果你有任何问题或建议,欢迎在评论区留言。


http://www.kler.cn/a/379604.html

相关文章:

  • PostgreSQL (八) 创建分区
  • 基于SpringBoot+Gpt个人健康管家管理系统【提供源码+答辩PPT+参考文档+项目部署】
  • 十四届蓝桥杯STEMA考试Python真题试卷第二套第五题
  • 如何对LabVIEW软件进行性能评估?
  • 【华为HCIP实战课程31(完整版)】中间到中间系统协议IS-IS路由汇总详解,网络工程师
  • 【每日一题】LeetCode - 三数之和
  • el-table type=“selection“换页多选数据丢失的解决办法
  • dify实战案例分享-基于多模态模型的发票识别
  • git submodule
  • 【AIGC】深入探索『后退一步』提示技巧:激发ChatGPT的智慧潜力
  • 【jvm】对象分配过程
  • PostgreSQL JOIN 操作深入解析
  • 《星光予你》系列网剧正式开机! “黑莲花”陷入时间循环攻略疯批霸总
  • 报错 sys_platform == “win32“ (from mmcv) (from versions: none)
  • excel表格文字识别-ocr表格文字提取api接口集成-python
  • 双向链表专题
  • word选择题转excel(一键转写,无格式要求)
  • 发货到印尼的海运报价
  • C++学习笔记----9、发现继承的技巧(七)---- 转换(1)
  • 蓝桥杯py组入门(bfs广搜)
  • git入门教程4:git工作流程
  • 【ARM Linux 系统稳定性分析入门及渐进 1.2 -- Crash 工具依赖内容】
  • 软考:通信系统架构设计
  • 【django】Django REST Framework 序列化与反序列化详解
  • 07.适配器模式设计思想
  • 论文学习——A Prompt Pattern Catalog to Enhance Prompt Engineering with ChatGPT