梯度检查点技术的使用
文章目录
- 1、PyTorch
- 使用例子1
- 使用例子2
- 使用例子3
- 2、Hugging Face
- 使用例子
- 3、DeepSpeed
1、PyTorch
torch.utils.checkpoint
官方文档- PyTorch Training Performance Guide 中关于 Gradient Checkpoints 的介绍
- 参考博客
使用例子1
见博客。
使用例子2
import torchvision
from torch.utils.data import DataLoader
import torch
from torch import nn
from torch import optim
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import os
transform_train = torchvision.transforms.Compose([
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
transform_test = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
train_dataset = torchvision.datasets.CIFAR10("/home/zjma/dataset/cifar10/", train=True, transform=transform_train, download=False)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=False)
test_dataset = torchvision.datasets.CIFAR10("/home/zjma/dataset/cifar10/", train=False, transform=transform_test, download=False)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)
class CIFAR10Model_Original(nn.Module):
def __init__(self):
super().__init__()
self.cnn_block_1 = nn.Sequential(*[
nn.Conv2d(3, 32, 3, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
nn.Dropout(0.25)
])
self.cnn_block_2 = nn.Sequential(*[
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
nn.Dropout(0.25)
])
self.flatten = lambda inp: torch.flatten(inp, 1)
self.head = nn.Sequential(*[
nn.Linear(64 * 8 * 8, 512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, 10)
])
def forward(self, X):
X = self.cnn_block_1(X)
X = self.cnn_block_2(X)
X = self.flatten(X)
X = self.head(X)
return X
class CIFAR10Model_Optimized(nn.Module):
def __init__(self):
super().__init__()
self.cnn_block_1 = nn.Sequential(*[
nn.Conv2d(3, 32, 3, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
])
self.dropout_1 = nn.Dropout(0.25)
self.cnn_block_2 = nn.Sequential(*[
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
])
self.dropout_2 = nn.Dropout(0.25)
self.flatten = lambda inp: torch.flatten(inp, 1)
self.linearize = nn.Sequential(*[
nn.Linear(64 * 8 * 8, 512),
nn.ReLU()
])
self.dropout_3 = nn.Dropout(0.5)
self.out = nn.Linear(512, 10)
def forward(self, X):
X = torch.utils.checkpoint.checkpoint(self.cnn_block_1, X)
X = self.dropout_1(X)
X = torch.utils.checkpoint.checkpoint(self.cnn_block_2, X)
X = self.dropout_2(X)
X = self.flatten(X)
X = self.linearize(X)
X = self.dropout_3(X)
X = self.out(X)
return X
# clf = CIFAR10Model_Original()
clf = CIFAR10Model_Optimized()
start_epoch = 1
clf.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.RMSprop(clf.parameters(), lr=0.0001, weight_decay=1e-6)
def train():
clf.train()
NUM_EPOCHS = 10
for epoch in range(start_epoch, NUM_EPOCHS + 1):
losses = []
for i, (X_batch, y_cls) in enumerate(train_dataloader):
optimizer.zero_grad()
y = y_cls.cuda()
X_batch = X_batch.cuda()
y_pred = clf(X_batch)
loss = criterion(y_pred, y)
loss.backward()
optimizer.step()
train_loss = loss.item()
losses.append(train_loss)
# Memory statistics after each batch
mem_allocated = torch.cuda.memory_allocated()
mem_reserved = torch.cuda.memory_reserved()
max_mem_allocated = torch.cuda.max_memory_allocated()
max_mem_reserved = torch.cuda.max_memory_reserved()
if i % 10 == 0:
print(
f'Finished epoch {epoch}/{NUM_EPOCHS}, batch {i}. loss: {train_loss:.3f}.'
f'Memory allocated: {mem_allocated / (1024 ** 2):.2f} MB, '
f'Memory reserved: {mem_reserved / (1024 ** 2):.2f} MB, '
f'Max memory allocated: {max_mem_allocated / (1024 ** 2):.2f} MB, '
f'Max memory reserved: {max_mem_reserved / (1024 ** 2):.2f} MB.'
)
# Reset peak memory stats for the next iteration
torch.cuda.reset_peak_memory_stats()
print(
f'Finished epoch {epoch}. '
f'avg loss: {np.mean(losses)}; median loss: {np.median(losses)}'
)
train()
使用 checkpoint 优化前:
- Max memory allocated: 69.58 MB
- Max memory reserved: 96.00 MB
使用 checkpoint 优化后:
- Max memory allocated: 40.80 MB
- Max memory reserved: 64.00 MB
使用例子3
见项目。
2、Hugging Face
gradient_checkpointing_enable
官方文档及用法- Methods and tools for efficient training on a single GPU 中关于 Gradient Checkpointing 的内容
- Performance and Scalability: How To Fit a Bigger Model and Train It Faster中关于Gradient Checkpointing的内容
- 参考博客1
- 参考博客2
使用例子
见项目。
3、DeepSpeed
- https://eanyang7.github.io/transformers_docs/main_classes/deepspeed/#activation-checkpointing-gradient-checkpointing
- https://zhuanlan.zhihu.com/p/644656141
- https://blog.csdn.net/Scc_hy/article/details/138728380
- https://huggingface.co/docs/transformers/main/en/perf_train_gpu_one#deepspeed-zero