当前位置: 首页 > article >正文

机器学习day8

自定义数据集 ,使用朴素贝叶斯对其进行分类

代码

import numpy as np
import matplotlib.pyplot as plt

class1_points = np.array([[2.1, 2.2], [2.4, 2.5], [2.2, 2.0], [2.0, 2.1], [2.3, 2.3], [2.6, 2.4], [2.5, 2.1]])
class2_points = np.array([[4.0, 3.5], [4.2, 3.9], [4.1, 3.8], [3.7, 3.4], [4.4, 3.6], [4.5, 3.7], [4.3, 3.9]])

X = np.concatenate((class1_points, class2_points), axis=0)
Y = np.concatenate((np.zeros(len(class1_points)), np.ones(len(class2_points))), axis=0)
print(Y)

prior_prob = [np.sum(Y == 0) / len(Y), np.sum(Y == 1) / len(Y)]

class_μ = [np.mean(X[Y == 0], axis=0), np.mean(X[Y == 1], axis=0)]
class_cov = [np.cov(X[Y == 0], rowvar=False), np.cov(X[Y == 1], rowvar=False)]

def pdf(x, mean, cov):
    n = len(mean)
    coff = 1 / (2 * np.pi) ** (n / 2) * np.sqrt(np.linalg.det(cov))
    exponent = np.exp(-(1 / 2) * np.dot(np.dot((x - mean).T, np.linalg.inv(cov)), (x - mean)))
    return coff * exponent

xx, yy = np.meshgrid(np.arange(0, 5, 0.05), np.arange(0, 5, 0.05))
grid_points = np.c_[xx.ravel(), yy.ravel()]

grid_label = []
for point in grid_points:
    poster_prob = []
    for i in range(2):
        likelihood = pdf(point, class_μ[i], class_cov[i])
        poster_prob.append(prior_prob[i] * likelihood)
    pre_class = np.argmax(poster_prob)
    grid_label.append(pre_class)

plt.scatter(class1_points[:, 0], class1_points[:, 1], c="blue", label="class 1")
plt.scatter(class2_points[:, 0], class2_points[:, 1], c="red", label="class 2")
plt.legend()

grid_label = np.array(grid_label)
pre_grid_label = grid_label.reshape(xx.shape)
contour = plt.contour(xx, yy, pre_grid_label, level=0.5, color='green')

plt.show()

效果


http://www.kler.cn/a/532572.html

相关文章:

  • 『 C 』 `##` 在 C 语言宏定义中的作用解析
  • C语言第六课:数组与字符串
  • 读书笔记--分布式架构的异步化和缓存技术原理及应用场景
  • 玉米苗和杂草识别分割数据集labelme格式1997张3类别
  • PyTorch数据建模
  • Nginx的配置文件 conf/nginx.conf /etc/nginx/nginx.conf 笔记250203
  • sql中奇数、偶数、正则
  • 【L2JMobius】ZGC requires Windows version 1803 or later
  • 宝塔面板端口转发其它端口至MySQL的3306
  • 关于大模型 AGI 应知应会_生在AI发展的时代
  • 51单片机入门_05_LED闪烁(常用的延时方法:软件延时、定时器延时;while循环;unsigned char 可以表示的数字是0~255)
  • 统计满足条件的4位数(信息学奥赛一本通-1077)
  • 【通俗易懂说模型】线性回归(附深度学习、机器学习发展史)
  • 高斯过程处理大型数据集时返回值为 0 的问题
  • Spring Cloud工程搭建
  • 中继器与集线器
  • 用deepseek制作我的第一个长视频---使用AI解决尝试新领域没有经验拖延的问题!
  • 在Mac mini M4上部署DeepSeek R1本地大模型
  • Linux下使用C++设计线程池
  • 路由表转发表考研知识点
  • 强化学习驱动的自适应模型选择与融合用于监督学习
  • 蓝桥杯刷题 DAY4:小根堆 区间合并+二分
  • 体会SHAP分析局部解释在预测模型实践中的意义
  • 分享|LLM通过D-E-P-S完成长时间与多步骤的任务
  • idea隐藏无关文件
  • 炒股-基本面分析