当前位置: 首页 > article >正文

PyTorch C++系列教程1:用 VGG-16 识别 MNIST

PyTorch系列文章目录


文章目录

  • PyTorch系列文章目录
  • 前言
    • 安装
      • CPU 版本:
      • GPU (CUDA 9.0) 版本:
      • GPU (CUDA 10.0) 版本:
  • 一、VGG-16 的网络结构
  • 训练
  • 总结


前言

本文讲解如何用 PyTorch C 实现 VGG-16 来识别 MNIST 数据集。

安装

首先下载 libtorch:

CPU 版本:

wget https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-latest.zip -O libtorch.zip

GPU (CUDA 9.0) 版本:

wget https://download.pytorch.org/libtorch/cu90/libtorch-shared-with-deps-latest.zip -O libtorch.zip

GPU (CUDA 10.0) 版本:

wget https://download.pytorch.org/libtorch/cu100/libtorch-shared-with-deps-latest.zip

然后将下载的压缩包解压缩。后面我们将使用解压后后的文件夹的绝对路径。

一、VGG-16 的网络结构

![在这里插入图首先引入头文件:

#include <torch/torch.h>

然后实现网络定义:

/* Sample code for training a FCN on MNIST dataset using PyTorch C++ API */
/* This code uses VGG-16 Layer Network */

struct Net: torch::nn::Module {
    // VGG-16 Layer
    // conv1_1 - conv1_2 - pool 1 - conv2_1 - conv2_2 - pool 2 - conv3_1 - conv3_2 - conv3_3 - pool 3 -
    // conv4_1 - conv4_2 - conv4_3 - pool 4 - conv5_1 - conv5_2 - conv5_3 - pool 5 - fc6 - fc7 - fc8
    Net() {
        // Initialize VGG-16
        // On how to pass strides and padding: https://github.com/pytorch/pytorch/issues/12649#issuecomment-430156160
        conv1_1 = register_module("conv1_1", torch::nn::Conv2d(torch::nn::Conv2dOptions(1, 10, 3).padding(1)));
        conv1_2 = register_module("conv1_2", torch::nn::Conv2d(torch::nn::Conv2dOptions(10, 20, 3).padding(1)));
        // Insert pool layer
        conv2_1 = register_module("conv2_1", torch::nn::Conv2d(torch::nn::Conv2dOptions(20, 30, 3).padding(1)));
        conv2_2 = register_module("conv2_2", torch::nn::Conv2d(torch::nn::Conv2dOptions(30, 40, 3).padding(1)));
        // Insert pool layer
        conv3_1 = register_module("conv3_1", torch::nn::Conv2d(torch::nn::Conv2dOptions(40, 50, 3).padding(1)));
        conv3_2 = register_module("conv3_2", torch::nn::Conv2d(torch::nn::Conv2dOptions(50, 60, 3).padding(1)));
        conv3_3 = register_module("conv3_3", torch::nn::Conv2d(torch::nn::Conv2dOptions(60, 70, 3).padding(1)));
        // Insert pool layer
        conv4_1 = register_module("conv4_1", torch::nn::Conv2d(torch::nn::Conv2dOptions(70, 80, 3).padding(1)));
        conv4_2 = register_module("conv4_2", torch::nn::Conv2d(torch::nn::Conv2dOptions(80, 90, 3).padding(1)));
        conv4_3 = register_module("conv4_3", torch::nn::Conv2d(torch::nn::Conv2dOptions(90, 100, 3).padding(1)));
        // Insert pool layer
        conv5_1 = register_module("conv5_1", torch::nn::Conv2d(torch::nn::Conv2dOptions(100, 110, 3).padding(1)));
        conv5_2 = register_module("conv5_2", torch::nn::Conv2d(torch::nn::Conv2dOptions(110, 120, 3).padding(1)));
        conv5_3 = register_module("conv5_3", torch::nn::Conv2d(torch::nn::Conv2dOptions(120, 130, 3).padding(1)));
        // Insert pool layer
        fc1 = register_module("fc1", torch::nn::Linear(130, 50));
        fc2 = register_module("fc2", torch::nn::Linear(50, 20));
        fc3 = register_module("fc3", torch::nn::Linear(20, 10));
    }

    // Implement Algorithm
    torch::Tensor forward(torch::Tensor x) {
        x = torch::relu(conv1_1->forward(x));
        x = torch::relu(conv1_2->forward(x));
        x = torch::max_pool2d(x, 2);

        x = torch::relu(conv2_1->forward(x));
        x = torch::relu(conv2_2->forward(x));
        x = torch::max_pool2d(x, 2);

        x = torch::relu(conv3_1->forward(x));
        x = torch::relu(conv3_2->forward(x));
        x = torch::relu(conv3_3->forward(x));
        x = torch::max_pool2d(x, 2);

        x = torch::relu(conv4_1->forward(x));
        x = torch::relu(conv4_2->forward(x));
        x = torch::relu(conv4_3->forward(x));
        x = torch::max_pool2d(x, 2);

        x = torch::relu(conv5_1->forward(x));
        x = torch::relu(conv5_2->forward(x));
        x = torch::relu(conv5_3->forward(x));
        x = torch::max_pool2d(x, 2);

        x = x.view({-1, 130});

        x = torch::relu(fc1->forward(x));
        x = torch::relu(fc2->forward(x));
        x = fc3->forward(x);

        return torch::log_softmax(x, 1);
    }

    // Declare layers
    torch::nn::Conv2d conv1_1{nullptr};
    torch::nn::Conv2d conv1_2{nullptr};
    torch::nn::Conv2d conv2_1{nullptr};
    torch::nn::Conv2d conv2_2{nullptr};
    torch::nn::Conv2d conv3_1{nullptr};
    torch::nn::Conv2d conv3_2{nullptr};
    torch::nn::Conv2d conv3_3{nullptr};
    torch::nn::Conv2d conv4_1{nullptr};
    torch::nn::Conv2d conv4_2{nullptr};
    torch::nn::Conv2d conv4_3{nullptr};
    torch::nn::Conv2d conv5_1{nullptr};
    torch::nn::Conv2d conv5_2{nullptr};
    torch::nn::Conv2d conv5_3{nullptr};

    torch::nn::Linear fc1{nullptr}, fc2{nullptr}, fc3{nullptr};
};

训练

接下来我们测试训练网络,我们训练 10 个 epoch,学习率 0.01,使用 nll_loss 损失函数:

int main() {
 // Create multi-threaded data loader for MNIST data
 auto data_loader = torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(
 std::move(torch::data::datasets::MNIST("../../data").map(torch::data::transforms::Normalize<>(0.13707, 0.3081)).map(
 torch::data::transforms::Stack<>())), 64);
 
    // Build VGG-16 Network
    auto net = std::make_shared<Net>();
 
    torch::optim::SGD optimizer(net->parameters(), 0.01); // Learning Rate 0.01
 
 // net.train();
 
 for(size_t epoch=1; epoch<=10; ++epoch) {
 size_t batch_index = 0;
 // Iterate data loader to yield batches from the dataset
 for (auto& batch: *data_loader) {
 // Reset gradients
 optimizer.zero_grad();
 // Execute the model
 torch::Tensor prediction = net->forward(batch.data);
 // Compute loss value
 torch::Tensor loss = torch::nll_loss(prediction, batch.target);
 // Compute gradients
 loss.backward();
 // Update the parameters
 optimizer.step();
 
 // Output the loss and checkpoint every 100 batches
 if (++batch_index % 2 == 0) {
 std::cout << "Epoch: " << epoch << " | Batch: " << batch_index 
 << " | Loss: " << loss.item<float>() << std::endl;
 torch::save(net, "net.pt");
 }
 }
 }
}

总结

完整代码请参考:
https://github.com/krshrimali/Digit-Recognition-MNIST-SVHN-PyTorch-CPP

参考资料
https://pytorch.org/cppdocs/
http://yann.lecun.com/exdb/mnist/


http://www.kler.cn/a/6117.html

相关文章:

  • 【VUE 指令学习笔记】
  • Python创建GitHub标签的Django管理命令
  • BeanFactory与factoryBean 区别,请用源码分析,及spring中涉及的点,及应用场景
  • C++ 入门第23天:Lambda 表达式与标准库算法入门
  • BloombergGPT: A Large Language Model for Finance——面向金融领域的大语言模型
  • C4D2025 win版本安装完无法打开,提示请将你的maxon App更新至最新版本,如何解决
  • twitter开源算法(1)For You推荐系统架构
  • 10年 “自动化测试” 老鸟,写给 3-5 年测试员的几点建议,满满硬货指导
  • 牛客网Python入门103题练习|(05--运算符(2))
  • Vue3---手写Tree组件
  • leetcode 105.从前序与中序遍历序列构造二叉树
  • 【计算思维题】少儿编程 蓝桥杯青少组计算思维题真题及解析第2套
  • 一篇文章搞定《动手学深度学习》-(李牧)PyTorch版本的所有内容
  • 上班族适合大自考还是小自考?看完你就懂了
  • 【IAR工程】STM8S208RB基于ST标准库窗口看门狗(WWDG)
  • 华为手表开发:WATCH 3 Pro(12)http请求数据到服务器
  • 【Linux】传输层协议 — TCP协议
  • spark第四章:SparkSQL基本操作
  • 校招面试重点汇总之JVM(中大厂必备)
  • demo-helloworld,properties,actuator,admin-server/client
  • 大厂面试篇--2023软件测试八股文最全文档,有它直接大杀四方
  • leaflet绘制具有虚线框的多边形(125)
  • C# | 使用DataGridView展示JSON数组
  • 近万字的超详细C++类和对象(已完结)
  • 【网络应用开发】实验2--JSP技术及应用(HTTP状态400错误的请求的解决方法)
  • PMP一般要提前多久备考?