当前位置: 首页 > article >正文

机器学习西瓜书——线性判别分析LDA

datawhale task04

线性判别分析LDA

线性判别分析是一种由监督的分类器,它的简单原理是将高维数据降维映射到一个向量上然后根据样本的欧式距离定进行分析分类。LDA的关键就是找到这个向量w。构造损失函数让同类样本的投影点尽可能近,即让同种样本的方差尽量小;让异类样本投影点尽可能远,即让类样本中心尽可能远,最优化损失函数求出w。在预测时制定一个阈值theta(通常为0),让新样本投影到w上,让投影值与theta做比较从而预测出样本的类别。下面以二分类为例子进行介绍

在这里插入图片描述

预测公式

y = w T x y=w^Tx y=wTx

其中w是方向向量,x是样本向量,得到的y是预测值,让y与theta做比较得出预测的类别

优化公式

在这里插入图片描述

定义类内散度矩阵:
S w = ∑ 0 + ∑ 1 = ∑ x ∈ X 0 ( x − μ 0 ) ( x − μ ) T + ∑ x ∈ X 1 ( x − μ 1 ) ( x − μ ) T \begin{align} S_w&=\sum{}_0+\sum{}_1\\ &=\sum_{x\in{X_0}}(x-\mu_0)(x-\mu)^T+\sum_{x\in{X_1}}(x-\mu_1)(x-\mu)^T \end{align} Sw=0+1=xX0(xμ0)(xμ)T+xX1(xμ1)(xμ)T
衡量的是所有类别各自类的样本点的离散度的和。

定义类间散度矩阵
S b = ( μ 0 − μ 1 ) ( μ 0 − μ 1 ) T S_b=(\mu_0-\mu_1)(\mu_0-\mu_1)^T Sb=(μ0μ1)(μ0μ1)T
重写3.32为
J = w T S b w w T S w w S w − 1 ( μ 0 − μ 1 ) J=\frac{w^TSbw}{w^TS_ww}\\ S_w^{-1}(\mu_0-\mu_1) J=wTSwwwTSbwSw1(μ0μ1)
J就是最优化的目标,由于w在这里只关注方向而不关注长度,换句话说长度在这里不影响J最终计算结果,所以使用拉格朗日乘子法,求w^Ts_ww=1的条件下-J的最小值。

最终化简得到w的解析解:
w = S w − 1 ( μ 0 − μ 1 ) w=S_w^{-1}(\mu_0-\mu_1) w=Sw1(μ0μ1)
mu0-mu1很好求得,难的是Sw的逆,为了数值解的稳定性,通常使用SVD法求解Sw的逆。
S w = U σ V T S w − 1 = V σ − 1 U T \begin{align} S_w&=U\sigma V^T\\ S_w^{-1}&=V\sigma^{-1}U^T \end{align} SwSw1=UσVT=Vσ1UT
最终求得w

代码实现

import torch
import numpy as np
from numpy import linalg as la
import matplotlib.pyplot as plt
# 数据1、边界不清晰
# 设置随机种子,确保每次运行结果一致
np.random.seed(42)
# 生成正例数据
n_samples = 50
mean_positive = [0, 3]  # 正例的中心点
cov = [[1, 0], [0, 1]]  # 协方差矩阵
positive_samples = np.random.multivariate_normal(mean_positive, cov, n_samples)
# 生成负例数据
mean_negative = [0, -3]  # 负例的中心点,将中心点距离改为 -2 
negative_samples = np.random.multivariate_normal(mean_negative, cov, n_samples) 
# 合并数据
features = np.concatenate((positive_samples, negative_samples), axis=0).T
labels = np.concatenate((np.ones(n_samples), np.zeros(n_samples)))
# 打乱数据
indices = np.random.permutation(features.shape[1])
features = features[:, indices]
labels = labels[indices]
# 将数据转换为 PyTorch 张量
features = torch.tensor(features, dtype=float)
labels = torch.tensor(labels, dtype=float).reshape((-1, 1))
# 第一次画图,使用设置的坐标范围
x_min=-5
x_max=5
y_min=-5
y_max=5
plt.figure(figsize=(8, 6))
plt.scatter(features[0, :], features[1, :], c=labels.numpy().flatten())
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.title('Logistic Regression Data Visualization')
plt.xlim(x_min, x_max)
plt.ylim(y_min, y_max)
plt.show()

def classify(features,labels):
    positive=features[:,torch.where(labels==1)[0]]
    negative=features[:,torch.where(labels==0)[0]]
    return positive,negative

def get_Sw(positive,negative,positiveAvg,negativeAvg):
    return (positive-positiveAvg)@(positive-positiveAvg).T+(negative-negativeAvg)@(negative-negativeAvg).T

def F1_score(y_hat,y):
    '''输入预测值y_hat和实际值y,计算F1_score评估模型性能'''
    TP=torch.sum(torch.where((y==1)&(y_hat==1),torch.tensor(1),torch.tensor(0)))
    # TP=torch.sum(y==1&y_hat==1)
    FP=torch.sum(torch.where((y==0)&(y_hat==1),torch.tensor(1),torch.tensor(0)))
    # FP=torch.sum(y==0&y_hat==1)
    FN=torch.sum(torch.where((y==1)&(y_hat==0),torch.tensor(1),torch.tensor(0)))
    # FN=torch.sum(y==1&y_hat==0)
    P=TP/(TP+FP)
    R=TP/(TP+FN)
    return 2*P*R/(P+R)

positive,negative=classify(features,labels)
# 正例和负例的均值
positiveAvg=torch.mean(positive,dim=1).reshape(2,1)
negativeAvg=torch.mean(negative,dim=1).reshape(2,1)
Sw=get_Sw(positive,negative,positiveAvg,negativeAvg)
u,s,v=la.svd(Sw)
u=torch.tensor(u)
sigma=torch.tensor(np.diag(s))
v=torch.tensor(v)
Sw_inv=v@la.inv(sigma)@u.T
w=Sw_inv@(negativeAvg-positiveAvg)

#绘图
fig,ax=plt.subplots()
ax.scatter(positive[0,:],positive[1,:],c='b')
ax.scatter(negative[0,:],negative[1,:],c='g')
ax.set_xlim(-6,6)
ax.set_ylim(-6,6)
ax.set_title('Original')
x=np.linspace(-2,2,10)
theta=w[1]/w[0]
y=theta*x
ax.plot(x,y,c='r',linewidth=2.0)

predict=torch.where(w.T@features>0,torch.tensor(0),torch.tensor(1)).T
pre_positive,pre_negative=classify(features,predict)
plot,ax=plt.subplots()
ax.set_xlim(-6,6)
ax.set_ylim(-6,6)
ax.scatter(pre_positive[0,:],pre_positive[1,:],c='b')
ax.scatter(pre_negative[0,:],pre_negative[1,:],c='g')
ax.set_title('Predit')
x=np.linspace(-2,2,10)
theta=w[1]/w[0]
y=theta*x
ax.plot(x,y,c='r',linewidth=2.0)
#计算F1得分
F1=F1_score(predict,labels)
print(F1)

原始数据分类

预测出的数据分类

使用西瓜书3slpha数据

import numpy as np
from numpy import linalg as la
import torch
# 使用西瓜数据
features = torch.tensor([
    [0.697, 0.46],
    [0.774, 0.376],
    [0.634, 0.264],
    [0.608, 0.318],
    [0.556, 0.215],
    [0.403, 0.237],
    [0.481, 0.149],
    [0.437, 0.211],
    [0.666, 0.091],
    [0.243, 0.267],
    [0.245, 0.057],
    [0.343, 0.099],
    [0.639, 0.161],
    [0.657, 0.198],
    [0.36, 0.37],
    [0.593, 0.042],
    [0.719, 0.103]
])

labels = torch.tensor([
    1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0
]).reshape((-1, 1))
def classify(features,labels):
    positive=features[torch.where(labels==1)[0],:]
    negative=features[torch.where(labels==0)[0],:]
    return positive,negative
positive,negative=classify(features,labels)

def get_Sw(positive,negative,positiveAvg,negativeAvg):
    return (positive-positiveAvg).T@(positive-positiveAvg)+(negative-negativeAvg).T@(negative-negativeAvg)
Sw=get_Sw(positive,negative,positiveAvg,negativeAvg)
# 正例和负例的均值
positiveAvg=torch.mean(positive,dim=0).reshape(1,2)
negativeAvg=torch.mean(negative,dim=0).reshape(1,2)

u,s,v=la.svd(Sw)
u=torch.tensor(u)
sigma=torch.tensor(np.diag(s))
v=torch.tensor(v)
Sw_inv=v@la.inv(sigma)@u.T
w=Sw_inv@(negativeAvg-positiveAvg).T
#绘图
from matplotlib import pyplot as plt
fig,ax=plt.subplots()
ax.scatter(positive[:,0],positive[:,1],c='b')
ax.scatter(negative[:,0],negative[:,1],c='g')
x=np.linspace(0,1,50)
theta=w[1]/w[0]
y=theta*x
ax.plot(x,y,c='r',linewidth=2.0)

在这里插入图片描述

还可以使用sklearn中封装好的LDA分类器,计算出的w的方向与我们手搓的是一样的

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris

lda = LDA()
lda.fit(features, labels.reshape(-1))
w = lda.coe

在这里插入图片描述


http://www.kler.cn/news/328314.html

相关文章:

  • 使用PyTorch实现自然语言处理:从基础到实践
  • Go基础学习07-map注意事项;多协程对map的资源竞争;sync.Mutex避免竟态条件
  • QEMU使用Qemu-Guest-Agent传输文件、执行指令等
  • http增删改查四种请求方式操纵数据库
  • k8s 1.28.2 集群部署 ingress 1.11.1 包含 admission-webhook
  • Qt 中的 QListWidget、QTreeWidget 和 QTableWidget:简化的数据展示控件
  • 开发微信小程序 基础03
  • scala 2.12 error: value foreach is not a member of Object
  • 低代码用户中心:构建高效便捷的用户管理平台
  • VUE 开发——AJAX学习(二)
  • 51单片机学习第五课---B站UP主江协科技
  • 【网络安全】消息鉴别
  • 五.海量数据实时分析-FlinkCDC+DorisConnector实现数据的全量增量同步
  • Require:基于雪花算法完成一个局部随机,全局离散没有热点切唯一的数值Id生成器。
  • FileLink跨网文件交换:高效、安全、灵活的企业文件传输新方案
  • 力扣10.1
  • 5QI(5G QoS Identifier)
  • 《Linux从小白到高手》理论篇(二):Linux的目录结构和磁盘管理
  • 基于贝叶斯优化CNN-GRU网络的数据分类识别算法matlab仿真
  • python画图|自制渐变柱状图
  • 鸿蒙开发(NEXT/API 12)【穿戴设备信息查询】手机侧应用开发
  • 影院管理新篇章:小徐的Spring Boot应用
  • 低代码时代的企业信息化:规范与标准化的重要性
  • Redis: Sentinel哨兵监控架构及环境搭建
  • 通信工程学习:什么是LAN局域网、MAN城域网、WAN广域网
  • HarmonyOS Next应用开发——@build构建函数的使用
  • 每天一个数据分析题(四百九十一)- 主成分分析与因子分析
  • linux下recoketmq安装教程
  • JVM有哪些参数以及如何使用
  • 基于Java+SQL Server2008开发的(CS界面)个人财物管理系统