当前位置: 首页 > article >正文

t-SNE进行分类可视化

0、引入

我们在论文中通常可以看到下图这样的可视化效果,这就是使用t-SNE降维方法进行的可视化,当然除了t-SNE还有其他的比如PCA等降维等方法,关于这些算法的原理有很多文章可以借阅,这里不展开阐释,重点讲讲如何进行可视化。
在这里插入图片描述

1、基本原理

上面的图中一个点就是一个样本,我们需要明白的是一个样本用两个数值表示(x和y坐标),意味着原来高维的样本被降维到低维(2维)的空间中了。

比如在将一个样本图片输入到VGG网络中,在倒数第二了全连接层有4096个神经元,也就是该样本使用了4096维的向量表示。我们获取到这个向量表示后通过t-SNE进行降维,得到2维的向量表示,我们就可以在平面图中画出该点的位置。

我们清楚同一类的样本,它们的4096维向量是有相似性的,并且降维到2维后也是具有相似性的,所以在2维平面上面它们会倾向聚拢在一起。

可视化的过程中,大概步骤就是:

  • 对于500张图片(样本)进行模型推理,获取倒数第二层的4096特征向量,同时获取到标签,至此我们有两个值data_embed=[500,4096],label=[500,]
  • 使用t-SNE降维,将data_embed降维后,每个样本用2维表示,即data_embed=[500,2]
  • 使用plt将data_embed共500个样本绘制,并且同一类的颜色一致。

2、收集模型中的高维向量表示和标签

这一步主要是收集每个样本的高维特征及标签,高维特征是全连接层的输出,所以需要通过推理模型获取到(需要修改模型使同时输出全连接层输出)。

data_embed_collect=[]
label_collect=[]

for ......
	# inputs.shape=[BS,C,H,W]
	# embed_4096.shape=[BS,4096]
	# output.shape=[BS,1000]
	output,embed_4096=model(inputs)
	
	
	data_embed_collect.append(embed_4096)
	label_collect.append(label)
	......


# data_embed_collect.shape=[iters,BS,4096]
# label_collect.shape=[iters,BS,]

# 在这里,所有样本的4096特征都收集了,并且每个样本的标签也收集了
# data_embed_npy.shape=[samples,4096]
# label_npu.shape=[samples,]
data_embed_npy=torch.cat(data_embed_collect,axis=0).cpu().numpy()
label_npu=torch.cat(label_collect,axis=0).cpu().numpy()

np.save("data_embed_npy.npy",data_embed_npy)
np.save("label_npu.npy",label_npu).

3、进行t-SNE降维并可视化

代码也简单,首先调用t-SNE进行降维降到2维,然后使用plt将2维定位坐标进行绘制,代码如下:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE


def get_fer_data(data_path="vis_fer_data.npy",
                 label_path="vis_fer_label.npy"):
    """
	该函数读取上一步保存的两个npy文件,返回data和label数据
    Args:
        data_path:
        label_path:

    Returns:
        data: 样本特征数据,shape=(BS,embed)
        label: 样本标签数据,shape=(BS,)
        n_samples :样本个数
        n_features:样本的特征维度

    """
    data = np.load(data_path)
    label = np.load(label_path)
    n_samples, n_features = data.shape

    return data, label, n_samples, n_features

color_map = ['r','y','k','g','b','m','c'] # 7个类,准备7种颜色
def plot_embedding_2D(data, label, title):
	"""
	
	"""
    x_min, x_max = np.min(data, 0), np.max(data, 0)
    data = (data - x_min) / (x_max - x_min)
    fig = plt.figure()
    for i in range(data.shape[0]):
        plt.plot(data[i, 0], data[i, 1],marker='o',markersize=1,color=color_map[label[i]])
    plt.xticks([])
    plt.yticks([])
    plt.title(title)
    return fig


def main():
    data, label, n_samples, n_features = get_fer_data()  # 根据自己的路径合理更改

    print('Begining......') 	

	# 调用t-SNE对高维的data进行降维,得到的2维的result_2D,shape=(samples,2)
    tsne_2D = TSNE(n_components=2, init='pca', random_state=0) 
    result_2D = tsne_2D.fit_transform(data)
    
    print('Finished......')
    fig1 = plot_embedding_2D(result_2D, label, 't-SNE')	# 将二维数据用plt绘制出来
    fig1.show()
    plt.pause(50)
    
if __name__ == '__main__':
    main()

7个类,共计3589个样本降维到2维后的效果:
在这里插入图片描述


http://www.kler.cn/news/10253.html

相关文章:

  • 【SpringMVC】7—文件上传
  • 详细讲讲Java线程的状态
  • 林长制信息系统主要建设思路
  • Java实现图片缩放裁剪,图片像素比例变更,批量转换图片像素比
  • 遗传算法优化深度信念网络DBN的分类预测,GA-DBN分类预测
  • C++ 的fcntl函数
  • ChatGPT搭建语音智能助手
  • 工作中英语学习的几个阶段
  • Three.js教程:第一个3D场景
  • 【MyBatis Plus】003 -- 配置(基本、进阶、DB策略) 条件构造器
  • Linux下使用ClamAV病毒查杀
  • Lottie加载的一些坑
  • 【OpenCV-Python】cvui 之 trackbar
  • 因果推断14--DRNet论文和代码学习
  • 如果让你做技术负责人,你会怎么设计后端架构?
  • 查看 Elasticsearch 分析器
  • selenium库有哪些功能呢?都是如何实现的呢?
  • ( “树” 之 DFS) 543. 二叉树的直径 ——【Leetcode每日一题】
  • Git的安装与基本使用
  • 2021蓝桥杯真题大写 C语言/C++
  • 计算机网络笔记(横向)
  • 代码随想录算法训练营第三十四天-贪心算法3| 1005.K次取反后最大化的数组和 134. 加油站 135. 分发糖果
  • 微服务+springcloud+springcloud alibaba学习笔记【Eureka服务注册中心】(3/9)
  • C++标准库--IO库(Primer C++ 第五版 · 阅读笔记)
  • 离散数学_第二章:基本结构:集合、函数、序列、求和和矩阵(1)
  • 探索树形数据结构,通识树、森林与二叉树的基础知识(专有名词),进一步利用顺序表和链表表示、遍历和线索树形结构
  • 梯度的看法
  • MyBatis配置文件 —— 相关标签详解
  • 干翻Hadoop系列之:Hadoop前瞻之分布式知识
  • Leetcode.1992 找到所有的农场组