当前位置: 首页 > article >正文

机器学习-朴素贝叶斯

文章目录

  • 一、朴素贝叶斯简介
    • 1.含义
    • 2.公式
  • 二、代码实现
    • 1.数据加载和预处理
    • 2.切分数据集
    • 3.模型训练
    • 4.性能评估
    • 5.测试集预测
    • 6.详细代码
  • 三、总结

一、朴素贝叶斯简介

1.含义

朴素贝叶斯(Naive Bayes)是一种基于贝叶斯定理与特征条件独立假设的分类方法。它之所以被称为“朴素”,是因为它假设特征之间相互独立,即一个特征的出现与另一个特征无关,这在现实世界中往往不成立,但这一假设使得朴素贝叶斯分类器变得简单且高效。

2.公式

P ( A ∣ B ) = P ( B ∣ A ) P ( A ) P ( B ) P(A∣B)=\frac{P(B|A)P(A)}{P(B)} P(AB)=P(B)P(BA)P(A)
其中:

P(A∣B) 是后验概率,即在给定数据B下,属于类别A的概率。
P(B∣A) 是似然概率,即在类别A下观测到数据B的概率。
P(A) 是先验概率,即类别A出现的概率。
P(B) 是证据因子,对于所有类别是相同的,因此不影响分类决策。

二、代码实现

下面这段代码的主要目的是使用朴素贝叶斯分类器来对鸢尾花数据集进行分类,代码实现了使用多项式朴素贝叶斯对鸢尾花数据集进行分类的基本流程,包括数据加载、预处理、模型训练、预测和性能评估,我们将其一步步拆分,进行更详细的讲解。

1.数据加载和预处理

import pandas as pd

def cm_plot(y, yp):
    from sklearn.metrics import confusion_matrix
    import matplotlib.pyplot as plt

    cm = confusion_matrix(y, yp)
    plt.matshow(cm, cmap=plt.cm.Blues)
    plt.colorbar()
    for x in range(len(cm)):
        for y in range(len(cm)):
            plt.annotate(cm[x, y], xy=(y, x), horizontalalignment='center', verticalalignment='center')
            plt.ylabel('True label')
            plt.xlabel('Predicted label')
    return plt

data = pd.read_csv("iris.csv")
data_a = data.drop(columns=data.columns[0])

x = data_a.iloc[:, :-1]
y = data_a.iloc[:, -1]
  • 使用pandas读取iris.csv文件,该文件应包含鸢尾花数据集。
  • 数据集的第一列是不需要的列,因此删除。
  • 数据集被分为特征集x和目标集y,其中x包含除最后一列外的所有列,y包含最后一列。

2.切分数据集

from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = \
    train_test_split(x, y, test_size=0.2, random_state=0)
  • 使用train_test_split函数将数据集切分为训练集和测试集,测试集占总数据的20%,随机种子设为0,以便拥有可重复性。

3.模型训练

from sklearn.naive_bayes import MultinomialNB#导入朴素贝叶斯分类器
classifier = MultinomialNB(alpha=1)
classifier.fit(x_train,y_train)

  • 使用多项式朴素贝叶斯分类器对训练集进行训练。

4.性能评估

#绘制混淆矩阵
from sklearn import metrics
train_pred = classifier.predict(x_train)
cm_plot(y_train,train_pred).show()
print(metrics.classification_report(y_train, train_pred))
score = classifier.score(x_train, y_train)
print(score)
  • 对训练集进行预测,得到预测结果train_pred。
  • 使用cm_plot函数绘制训练集的混淆矩阵,并进行可视化。
  • 打印分类报告,该报告提供了主要分类指标的文本报告,如精确度、召回率、F1分数等。
  • 打印训练集上的准确度分数,来评估训练集上的性能。

5.测试集预测

test_pred = classifier.predict(x_test)
cm_plot(y_test,test_pred).show()
print(metrics.classification_report(y_test, test_pred))
  • 对测试集进行预测,并展示测试集的混淆矩阵,最后将分类报告打印出来。

6.详细代码

import pandas as pd

def cm_plot(y, yp):
    from sklearn.metrics import confusion_matrix
    import matplotlib.pyplot as plt

    cm = confusion_matrix(y, yp)
    plt.matshow(cm, cmap=plt.cm.Blues)
    plt.colorbar()
    for x in range(len(cm)):
        for y in range(len(cm)):
            plt.annotate(cm[x, y], xy=(y, x), horizontalalignment='center', verticalalignment='center')
            plt.ylabel('True label')
            plt.xlabel('Predicted label')
    return plt

data = pd.read_csv("iris.csv")
data_a = data.drop(columns=data.columns[0])

x = data_a.iloc[:, :-1]
y = data_a.iloc[:, -1]

"""切分数据集"""
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = \
    train_test_split(x, y, test_size=0.2, random_state=0)

from sklearn.naive_bayes import MultinomialNB#导入朴素贝叶斯分类器
classifier = MultinomialNB(alpha=1)
classifier.fit(x_train,y_train)

"""预测训练集"""
#绘制混淆矩阵
from sklearn import metrics
train_pred = classifier.predict(x_train)
cm_plot(y_train,train_pred).show()
print(metrics.classification_report(y_train, train_pred))
score = classifier.score(x_train, y_train)
print(score)
"""测试集预测"""
test_pred = classifier.predict(x_test)
cm_plot(y_test,test_pred).show()
print(metrics.classification_report(y_test, test_pred))

三、总结

朴素贝叶斯分类器因其简单性和高效性,在文本分类、垃圾邮件检测、情感分析等领域有着广泛的应用。但同时也有着自己的优缺点。

  • 优点
    • 简单高效:由于假设特征之间相互独立,大大简化了计算。
    • 处理缺失数据:对缺失数据不敏感,可以通过忽略该特征或使用该特征的先验概率来处理。
    • 易于实现:算法实现相对简单,易于理解和应用。
  • 缺点
    • 特征独立性假设:现实中特征之间往往存在相关性,这一假设限制了朴素贝叶斯的性能。
    • 参数估计问题:如果某个特征在训练数据中未出现,则条件概率为零,这会导致整个后验概率为零,即所谓的“零概率问题”。

http://www.kler.cn/a/282768.html

相关文章:

  • 释放高级功能:Nexusflows Athene-V2-Agent在工具使用和代理用例方面超越 GPT-4o
  • NPOI 实现Excel模板导出
  • C++——视频问题总结
  • 【Linux】Linux 权限的理解
  • debian 系统更新升级
  • 创建型设计模式与面向接口编程
  • LeetCode题练习与总结:添加与搜索单词 - 数据结构设计--211
  • 模糊C-means算法原理及Python实践
  • 电容应用原理
  • 如何构建基于Java SpringBoot和Vue的受灾救援物资管理系统?——四步实现物资高效调配,提升救援响应速度
  • 速盾:企业在使用高防 IP 和 CDN 时如何确保数据的安全性?
  • MYSQL数据库(三)
  • 使用Python从图像中提取文本的OCR库详解
  • 易保全线上赋强公证解决方案,助力业务纠纷高质效化解
  • 【设计模式】单例模式、工厂模式、策略模式、观察者模式、装饰器模式
  • 云存储服务器租用的好处有哪些?
  • HCIP是什么?HCIP认证解析!
  • npm创建项目一直等待
  • 视频压缩怎么操作?三个办法教你无损压缩视频
  • SQL 语句及其分类
  • cesium 实现克里金生成矢量等值面,使用worker浏览器线程
  • 速盾:如何选择适合企业的高防 IP 和 CDN?
  • Nginx负载均衡实现:深入配置与最佳实践
  • 提交保存,要做重复请求拦截,避免出现重复保存的问题
  • 数据结构与算法——深度优先搜索(DFS)和广度优先搜索(BFS)
  • j9、vue、uni-app、小程序的页面传参方式