当前位置: 首页 > article >正文

机器学习 AdaBoost 算法

AdaBoost 提升学习算法是通过训练多个弱分类算法实现一个强分类算法,做法非常朴素,在训练过程中,提供分类错误的数据权重,降低分类正确的权重,提高分类效果好的弱分类器权重,降低分类效果差的若分类器权重。

  • AdaBoost 公式
    sign 代表符号,如果 > 0 返回 1,如果 < 0 返回 -1。
    在这里插入图片描述

  • 损失函数

在这里插入图片描述

  • 权重计算
    ε 是错误值,ε < 0.5 比较好,= 0.5 和随机猜一样,> 0.5 还不如随机了。
    在这里插入图片描述

SKLearn 实现 AdaBoost

  • 生成测试数据
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
# 导入sklearn模拟二分类数据生成模块
from sklearn.datasets import make_blobs

# 生成模拟二分类数据集
X, y =  make_blobs(n_samples=150, n_features=2, centers=2,
  cluster_std=1.2, random_state=40)
# 将标签转换为1/-1
y_ = y.copy()
y_[y_==0] = -1
y_ = y_.astype(float)
# 训练/测试数据集划分
X_train, X_test, y_train, y_test = train_test_split(X, y_,
 test_size=0.3, random_state=43)
# 设置颜色参数
colors = {0:'r', 1:'g'}
# 绘制二分类数据集的散点图
plt.scatter(X[:,0], X[:,1], marker='o', c=pd.Series(y).map(colors))
plt.show();

在这里插入图片描述

  • 训练、测试
# 导入sklearn adaboost分类器
from sklearn.ensemble import AdaBoostClassifier
# 创建Adaboost模型实例
clf_ = AdaBoostClassifier(n_estimators=5, random_state=0)
# 模型拟合
clf_.fit(X_train, y_train)
# 模型预测
y_pred_ = clf_.predict(X_test)
# 计算模型预测准确率
accuracy = accuracy_score(y_test, y_pred_)
print("Accuracy of AdaBoost by sklearn:", accuracy)

在这里插入图片描述

总结

AdaBoost 分类器是一种多个弱分类器的组合,AdaBoost、SVM、逻辑回归各自适应不同的场景,下表列出了各个模型不同的特性,可以根据自己的业务场景进行选择。

标准AdaBoost逻辑回归支持向量机 (SVM)
噪声敏感性高(容易对噪声过拟合)中等(软间隔有所帮助)
非线性能力使用决策树效果好差(线性)非常好(使用非线性核)
计算成本中等到高使用非线性核时高
可解释性中等(线性 SVM)
离群值敏感性中等低(使用软间隔)

http://www.kler.cn/a/403045.html

相关文章:

  • java学习-集合
  • 常用Adb 命令
  • javascrip页面交互
  • 设计模式之 状态模式
  • 机器学习day6-线性代数2-梯度下降
  • 23种设计模式速记法
  • 使用Python推送FLV流
  • 《Vue零基础教程》(1)Vue简介
  • C# AutoMapper 10个常用方法总结
  • Spring Boot 项目 myblog 整理
  • 智能购物时代:AI在电商平台的革命性应用
  • 针对AI增强图像大规模鲁棒性测试的数据集
  • 15分钟学 Go 实战项目六 :统计分析工具项目(30000字完整例子)
  • ssl证书,以 Nginx 为例
  • 如何构建高效的接口自动化测试框架?
  • Halcon 分割之区域生长法
  • 拓展Git相关知识(⭐版控工具⭐)
  • 量化交易系统开发-实时行情自动化交易-3.4.3.3.期货市场深度数据
  • Golang语言整合jwt+gin框架实现token
  • 学习threejs,对模型多个动画切换展示
  • Matlab多输入单输出之倾斜手写数字识别
  • os库的常见使用
  • 星融元与焱融科技AI分布式存储软硬件完成兼容性互认证
  • 13.C++内存管理2(C++ new和delete的使用和原理详解,内存泄漏问题)
  • 数据结构(双向链表——c语言实现)
  • Restful API 规范详解