当前位置: 首页 > article >正文

机器学习10

自定义数据集 使用scikit-learn中svm的包实现svm分类

代码

import numpy as np
import matplotlib.pyplot as plt


class1_points = np.array([[1.9, 1.2],
                          [1.5, 2.1],
                          [1.9, 0.5],
                          [1.5, 0.9],
                          [0.9, 1.2],
                          [1.1, 1.7],
                          [1.4, 1.1]])

class2_points = np.array([[3.2, 3.2],
                          [3.7, 2.9],
                          [3.2, 2.6],
                          [1.7, 3.3],
                          [3.4, 2.6],
                          [4.1, 2.3],
                          [3.0, 2.9]])


x1_data = np.concatenate((class1_points[:, 0], class2_points[:, 0]))
x2_data = np.concatenate((class1_points[:, 1], class2_points[:, 1]))
y = np.concatenate((np.ones(class1_points.shape[0]), -np.ones(class2_points.shape[0])))


w1 = 0.1
w2 = 0.1
b = 0
learning_rate = 0.05


l_data = x1_data.size


fig, (ax1, ax2) = plt.subplots(2, 1)

step_list = np.array([])  # 初始化为空数组
loss_values = np.array([])  # 初始化为空数组


num_iterations = 1000
for n in range(1, num_iterations + 1):
    z = w1 * x1_data + w2 * x2_data + b
    yz = y * z
    loss = 1 - yz
    loss[loss < 0] = 0
    hinge_loss = np.mean(loss)
    loss_values = np.append(loss_values, hinge_loss)
    step_list = np.append(step_list, n)


    gradient_w1 = 0
    gradient_w2 = 0
    gradient_b = 0
    for i in range(len(y)):
        if loss[i] > 0:
            gradient_w1 += -y[i] * x1_data[i]
            gradient_w2 += -y[i] * x2_data[i]
            gradient_b += -y[i]

    gradient_w1 /= len(y)
    gradient_w2 /= len(y)
    gradient_b /= len(y)


    w1 -= learning_rate * gradient_w1
    w2 -= learning_rate * gradient_w2
    b -= learning_rate * gradient_b

    # 显示频率设置
    frequence_display = 50
    if n % frequence_display == 0 or n == 1:

        if np.abs(w2) < 1e-5:
            continue
        x1_min, x1_max = 0, 6
        x2_min, x2_max = -(w1 * x1_min + b) / w2, -(w1 * x1_max + b) / w2
        ax1.clear()
        ax1.scatter(x1_data[:len(class1_points)], x2_data[:len(class1_points)], c='red', label='Class 1')
        ax1.scatter(x1_data[len(class1_points):], x2_data[len(class1_points):], c='blue', label='Class 2')
        ax1.plot((x1_min, x1_max), (x2_min, x2_max), 'r-')
        ax1.set_title(f"SVM: w1={round(w1.item(), 3)}, w2={round(w2.item(), 3)}, b={round(b.item(), 3)}")


        ax2.clear()
        ax2.plot(step_list, loss_values, 'g-')
        ax2.set_xlabel("Step")
        ax2.set_ylabel("Loss")
        # 显示图形
        plt.pause(1)
        

plt.show()

效果展示


http://www.kler.cn/a/530956.html

相关文章:

  • AI技术在SEO关键词优化中的应用策略与前景展望
  • docker gitlab arm64 版本安装部署
  • stm32硬件实现与w25qxx通信
  • 论文阅读(十):用可分解图模型模拟连锁不平衡
  • 96,【4】 buuctf web [BJDCTF2020]EzPHP
  • 9 点结构模块(point.rs)
  • guava:基于TypeToken解析泛型类的类型变量(TypeVariable)的具体类型
  • Python处理数据库:MySQL与SQLite详解
  • Python小游戏29乒乓球
  • SQL范式与反范式_优化数据库性能
  • 基于LLM的路由在专家混合应用:一种新颖的交易框架,该框架在夏普比率和总回报方面提升了超过25%
  • S4 HANA明确税金汇差科目(OBYY)
  • Windows 中的 WSL:开启你的 Linux 之旅
  • 掌握API和控制点(从Java到JNI接口)_36 JNI开发与NDK 04
  • 637. 二叉树的层平均值
  • 每日 Java 面试题分享【第 20 天】
  • OpenAI宣布ChatGPT集成到苹果操作系统,将带来哪些新功能?
  • Rust结构体方法语法:让数据拥有行为
  • DeepSeek 集成到个人网站的详细步骤
  • CompletableFuture使用
  • 简易CPU设计入门:指令单元(二)
  • Google C++ Style / 谷歌C++开源风格
  • 租房管理系统助力数字化转型提升租赁服务质量与用户体验
  • csapp笔记3.6节——控制(1)
  • 整形的存储形式和浮点型在计算机中的存储形式
  • 【Redis】安装配置Redis超详细教程 / Linux版