当前位置: 首页 > article >正文

深度学习·wandb

wandb

一个好用的可视化训练过程和调参工具,建议在深度学习中使用,语法来说更加方便

前置工作

这里是一些简单的网络结构,用于测试

数据集:

  • Kaggle上HeartDisease的0-1分类问题
    df=pd.read_csv('../data/heart_attack/heart.csv')

数据集的迭代:

  • X=torch.tensor(X.values,device=config.device,dtype=torch.float32) y=torch.tensor(y.values,device=config.device,dtype=torch.float32).reshape(-1,1) dataset=TensorDataset(X,y) dataloader=DataLoader(dataset,batch_size=config.batch_size,shuffle=True)

简单的DNN

class DNN(nn.Module):
    def __init__(self,input_size,hidden_size,dropout:float):
        super().__init__()
        self.input_size=input_size
        
        self.hidden_size=hidden_size
        
        self.fc1=nn.Linear(self.input_size,self.hidden_size)
        
        self.fc2=nn.Linear(self.hidden_size,self.hidden_size)
        
        self.fc3=nn.Linear(self.hidden_size,1)
        
        self.dropout=nn.Dropout(dropout)
    def forward(self,x):
        x=F.leaky_relu(self.fc1(x))
        x=self.dropout(x)
        x=F.leaky_relu(self.fc2(x))
        x=self.dropout(x)
        x=self.fc3(x)
        return x

wandb监视训练过程

使用login()登陆

import os
os.environ["WANDB_API_KEY"] = "xxxx"
wandb.login(key=os.environ['WANDB_API_KEY'])

初始化wandb

  • 建议使用系统时间:
    current_time = datetime.now()
    standard_time = current_time.strftime("%Y-%m-%d %H:%M:%S")
    name=standard_time
  • 初始化:
    注意保存wand.run.id方便继续监视该模型
    wandb.init(project=config.project_name,name=name,config=config.__dict__)# 转换为dict    
    model_run_id=wandb.run.id

训练流程中记录参数

    for epoch in tqdm(range(config.epochs)):
        for X,y in dataloader:
        # 反向传播
        # 评估指标
        wandb.log({'epoch':epoch+1,'val_acc':val_acc,'best_acc':best_metric})
    wandb.finish()

wandb.log从接口收到对应参数,wandb.finish()完成记录,主要不要漏掉finish

继续训练

  • 提供run.id并将resume设置为must
    wandb.init(project=config.project_name,id=model_run_id,resume='must')

Artifact工件

工件可以是代码也可以是数据集
第一个参数是名称,第二个是类型

wandb.init(project=config.project_name,id=model_run_id,resume='must')
arti_dataset=wandb.Artifact('HeartDisease',type='dataset')
arti_dataset.add_dir('../data/heart_attack/')
wandb.log_artifact(arti_dataset)
```python
arti_code=wandb.Artifact('ipynb',type='code')
arti_code.add_file('./wand_test.ipynb')
wandb.log_artifact(arti_code)
wandb.finish()

Table

可视化分析

wandb.init(project=config.project_name,id=model_run_id,resume='must')
good_cases=wandb.Table(columns=['id','GroundTrue','Prediction'])
bad_cases=wandb.Table(columns=['id','GroundTrue','Prediction'])

在代码中加入如下:

good_cases.add_data(i,y,prediction)

一般是用于比对feature、label和prediction


http://www.kler.cn/a/328597.html

相关文章:

  • 第6章详细设计 -6.7 PCB工程需求表单
  • 开源项目低代码表单设计器FcDesigner获取表单的层级结构与组件数据
  • 主机型入侵检测系统(HIDS)——Elkeid在Centos7的保姆级安装部署教程
  • 抖音热门素材去哪找?优质抖音视频素材网站推荐!
  • 【AlphaFold3】开源本地的安装及使用
  • RTSP播放器EasyPlayer.js播放器UniApp或者内嵌其他App里面webview需要截图下载
  • 自然语言处理问答系统技术
  • html5 + css3(下)
  • STL容器适配器
  • OpenCV 形态学相关函数详解及用法示例
  • 字符串逆序
  • 滚雪球学MySQL[3.3讲]:MySQL复杂查询详解:CASE语句、自连接与视图管理
  • OpenCV视频I/O(11)视频采集类VideoCapture之设置视频捕获设备的属性函数 set()的使用
  • Go语言入门:掌握基础语法与核心概念
  • 决策树的损失函数公式详细说明和例子说明
  • JS+HTML基础
  • 小徐影院:探索Spring Boot的影院管理
  • 您的计算机已被Lockbit3.0勒索病毒感染?恢复您的数据的方法在这里!
  • Windows 上安装 PostgreSQL
  • Qt界面优化——QSS
  • hystrix微服务部署
  • Raft 协议解读:简化分布式一致性
  • 美洽客户服务AI Agent 1.0,全渠道多场景赋能业务增长
  • linux 网络序
  • 快速实现AI搜索!Fivetran 支持 Milvus 作为数据迁移目标
  • 【Linux】进程概念-2