当前位置: 首页 > article >正文

TensorFlow 与 Matplotlib 核心知识点及实战案例

TensorFlow 与 Matplotlib 核心知识点及实战案例

一、训练监控可视化

1. 损失/准确率曲线

import matplotlib.pyplot as plt

# TensorFlow训练后获取历史数据
history = model.fit(...)

# 创建双纵坐标图表
fig, ax1 = plt.subplots(figsize=(10,5))

color = 'tab:red'
ax1.set_xlabel('Epochs')
ax1.set_ylabel('Loss', color=color)
ax1.plot(history.history['loss'], color=color, label='Train Loss')
ax1.tick_params(axis='y', labelcolor=color)

ax2 = ax1.twinx()  # 共享x轴
color = 'tab:blue'
ax2.set_ylabel('Accuracy', color=color)
ax2.plot(history.history['accuracy'], color=color, label='Accuracy')
ax2.tick_params(axis='y', labelcolor=color)

plt.title('Training Metrics')
fig.tight_layout()
plt.show()

场景:监控模型是否过拟合/欠拟合
解决问题:识别训练过程中的异常波动,确定最佳停止时机

2. 动态更新训练曲线

from IPython.display import clear_output

class LivePlotCallback(tf.keras.callbacks.Callback):
    def on_train_begin(self, logs=None):
        self.x = []
        self.losses = []
        self.acc = []
        
    def on_epoch_end(self, epoch, logs=None):
        self.x.append(epoch)
        self.losses.append(logs['loss'])
        self.acc.append(logs['accuracy'])
        
        clear_output(wait=True)
        plt.figure(figsize=(8,4))
        plt.subplot(1,2,1)
        plt.plot(self.x, self.losses, 'r-', label='loss')
        plt.legend()
        
        plt.subplot(1,2,2)
        plt.plot(self.x, self.acc, 'b-', label='accuracy')
        plt.legend()
        
        plt.tight_layout()
        plt.show()

# 在model.fit中调用
model.fit(..., callbacks=[LivePlotCallback()])

场景:实时监控训练过程(适用于Jupyter环境)
解决问题:即时发现训练异常,无需等待训练完成

二、数据可视化

1. 多通道图像显示

# 从TF Dataset中取一批数据
for images, labels in train_dataset.take(1):
    plt.figure(figsize=(10,10))
    for i in range(9):
        plt.subplot(3,3,i+1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        plt.axis("off")
    plt.show()

场景:检查数据增强效果、验证标签正确性
解决问题:发现错误标注或异常数据样本

2. 特征分布直方图

# 可视化MNIST数据分布
(train_images, train_labels), _ = tf.keras.datasets.mnist.load_data()

plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plt.hist(train_images.reshape(-1), bins=50, color='blue')
plt.title('Pixel Value Distribution')

plt.subplot(1,2,2)
plt.hist(train_labels, bins=10, edgecolor='black')
plt.xticks(range(10))
plt.title('Class Distribution')
plt.show()

场景:数据预处理阶段分析
解决问题:检测数据偏差(如类别不平衡、像素值异常)

三、模型分析

1. 特征映射可视化

# 获取CNN中间层输出
layer_outputs = [layer.output for layer in model.layers[:4]]
activation_model = tf.keras.Model(inputs=model.input, outputs=layer_outputs)

activations = activation_model.predict(img_array)

# 可视化卷积层激活
plt.figure(figsize=(12,6))
for i in range(32):  # 显示前32个滤波器
    plt.subplot(4,8,i+1)
    plt.imshow(activations[0,:,:,i], cmap='viridis')
    plt.axis('off')
plt.tight_layout()

场景:调试CNN模型特征提取能力
解决问题:验证卷积层是否有效捕捉特征

2. 混淆矩阵

from sklearn.metrics import confusion_matrix
import seaborn as sns

# 生成预测结果
y_pred = model.predict(test_images)
y_pred = np.argmax(y_pred, axis=1)

# 绘制混淆矩阵
cm = confusion_matrix(test_labels, y_pred)
plt.figure(figsize=(10,8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.show()

场景:分类模型性能分析
解决问题:识别模型易混淆的类别组合

四、高级应用

1. 梯度分布可视化

# 在梯度计算后获取梯度值
with tf.GradientTape() as tape:
    predictions = model(x_batch)
    loss = loss_fn(y_batch, predictions)
grads = tape.gradient(loss, model.trainable_weights)

# 绘制梯度分布
plt.figure(figsize=(10,6))
for i, grad in enumerate(grads):
    if 'dense' in model.weights[i].name:  # 只显示全连接层
        plt.hist(grad.numpy().flatten(), bins=50, alpha=0.5, 
                label=model.weights[i].name)
plt.legend()
plt.title('Gradient Distribution')
plt.xlabel('Gradient Value')
plt.ylabel('Frequency')

场景:调试梯度消失/爆炸问题
解决问题:诊断网络层的学习效率问题

2. 决策边界可视化

# 生成网格数据
xx, yy = np.meshgrid(np.linspace(-3,3,100), np.linspace(-3,3,100))
grid = np.c_[xx.ravel(), yy.ravel()]
probs = model.predict(grid).reshape(xx.shape)

# 绘制决策边界
plt.figure(figsize=(8,6))
plt.contourf(xx, yy, probs>0.5, alpha=0.3)
plt.scatter(X[:,0], X[:,1], c=y, edgecolors='k')
plt.title('Decision Boundary')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.colorbar()

场景:二分类问题模型分析
解决问题:直观展示模型的分类能力边界

使用建议

  • 数据探索阶段:优先使用直方图和散点图分析数据分布
  • 模型训练时:实时监控损失曲线,配合TensorBoard使用更佳
  • 模型调试期:结合激活可视化和梯度分析诊断网络问题
  • 结果汇报时:使用混淆矩阵和决策边界图增强说服力

这些可视化技巧将贯穿机器学习项目的全生命周期,帮助您更高效地完成以下关键任务:验证数据质量、监控训练过程、分析模型行为、解释预测结果。


http://www.kler.cn/a/596198.html

相关文章:

  • 2953. 统计完全子字符串(将题目中给的信息进行分组循环)
  • JavaIO流的使用和修饰器模式(直击心灵版)
  • 解释什么是受控组件和非受控组件
  • git 查看某个函数的所有提交日志
  • Web爬虫利器FireCrawl:全方位助力AI训练与高效数据抓取。本地部署方式
  • 【入门初级篇】报表基础操作与功能介绍
  • 大数据处理最容易的开源平台
  • 基于Python编程语言实现“机器学习”,用于车牌识别项目
  • Android Audio基础(52)—— ASoC的PCM逻辑设备
  • AGI成立的条件
  • jieba中文分词模块,详细使用教程
  • 基于 PyTorch 的 MNIST 手写数字分类模型
  • 学习笔记:黑马程序员JavaWeb开发教程(2025.3.21)
  • 卷积神经网络 - 汇聚层
  • 使用Three.js渲染器创建炫酷3D场景
  • m4i.22xx-x8系列-PCIe总线直流耦合5G采集卡
  • 基于Django的动物图像识别分析系统
  • 阿里云平台Vue项目打包发布
  • EtherCAT 八口交换机方案测试介绍,FCE1100助力工业交换机国产芯快速移植。
  • 《Python实战进阶》No26: CI/CD 流水线:GitHub Actions 与 Jenkins 集成