PyTorch C++系列教程1:用 VGG-16 识别 MNIST
PyTorch系列文章目录
文章目录
- PyTorch系列文章目录
- 前言
- 安装
- CPU 版本:
- GPU (CUDA 9.0) 版本:
- GPU (CUDA 10.0) 版本:
- 一、VGG-16 的网络结构
- 训练
- 总结
前言
本文讲解如何用 PyTorch C 实现 VGG-16 来识别 MNIST 数据集。
安装
首先下载 libtorch:
CPU 版本:
wget https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-latest.zip -O libtorch.zip
GPU (CUDA 9.0) 版本:
wget https://download.pytorch.org/libtorch/cu90/libtorch-shared-with-deps-latest.zip -O libtorch.zip
GPU (CUDA 10.0) 版本:
wget https://download.pytorch.org/libtorch/cu100/libtorch-shared-with-deps-latest.zip
然后将下载的压缩包解压缩。后面我们将使用解压后后的文件夹的绝对路径。
一、VGG-16 的网络结构
![在这里插入图首先引入头文件:
#include <torch/torch.h>
然后实现网络定义:
/* Sample code for training a FCN on MNIST dataset using PyTorch C++ API */
/* This code uses VGG-16 Layer Network */
struct Net: torch::nn::Module {
// VGG-16 Layer
// conv1_1 - conv1_2 - pool 1 - conv2_1 - conv2_2 - pool 2 - conv3_1 - conv3_2 - conv3_3 - pool 3 -
// conv4_1 - conv4_2 - conv4_3 - pool 4 - conv5_1 - conv5_2 - conv5_3 - pool 5 - fc6 - fc7 - fc8
Net() {
// Initialize VGG-16
// On how to pass strides and padding: https://github.com/pytorch/pytorch/issues/12649#issuecomment-430156160
conv1_1 = register_module("conv1_1", torch::nn::Conv2d(torch::nn::Conv2dOptions(1, 10, 3).padding(1)));
conv1_2 = register_module("conv1_2", torch::nn::Conv2d(torch::nn::Conv2dOptions(10, 20, 3).padding(1)));
// Insert pool layer
conv2_1 = register_module("conv2_1", torch::nn::Conv2d(torch::nn::Conv2dOptions(20, 30, 3).padding(1)));
conv2_2 = register_module("conv2_2", torch::nn::Conv2d(torch::nn::Conv2dOptions(30, 40, 3).padding(1)));
// Insert pool layer
conv3_1 = register_module("conv3_1", torch::nn::Conv2d(torch::nn::Conv2dOptions(40, 50, 3).padding(1)));
conv3_2 = register_module("conv3_2", torch::nn::Conv2d(torch::nn::Conv2dOptions(50, 60, 3).padding(1)));
conv3_3 = register_module("conv3_3", torch::nn::Conv2d(torch::nn::Conv2dOptions(60, 70, 3).padding(1)));
// Insert pool layer
conv4_1 = register_module("conv4_1", torch::nn::Conv2d(torch::nn::Conv2dOptions(70, 80, 3).padding(1)));
conv4_2 = register_module("conv4_2", torch::nn::Conv2d(torch::nn::Conv2dOptions(80, 90, 3).padding(1)));
conv4_3 = register_module("conv4_3", torch::nn::Conv2d(torch::nn::Conv2dOptions(90, 100, 3).padding(1)));
// Insert pool layer
conv5_1 = register_module("conv5_1", torch::nn::Conv2d(torch::nn::Conv2dOptions(100, 110, 3).padding(1)));
conv5_2 = register_module("conv5_2", torch::nn::Conv2d(torch::nn::Conv2dOptions(110, 120, 3).padding(1)));
conv5_3 = register_module("conv5_3", torch::nn::Conv2d(torch::nn::Conv2dOptions(120, 130, 3).padding(1)));
// Insert pool layer
fc1 = register_module("fc1", torch::nn::Linear(130, 50));
fc2 = register_module("fc2", torch::nn::Linear(50, 20));
fc3 = register_module("fc3", torch::nn::Linear(20, 10));
}
// Implement Algorithm
torch::Tensor forward(torch::Tensor x) {
x = torch::relu(conv1_1->forward(x));
x = torch::relu(conv1_2->forward(x));
x = torch::max_pool2d(x, 2);
x = torch::relu(conv2_1->forward(x));
x = torch::relu(conv2_2->forward(x));
x = torch::max_pool2d(x, 2);
x = torch::relu(conv3_1->forward(x));
x = torch::relu(conv3_2->forward(x));
x = torch::relu(conv3_3->forward(x));
x = torch::max_pool2d(x, 2);
x = torch::relu(conv4_1->forward(x));
x = torch::relu(conv4_2->forward(x));
x = torch::relu(conv4_3->forward(x));
x = torch::max_pool2d(x, 2);
x = torch::relu(conv5_1->forward(x));
x = torch::relu(conv5_2->forward(x));
x = torch::relu(conv5_3->forward(x));
x = torch::max_pool2d(x, 2);
x = x.view({-1, 130});
x = torch::relu(fc1->forward(x));
x = torch::relu(fc2->forward(x));
x = fc3->forward(x);
return torch::log_softmax(x, 1);
}
// Declare layers
torch::nn::Conv2d conv1_1{nullptr};
torch::nn::Conv2d conv1_2{nullptr};
torch::nn::Conv2d conv2_1{nullptr};
torch::nn::Conv2d conv2_2{nullptr};
torch::nn::Conv2d conv3_1{nullptr};
torch::nn::Conv2d conv3_2{nullptr};
torch::nn::Conv2d conv3_3{nullptr};
torch::nn::Conv2d conv4_1{nullptr};
torch::nn::Conv2d conv4_2{nullptr};
torch::nn::Conv2d conv4_3{nullptr};
torch::nn::Conv2d conv5_1{nullptr};
torch::nn::Conv2d conv5_2{nullptr};
torch::nn::Conv2d conv5_3{nullptr};
torch::nn::Linear fc1{nullptr}, fc2{nullptr}, fc3{nullptr};
};
训练
接下来我们测试训练网络,我们训练 10 个 epoch,学习率 0.01,使用 nll_loss 损失函数:
int main() {
// Create multi-threaded data loader for MNIST data
auto data_loader = torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(
std::move(torch::data::datasets::MNIST("../../data").map(torch::data::transforms::Normalize<>(0.13707, 0.3081)).map(
torch::data::transforms::Stack<>())), 64);
// Build VGG-16 Network
auto net = std::make_shared<Net>();
torch::optim::SGD optimizer(net->parameters(), 0.01); // Learning Rate 0.01
// net.train();
for(size_t epoch=1; epoch<=10; ++epoch) {
size_t batch_index = 0;
// Iterate data loader to yield batches from the dataset
for (auto& batch: *data_loader) {
// Reset gradients
optimizer.zero_grad();
// Execute the model
torch::Tensor prediction = net->forward(batch.data);
// Compute loss value
torch::Tensor loss = torch::nll_loss(prediction, batch.target);
// Compute gradients
loss.backward();
// Update the parameters
optimizer.step();
// Output the loss and checkpoint every 100 batches
if (++batch_index % 2 == 0) {
std::cout << "Epoch: " << epoch << " | Batch: " << batch_index
<< " | Loss: " << loss.item<float>() << std::endl;
torch::save(net, "net.pt");
}
}
}
}
总结
完整代码请参考:
https://github.com/krshrimali/Digit-Recognition-MNIST-SVHN-PyTorch-CPP
参考资料
https://pytorch.org/cppdocs/
http://yann.lecun.com/exdb/mnist/