当前位置: 首页 > article >正文

深度学习网格搜索实战

还是使用房价数据集进行实战。因为模型简单,使用超参数搜索的时候速度快。

在之前的回归代码的基础上加入for循环:

for lr in [1e-2, 3e-2, 3e-1, 1e-3]: # 把参数组合放在这,参数代表学习率
    #每次拿一个参数就要重新实例化一个模型
    epoch = 100
    model = NeuralNetwork()

    # 1. 定义损失函数 采用MSE损失
    loss_fct = nn.MSELoss()
    # 2. 定义优化器 采用SGD
    # Optimizers specified in the torch.optim package
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)

    # 3. early stop
    early_stop_callback = EarlyStopCallback(patience=10, min_delta=1e-3)

    model = model.to(device)
    record = training(
        model, 
        train_loader, 
        val_loader, 
        epoch, 
        loss_fct, 
        optimizer, 
        early_stop_callback=early_stop_callback,
        eval_step=len(train_loader)
        )
    print("lr: {}".format(lr))
    plot_learning_curves(record)
    model.eval()
    loss = evaluating(model, val_loader, loss_fct)
    print(f"loss:     {loss:.4f}")

效果:


http://www.kler.cn/a/576537.html

相关文章:

  • GPO 配置的 4 种常见安全错误及安全优化策略
  • 机器学习(李宏毅)——Life-Long Learning
  • 在 Docker 中安装并配置 Nginx
  • Service 无法访问后端 Pod,如何逐步定位问题
  • C++ `bitset` 入门指南
  • 蓝桥杯P17153-班级活动 题解
  • Elastic如何获取当前系统时间
  • 【算法方法总结·四】字符串操作的一些技巧和注意事项
  • Chrome 扩展开发:Chrome 扩展的作用和开发意义(一)
  • Ollama 框架本地部署教程:开源定制,为AI 项目打造专属解决方案!
  • 网络与网络安全
  • 每天记录一道Java面试题---day28
  • 3.6 登录认证
  • el-table一格两行;概率;find
  • 面向服务的架构风格
  • P63 C++当中的计时
  • Vim复制内容到系统剪切板
  • 深入HarmonyOS NEXT开发中的ArkData操作SQLite数据库
  • 如何收集 Kubernetes 集群的日志
  • 在 k8s中查看最大 CPU 和内存的极限