Lightning初探
portch-lightning是pytorch的抽象和包装,它的好处是可复用性强,易维护,逻辑清晰等。
学习使用portch-lightning可以使我们专注于模型,而不是其他的重复脏活。
一、传统的pytorch训练流程
# 模型
model = Modle(nn.Module)
# datasets
train_data_loader = DataLoader()
val_data_loader = DataLoader()
# 优化器
optimizer = torch.optim.SGD(model.parameters,lr)
for epoch in range(num_epoches):
# 设置为训练模式
model.train()
for batch in train_data_loader:
# 提取数据
x,target = batch
# 前向传播
y = model.forward(x)
# 计算损失
loss = MSELoss(y,target)
# 梯度清零
optimizer.zero_grad()
# 反向传播
loss.backward()
# 参数优化
optimizer.step()
# 设置为评估模式
model.eval()
# 停止梯度更新
with torch.no_grad():
Y = []
Y_target = []
for batch in val_data_loader:
# 提取数据
x,target = batch
# 前向传播
y = model.forward(x)
# 存储结果
Y_target.append(target)
Y.append(y)
# 统一计算分数
score = evaluation(Y_target,Y)
这个流程除了模型需要改之外,其他的部分每次都是相同的,需要做很多重复的工作,且加入可视化的逻辑越多,代码就越臃肿,越繁杂。
二、新的portch-lightning的训练流程
"""
原本定义模型时继承的是nn.Module,但现在换成了L.LightningModule
其相对于nn.Module又把部分训练过程封装成了函数
最后又把整体流程封装成了一个大函数
最后实现了:模型定义好之后,只需要调一个大函数就可以完成训练流程
"""
class model(L.LightningModule):
def init(self):
pass
def forward(self):
pass
def configure_optimizers(self):
optimizer = torch.optim.SGD(parameters,lr)
return optimizer
def train_step():
loss = []
return loss
def evaluation_step():
loss = []
"""
函数替代后的具体过程
"""
# 模型(替换,包含优化器)
model
# datasets(没有变化)
train_data_loader = DataLoader()
val_data_loader = DataLoader()
for epoch in range(num_epoches):
# 设置为训练模式
model.train()
for batch in train_data_loader:
# 训练步骤
loss = train_step(self,batch,batch_index)
# 反向传播
backward()
# 参数优化
optimizers_step()
# 设置为评估模式
model.eval()
# 停止梯度更新
with torch.no_grad():
val_loss_list = []
for batch in val_data_loader:
# 评估步骤
val_loss = evaluation_step(self,batch,batch_index)
# 保存结果
val_loss_list.append(val_loss)
# 统一计算分数
score = mean(val_loss_list)
"""
把替代后的流程封装为函数
"""
model = model()
trainer = L.Trainer("GPU")
trainer.fit(model,train_data_loader,val_data_loader)