当前位置: 首页 > article >正文

【深度学习】关键技术-模型训练(Model Training)

模型训练

模型训练是机器学习和深度学习中调整模型参数以优化性能的过程,通常包括以下步骤:

  1. 数据准备:加载数据、预处理、分割训练集和测试集。
  2. 模型定义:选择合适的算法或网络架构。
  3. 损失函数与优化器:定义训练目标(损失函数)和优化算法。
  4. 训练过程:通过迭代更新模型参数,使模型在训练集上表现更优。
  5. 验证与测试:通过验证集或测试集评估模型的性能,避免过拟合。

模型训练的关键步骤

1. 数据准备
  • 数据加载:加载原始数据(如 CSV、图像文件等)。
  • 数据预处理
    • 特征归一化或标准化。
    • 数据增强(图像翻转、裁剪等)。
  • 划分数据集:将数据分为训练集、验证集和测试集。

示例代码

from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler

# 加载数据
data = load_iris()
X, y = data.data, data.target

# 特征归一化
scaler = StandardScaler()
X = scaler.fit_transform(X)

# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)


2. 定义模型
  • 根据任务选择模型架构。
  • 对于深度学习,可以定义神经网络的层次结构。

示例代码(TensorFlow)

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense

# 创建一个简单的神经网络
model = Sequential([
    Dense(64, activation='relu', input_shape=(4,)),  # 输入特征为4维
    Dense(32, activation='relu'),
    Dense(3, activation='softmax')  # 输出3个类别
])


3. 定义损失函数和优化器
  • 损失函数:定义模型优化的目标。
    • 回归:MSE、MAE。
    • 分类:交叉熵、KL 散度。
  • 优化器:选择适合的优化算法(如 SGD、Adam)。

示例代码

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',  # 适合分类问题
              metrics=['accuracy'])  # 添加评估指标


4. 训练模型
  • 批量训练:通过小批量数据更新参数,提升计算效率。
  • 迭代训练:每轮称为一个 epoch,表示遍历整个数据集一次。
  • 验证集评估:在每个 epoch 结束后,通过验证集评估模型性能。

示例代码

history = model.fit(X_train, y_train, 
                    validation_split=0.2,  # 从训练集中划分验证集
                    epochs=50,  # 训练50轮
                    batch_size=16,  # 每批大小为16
                    verbose=1)  # 输出训练日志
Epoch 1/50
6/6 [==============================] - 1s 37ms/step - loss: 1.2995 - accuracy: 0.3229 - val_loss: 1.1235 - val_accuracy: 0.1667
Epoch 2/50
6/6 [==============================] - 0s 6ms/step - loss: 1.1551 - accuracy: 0.3542 - val_loss: 1.0245 - val_accuracy: 0.5833
Epoch 3/50
6/6 [==============================] - 0s 6ms/step - loss: 1.0217 - accuracy: 0.5104 - val_loss: 0.9356 - val_accuracy: 0.7917
Epoch 4/50
6/6 [==============================] - 0s 6ms/step - loss: 0.9087 - accuracy: 0.7500 - val_loss: 0.8568 - val_accuracy: 0.8750
Epoch 5/50
6/6 [==============================] - 0s 6ms/step - loss: 0.8113 - accuracy: 0.8021 - val_loss: 0.7840 - val_accuracy: 0.8750
Epoch 6/50
6/6 [==============================] - 0s 5ms/step - loss: 0.7256 - accuracy: 0.8125 - val_loss: 0.7158 - val_accuracy: 0.8750
Epoch 7/50
6/6 [==============================] - 0s 6ms/step - loss: 0.6508 - accuracy: 0.8125 - val_loss: 0.6570 - val_accuracy: 0.8750
Epoch 8/50
6/6 [==============================] - 0s 6ms/step - loss: 0.5906 - accuracy: 0.8229 - val_loss: 0.6048 - val_accuracy: 0.8750
Epoch 9/50
6/6 [==============================] - 0s 6ms/step - loss: 0.5433 - accuracy: 0.8229 - val_loss: 0.5616 - val_accuracy: 0.8750
Epoch 10/50
6/6 [==============================] - 0s 7ms/step - loss: 0.5063 - accuracy: 0.8229 - val_loss: 0.5259 - val_accuracy: 0.8750
Epoch 11/50
6/6 [==============================] - 0s 6ms/step - loss: 0.4748 - accuracy: 0.8229 - val_loss: 0.4969 - val_accuracy: 0.8750
Epoch 12/50
6/6 [==============================] - 0s 6ms/step - loss: 0.4490 - accuracy: 0.8229 - val_loss: 0.4719 - val_accuracy: 0.8750
Epoch 13/50
6/6 [==============================] - 0s 6ms/step - loss: 0.4273 - accuracy: 0.8229 - val_loss: 0.4504 - val_accuracy: 0.8750
Epoch 14/50
6/6 [==============================] - 0s 6ms/step - loss: 0.4076 - accuracy: 0.8229 - val_loss: 0.4334 - val_accuracy: 0.8750
Epoch 15/50
6/6 [==============================] - 0s 6ms/step - loss: 0.3913 - accuracy: 0.8229 - val_loss: 0.4184 - val_accuracy: 0.8750
Epoch 16/50
6/6 [==============================] - 0s 7ms/step - loss: 0.3756 - accuracy: 0.8229 - val_loss: 0.4042 - val_accuracy: 0.8750
Epoch 17/50
6/6 [==============================] - 0s 7ms/step - loss: 0.3615 - accuracy: 0.8333 - val_loss: 0.3941 - val_accuracy: 0.8750
Epoch 18/50
6/6 [==============================] - 0s 7ms/step - loss: 0.3480 - accuracy: 0.8333 - val_loss: 0.3801 - val_accuracy: 0.9167
Epoch 19/50
6/6 [==============================] - 0s 6ms/step - loss: 0.3355 - accuracy: 0.8333 - val_loss: 0.3654 - val_accuracy: 0.9167
Epoch 20/50
6/6 [==============================] - 0s 6ms/step - loss: 0.3234 - accuracy: 0.8333 - val_loss: 0.3600 - val_accuracy: 0.9167
Epoch 21/50
6/6 [==============================] - 0s 6ms/step - loss: 0.3139 - accuracy: 0.8750 - val_loss: 0.3546 - val_accuracy: 0.9167
Epoch 22/50
6/6 [==============================] - 0s 6ms/step - loss: 0.3020 - accuracy: 0.8854 - val_loss: 0.3429 - val_accuracy: 0.9167
Epoch 23/50
6/6 [==============================] - 0s 7ms/step - loss: 0.2913 - accuracy: 0.8854 - val_loss: 0.3291 - val_accuracy: 0.9167
Epoch 24/50
6/6 [==============================] - 0s 14ms/step - loss: 0.2832 - accuracy: 0.8750 - val_loss: 0.3242 - val_accuracy: 0.9167
Epoch 25/50
6/6 [==============================] - 0s 6ms/step - loss: 0.2721 - accuracy: 0.8958 - val_loss: 0.3144 - val_accuracy: 0.9167
Epoch 26/50
6/6 [==============================] - 0s 6ms/step - loss: 0.2634 - accuracy: 0.9062 - val_loss: 0.3064 - val_accuracy: 0.9167
Epoch 27/50
6/6 [==============================] - 0s 6ms/step - loss: 0.2542 - accuracy: 0.9271 - val_loss: 0.2990 - val_accuracy: 0.9167
Epoch 28/50
6/6 [==============================] - 0s 6ms/step - loss: 0.2459 - accuracy: 0.9271 - val_loss: 0.2888 - val_accuracy: 0.9167
Epoch 29/50
6/6 [==============================] - 0s 7ms/step - loss: 0.2369 - accuracy: 0.9271 - val_loss: 0.2868 - val_accuracy: 0.9167
Epoch 30/50
6/6 [==============================] - 0s 7ms/step - loss: 0.2275 - accuracy: 0.9375 - val_loss: 0.2825 - val_accuracy: 0.9583
Epoch 31/50
6/6 [==============================] - 0s 7ms/step - loss: 0.2188 - accuracy: 0.9479 - val_loss: 0.2746 - val_accuracy: 0.9583
Epoch 32/50
6/6 [==============================] - 0s 6ms/step - loss: 0.2105 - accuracy: 0.9479 - val_loss: 0.2625 - val_accuracy: 0.9583
Epoch 33/50
6/6 [==============================] - 0s 5ms/step - loss: 0.2016 - accuracy: 0.9479 - val_loss: 0.2563 - val_accuracy: 0.9583
Epoch 34/50
6/6 [==============================] - 0s 6ms/step - loss: 0.1936 - accuracy: 0.9479 - val_loss: 0.2478 - val_accuracy: 0.9583
Epoch 35/50
6/6 [==============================] - 0s 6ms/step - loss: 0.1853 - accuracy: 0.9479 - val_loss: 0.2445 - val_accuracy: 0.9583
Epoch 36/50
6/6 [==============================] - 0s 8ms/step - loss: 0.1774 - accuracy: 0.9479 - val_loss: 0.2341 - val_accuracy: 0.9583
Epoch 37/50
6/6 [==============================] - 0s 7ms/step - loss: 0.1686 - accuracy: 0.9479 - val_loss: 0.2267 - val_accuracy: 0.9583
Epoch 38/50
6/6 [==============================] - 0s 6ms/step - loss: 0.1607 - accuracy: 0.9479 - val_loss: 0.2247 - val_accuracy: 0.9583
Epoch 39/50
6/6 [==============================] - 0s 6ms/step - loss: 0.1540 - accuracy: 0.9479 - val_loss: 0.2189 - val_accuracy: 0.9583
Epoch 40/50
6/6 [==============================] - 0s 6ms/step - loss: 0.1470 - accuracy: 0.9479 - val_loss: 0.2126 - val_accuracy: 0.9583
Epoch 41/50
6/6 [==============================] - 0s 6ms/step - loss: 0.1406 - accuracy: 0.9375 - val_loss: 0.2035 - val_accuracy: 0.9583
Epoch 42/50
6/6 [==============================] - 0s 6ms/step - loss: 0.1348 - accuracy: 0.9479 - val_loss: 0.1886 - val_accuracy: 0.9583
Epoch 43/50
6/6 [==============================] - 0s 6ms/step - loss: 0.1303 - accuracy: 0.9583 - val_loss: 0.1836 - val_accuracy: 0.9583
Epoch 44/50
6/6 [==============================] - 0s 7ms/step - loss: 0.1271 - accuracy: 0.9583 - val_loss: 0.1846 - val_accuracy: 0.9583
Epoch 45/50
6/6 [==============================] - 0s 7ms/step - loss: 0.1210 - accuracy: 0.9583 - val_loss: 0.1815 - val_accuracy: 0.9583
Epoch 46/50
6/6 [==============================] - 0s 6ms/step - loss: 0.1169 - accuracy: 0.9583 - val_loss: 0.1786 - val_accuracy: 0.9583
Epoch 47/50
6/6 [==============================] - 0s 6ms/step - loss: 0.1134 - accuracy: 0.9688 - val_loss: 0.1766 - val_accuracy: 0.9583
Epoch 48/50
6/6 [==============================] - 0s 6ms/step - loss: 0.1105 - accuracy: 0.9583 - val_loss: 0.1703 - val_accuracy: 0.9583
Epoch 49/50
6/6 [==============================] - 0s 6ms/step - loss: 0.1074 - accuracy: 0.9688 - val_loss: 0.1729 - val_accuracy: 0.9583
Epoch 50/50
6/6 [==============================] - 0s 7ms/step - loss: 0.1051 - accuracy: 0.9583 - val_loss: 0.1777 - val_accuracy: 0.9583
1/1 [==============================] - 0s 24ms/step - loss: 0.0885 - accuracy: 1.0000


5. 验证与测试
  • 验证集评估:在训练过程中监控模型在验证集上的表现。
  • 测试集评估:训练完成后在独立的测试集上评估最终性能。

示例代码

# 在测试集上评估
test_loss, test_accuracy = model.evaluate(X_test, y_test)
print(f"Test Accuracy: {test_accuracy}")
Test Accuracy: 1.0


可视化训练过程

训练过程中,可以通过绘制损失值和准确率的变化趋势来分析模型性能。

示例代码

import matplotlib.pyplot as plt

# 绘制训练和验证的损失值
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Loss Over Epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

# 绘制训练和验证的准确率
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Accuracy Over Epochs')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()


模型训练中的常见问题

  1. 过拟合:模型在训练集上表现很好,但在验证集或测试集上表现较差。

    • 解决方法
      • 数据增强。
      • 正则化(L1、L2)。
      • 添加 Dropout 层。
      • 使用更小的模型。
  2. 欠拟合:模型在训练集上表现不佳。

    • 解决方法
      • 增加模型复杂度。
      • 调整学习率。
      • 增加训练数据。
  3. 梯度消失或爆炸

    • 解决方法
      • 使用合适的激活函数(如 ReLU)。
      • 批归一化(Batch Normalization)。
      • 梯度裁剪(Gradient Clipping)。

完整示例代码

以下是一个完整的模型训练示例,用于分类任务:

import numpy as np
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# 生成数据
X, y = make_classification(n_samples=1000, n_features=20, n_classes=2, random_state=42)

# 数据标准化
scaler = StandardScaler()
X = scaler.fit_transform(X)

# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 定义模型
model = Sequential([
    Dense(64, activation='relu', input_shape=(20,)),
    Dense(32, activation='relu'),
    Dense(3, activation='softmax')  # 输出层
])

# 编译模型
model.compile(optimizer=Adam(learning_rate=0.001),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# 训练模型
history = model.fit(X_train, y_train,
                    validation_split=0.2,
                    epochs=30,
                    batch_size=32,
                    verbose=1)

# 测试模型
test_loss, test_accuracy = model.evaluate(X_test, y_test)
print(f"Test Accuracy: {test_accuracy}")

# 可视化训练过程
import matplotlib.pyplot as plt
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.legend()
plt.show()

运行结果

Epoch 1/30
20/20 [==============================] - 1s 11ms/step - loss: 0.9144 - accuracy: 0.5141 - val_loss: 0.7246 - val_accuracy: 0.6313
Epoch 2/30
20/20 [==============================] - 0s 2ms/step - loss: 0.6462 - accuracy: 0.6531 - val_loss: 0.5841 - val_accuracy: 0.7563
Epoch 3/30
20/20 [==============================] - 0s 2ms/step - loss: 0.5367 - accuracy: 0.7703 - val_loss: 0.4945 - val_accuracy: 0.8188
Epoch 4/30
20/20 [==============================] - 0s 2ms/step - loss: 0.4608 - accuracy: 0.8266 - val_loss: 0.4262 - val_accuracy: 0.8500
Epoch 5/30
20/20 [==============================] - 0s 2ms/step - loss: 0.4074 - accuracy: 0.8531 - val_loss: 0.3757 - val_accuracy: 0.8625
Epoch 6/30
20/20 [==============================] - 0s 3ms/step - loss: 0.3724 - accuracy: 0.8609 - val_loss: 0.3413 - val_accuracy: 0.8438
Epoch 7/30
20/20 [==============================] - 0s 3ms/step - loss: 0.3463 - accuracy: 0.8734 - val_loss: 0.3223 - val_accuracy: 0.8562
Epoch 8/30
20/20 [==============================] - 0s 2ms/step - loss: 0.3289 - accuracy: 0.8844 - val_loss: 0.3086 - val_accuracy: 0.8875
Epoch 9/30
20/20 [==============================] - 0s 2ms/step - loss: 0.3155 - accuracy: 0.8875 - val_loss: 0.3015 - val_accuracy: 0.8875
Epoch 10/30
20/20 [==============================] - 0s 2ms/step - loss: 0.3036 - accuracy: 0.8922 - val_loss: 0.2896 - val_accuracy: 0.9000
Epoch 11/30
20/20 [==============================] - 0s 3ms/step - loss: 0.2926 - accuracy: 0.8938 - val_loss: 0.2847 - val_accuracy: 0.8875
Epoch 12/30
20/20 [==============================] - 0s 3ms/step - loss: 0.2847 - accuracy: 0.8969 - val_loss: 0.2862 - val_accuracy: 0.8875
Epoch 13/30
20/20 [==============================] - 0s 3ms/step - loss: 0.2761 - accuracy: 0.8984 - val_loss: 0.2773 - val_accuracy: 0.8938
Epoch 14/30
20/20 [==============================] - 0s 2ms/step - loss: 0.2671 - accuracy: 0.9000 - val_loss: 0.2808 - val_accuracy: 0.8875
Epoch 15/30
20/20 [==============================] - 0s 2ms/step - loss: 0.2593 - accuracy: 0.9062 - val_loss: 0.2711 - val_accuracy: 0.8813
Epoch 16/30
20/20 [==============================] - 0s 3ms/step - loss: 0.2518 - accuracy: 0.9047 - val_loss: 0.2724 - val_accuracy: 0.8813
Epoch 17/30
20/20 [==============================] - 0s 2ms/step - loss: 0.2454 - accuracy: 0.9094 - val_loss: 0.2703 - val_accuracy: 0.8875
Epoch 18/30
20/20 [==============================] - 0s 2ms/step - loss: 0.2355 - accuracy: 0.9234 - val_loss: 0.2695 - val_accuracy: 0.8875
Epoch 19/30
20/20 [==============================] - 0s 2ms/step - loss: 0.2286 - accuracy: 0.9234 - val_loss: 0.2718 - val_accuracy: 0.8750
Epoch 20/30
20/20 [==============================] - 0s 3ms/step - loss: 0.2210 - accuracy: 0.9187 - val_loss: 0.2752 - val_accuracy: 0.8875
Epoch 21/30
20/20 [==============================] - 0s 3ms/step - loss: 0.2143 - accuracy: 0.9203 - val_loss: 0.2690 - val_accuracy: 0.8813
Epoch 22/30
20/20 [==============================] - 0s 2ms/step - loss: 0.2074 - accuracy: 0.9266 - val_loss: 0.2752 - val_accuracy: 0.8875
Epoch 23/30
20/20 [==============================] - 0s 2ms/step - loss: 0.2008 - accuracy: 0.9281 - val_loss: 0.2709 - val_accuracy: 0.8875
Epoch 24/30
20/20 [==============================] - 0s 2ms/step - loss: 0.1933 - accuracy: 0.9312 - val_loss: 0.2765 - val_accuracy: 0.8875
Epoch 25/30
20/20 [==============================] - 0s 5ms/step - loss: 0.1863 - accuracy: 0.9359 - val_loss: 0.2728 - val_accuracy: 0.8938
Epoch 26/30
20/20 [==============================] - 0s 2ms/step - loss: 0.1803 - accuracy: 0.9359 - val_loss: 0.2765 - val_accuracy: 0.8938
Epoch 27/30
20/20 [==============================] - 0s 2ms/step - loss: 0.1734 - accuracy: 0.9375 - val_loss: 0.2784 - val_accuracy: 0.8938
Epoch 28/30
20/20 [==============================] - 0s 2ms/step - loss: 0.1686 - accuracy: 0.9422 - val_loss: 0.2847 - val_accuracy: 0.8813
Epoch 29/30
20/20 [==============================] - 0s 3ms/step - loss: 0.1633 - accuracy: 0.9469 - val_loss: 0.2766 - val_accuracy: 0.9000
Epoch 30/30
20/20 [==============================] - 0s 3ms/step - loss: 0.1554 - accuracy: 0.9516 - val_loss: 0.2854 - val_accuracy: 0.9000
7/7 [==============================] - 0s 1ms/step - loss: 0.4057 - accuracy: 0.8200
Test Accuracy: 0.8199999928474426

通过上述步骤和代码,您可以完成模型训练并评估其性能,同时对训练过程中的关键问题进行分析和优化。


http://www.kler.cn/a/513873.html

相关文章:

  • 路径规划之启发式算法之二十八:候鸟优化算法(Migrating Birds Optimization, MBO)
  • Apache Tomcat文件包含漏洞复现(详细教程)
  • P8738 [蓝桥杯 2020 国 C] 天干地支
  • [Easy] leetcode-500 键盘行
  • elasticsearch 数据导出/导入
  • vue2 - Day05 - VueX
  • 【Springboot相关知识】Springboot结合SpringSecurity实现身份认证以及接口鉴权
  • vue md5加密
  • 性能调优篇 四、JVM运行时参数
  • 数据结构(四) B树/跳表
  • elementui完美做到table动态复杂合并行、合并列,适用于vue2、vue3
  • CVPR 2024 人脸方向总汇(人脸识别、头像重建、人脸合成和3D头像等)
  • 声学前端算法方案,提升设备语音交互体验,ESP32-S3智能化联网应用
  • 美区TikTok危机缓解,TikTok 直播运营专线助力稳定运营
  • iOS-支付相关
  • 学习第七十三行
  • 能用导航菜单 单弹 双弹 缩小显手机端 原生css js
  • 数据结构(三) 排序/并查集/图
  • Python公有属性与私有属性
  • NAT·综合实验——静态+动态复用+TCP负载分担
  • 国产编辑器EverEdit - 大纲视图
  • 你还在用idea吗
  • 为什么要将将 数据类(dataclass)对象 转换为 字典(dictionary)
  • 深入剖析 JVM 内存模型
  • C# OpenCvSharp 部署表格检测
  • 【Linux系统】—— 编译器 gcc/g++ 的使用