TensorFlow 与 Matplotlib 核心知识点及实战案例
TensorFlow 与 Matplotlib 核心知识点及实战案例
一、训练监控可视化
1. 损失/准确率曲线
import matplotlib.pyplot as plt
# TensorFlow训练后获取历史数据
history = model.fit(...)
# 创建双纵坐标图表
fig, ax1 = plt.subplots(figsize=(10,5))
color = 'tab:red'
ax1.set_xlabel('Epochs')
ax1.set_ylabel('Loss', color=color)
ax1.plot(history.history['loss'], color=color, label='Train Loss')
ax1.tick_params(axis='y', labelcolor=color)
ax2 = ax1.twinx() # 共享x轴
color = 'tab:blue'
ax2.set_ylabel('Accuracy', color=color)
ax2.plot(history.history['accuracy'], color=color, label='Accuracy')
ax2.tick_params(axis='y', labelcolor=color)
plt.title('Training Metrics')
fig.tight_layout()
plt.show()
场景:监控模型是否过拟合/欠拟合
解决问题:识别训练过程中的异常波动,确定最佳停止时机
2. 动态更新训练曲线
from IPython.display import clear_output
class LivePlotCallback(tf.keras.callbacks.Callback):
def on_train_begin(self, logs=None):
self.x = []
self.losses = []
self.acc = []
def on_epoch_end(self, epoch, logs=None):
self.x.append(epoch)
self.losses.append(logs['loss'])
self.acc.append(logs['accuracy'])
clear_output(wait=True)
plt.figure(figsize=(8,4))
plt.subplot(1,2,1)
plt.plot(self.x, self.losses, 'r-', label='loss')
plt.legend()
plt.subplot(1,2,2)
plt.plot(self.x, self.acc, 'b-', label='accuracy')
plt.legend()
plt.tight_layout()
plt.show()
# 在model.fit中调用
model.fit(..., callbacks=[LivePlotCallback()])
场景:实时监控训练过程(适用于Jupyter环境)
解决问题:即时发现训练异常,无需等待训练完成
二、数据可视化
1. 多通道图像显示
# 从TF Dataset中取一批数据
for images, labels in train_dataset.take(1):
plt.figure(figsize=(10,10))
for i in range(9):
plt.subplot(3,3,i+1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.title(class_names[labels[i]])
plt.axis("off")
plt.show()
场景:检查数据增强效果、验证标签正确性
解决问题:发现错误标注或异常数据样本
2. 特征分布直方图
# 可视化MNIST数据分布
(train_images, train_labels), _ = tf.keras.datasets.mnist.load_data()
plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plt.hist(train_images.reshape(-1), bins=50, color='blue')
plt.title('Pixel Value Distribution')
plt.subplot(1,2,2)
plt.hist(train_labels, bins=10, edgecolor='black')
plt.xticks(range(10))
plt.title('Class Distribution')
plt.show()
场景:数据预处理阶段分析
解决问题:检测数据偏差(如类别不平衡、像素值异常)
三、模型分析
1. 特征映射可视化
# 获取CNN中间层输出
layer_outputs = [layer.output for layer in model.layers[:4]]
activation_model = tf.keras.Model(inputs=model.input, outputs=layer_outputs)
activations = activation_model.predict(img_array)
# 可视化卷积层激活
plt.figure(figsize=(12,6))
for i in range(32): # 显示前32个滤波器
plt.subplot(4,8,i+1)
plt.imshow(activations[0,:,:,i], cmap='viridis')
plt.axis('off')
plt.tight_layout()
场景:调试CNN模型特征提取能力
解决问题:验证卷积层是否有效捕捉特征
2. 混淆矩阵
from sklearn.metrics import confusion_matrix
import seaborn as sns
# 生成预测结果
y_pred = model.predict(test_images)
y_pred = np.argmax(y_pred, axis=1)
# 绘制混淆矩阵
cm = confusion_matrix(test_labels, y_pred)
plt.figure(figsize=(10,8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.show()
场景:分类模型性能分析
解决问题:识别模型易混淆的类别组合
四、高级应用
1. 梯度分布可视化
# 在梯度计算后获取梯度值
with tf.GradientTape() as tape:
predictions = model(x_batch)
loss = loss_fn(y_batch, predictions)
grads = tape.gradient(loss, model.trainable_weights)
# 绘制梯度分布
plt.figure(figsize=(10,6))
for i, grad in enumerate(grads):
if 'dense' in model.weights[i].name: # 只显示全连接层
plt.hist(grad.numpy().flatten(), bins=50, alpha=0.5,
label=model.weights[i].name)
plt.legend()
plt.title('Gradient Distribution')
plt.xlabel('Gradient Value')
plt.ylabel('Frequency')
场景:调试梯度消失/爆炸问题
解决问题:诊断网络层的学习效率问题
2. 决策边界可视化
# 生成网格数据
xx, yy = np.meshgrid(np.linspace(-3,3,100), np.linspace(-3,3,100))
grid = np.c_[xx.ravel(), yy.ravel()]
probs = model.predict(grid).reshape(xx.shape)
# 绘制决策边界
plt.figure(figsize=(8,6))
plt.contourf(xx, yy, probs>0.5, alpha=0.3)
plt.scatter(X[:,0], X[:,1], c=y, edgecolors='k')
plt.title('Decision Boundary')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.colorbar()
场景:二分类问题模型分析
解决问题:直观展示模型的分类能力边界
使用建议
- 数据探索阶段:优先使用直方图和散点图分析数据分布
- 模型训练时:实时监控损失曲线,配合TensorBoard使用更佳
- 模型调试期:结合激活可视化和梯度分析诊断网络问题
- 结果汇报时:使用混淆矩阵和决策边界图增强说服力
这些可视化技巧将贯穿机器学习项目的全生命周期,帮助您更高效地完成以下关键任务:验证数据质量、监控训练过程、分析模型行为、解释预测结果。