深度学习·wandb
wandb
一个好用的可视化训练过程和调参工具,建议在深度学习中使用,语法来说更加方便
前置工作
这里是一些简单的网络结构,用于测试
数据集:
- Kaggle上HeartDisease的0-1分类问题
df=pd.read_csv('../data/heart_attack/heart.csv')
数据集的迭代:
X=torch.tensor(X.values,device=config.device,dtype=torch.float32) y=torch.tensor(y.values,device=config.device,dtype=torch.float32).reshape(-1,1) dataset=TensorDataset(X,y) dataloader=DataLoader(dataset,batch_size=config.batch_size,shuffle=True)
简单的DNN
class DNN(nn.Module):
def __init__(self,input_size,hidden_size,dropout:float):
super().__init__()
self.input_size=input_size
self.hidden_size=hidden_size
self.fc1=nn.Linear(self.input_size,self.hidden_size)
self.fc2=nn.Linear(self.hidden_size,self.hidden_size)
self.fc3=nn.Linear(self.hidden_size,1)
self.dropout=nn.Dropout(dropout)
def forward(self,x):
x=F.leaky_relu(self.fc1(x))
x=self.dropout(x)
x=F.leaky_relu(self.fc2(x))
x=self.dropout(x)
x=self.fc3(x)
return x
wandb监视训练过程
使用login()登陆
import os
os.environ["WANDB_API_KEY"] = "xxxx"
wandb.login(key=os.environ['WANDB_API_KEY'])
初始化wandb
- 建议使用系统时间:
current_time = datetime.now()
standard_time = current_time.strftime("%Y-%m-%d %H:%M:%S")
name=standard_time
- 初始化:
注意保存wand.run.id方便继续监视该模型
wandb.init(project=config.project_name,name=name,config=config.__dict__)# 转换为dict
model_run_id=wandb.run.id
训练流程中记录参数
for epoch in tqdm(range(config.epochs)):
for X,y in dataloader:
# 反向传播
# 评估指标
wandb.log({'epoch':epoch+1,'val_acc':val_acc,'best_acc':best_metric})
wandb.finish()
wandb.log
从接口收到对应参数,wandb.finish()
完成记录,主要不要漏掉finish
继续训练
- 提供run.id并将
resume
设置为must
wandb.init(project=config.project_name,id=model_run_id,resume='must')
Artifact工件
工件可以是代码也可以是数据集
第一个参数是名称,第二个是类型
wandb.init(project=config.project_name,id=model_run_id,resume='must')
arti_dataset=wandb.Artifact('HeartDisease',type='dataset')
arti_dataset.add_dir('../data/heart_attack/')
wandb.log_artifact(arti_dataset)
```python
arti_code=wandb.Artifact('ipynb',type='code')
arti_code.add_file('./wand_test.ipynb')
wandb.log_artifact(arti_code)
wandb.finish()
Table
可视化分析
wandb.init(project=config.project_name,id=model_run_id,resume='must')
good_cases=wandb.Table(columns=['id','GroundTrue','Prediction'])
bad_cases=wandb.Table(columns=['id','GroundTrue','Prediction'])
在代码中加入如下:
good_cases.add_data(i,y,prediction)
一般是用于比对feature、label和prediction