当前位置: 首页 > article >正文

深度学习——回归实战

线性回归:

线性:自变量和应变量之间是线性关系,如:y = wx +b

回归:拟合一条曲线,使真实值和拟合值差距尽可能小

目标:求解参数w和b         所用算法:梯度下降算法

梯度下降:向着梯度方向(下降最快的方向)走一步,不停迭代

梯度下降过程:

训练循环(核心步骤):

  • 前向传播
    • 将训练集中的一批数据(一个批次)输入到模型中,通过模型的各层计算得到输出。
    • 这个过程中,数据按照模型的架构顺序依次通过各层,每层根据其权重和偏置对数据进行计算,如在全连接层中,计算是通过矩阵乘法和加法实现的。
  • 计算损失
    • 使用定义好的损失函数,计算模型预测输出与该批次数据真实标签之间的损失。
    • 损失值反映了当前模型在这一批次数据上的预测误差大小。
  • 反向传播
    • 根据计算得到的损失,通过链式法则从最后一层开始,逐层计算损失对每个模型参数(权重和偏置)的梯度。
    • 反向传播算法使得模型能够知道每个参数对损失的影响程度,从而为参数更新提供依据。
  • 更新参数
    • 使用选择的优化器,根据计算得到的梯度更新模型的参数。例如,在使用 SGD 优化器时,参数更新公式为:参数 = 参数 - 学习率 * 梯度。
    • 更新后的参数将用于下一次迭代的前向传播,通过不断重复这个过程,模型的参数逐渐调整,使得损失函数不断减小。
      import torch #深度学习框架
      import matplotlib.pyplot as plt #画图
      import random #随机
      
      
      def create_data(w, b, data_num):   #生成数据
          x = torch.normal(0, 1, (data_num, len(w)))   #平均数为0,方差为1,长度为data_num,宽度为len(w)
          y = torch.matmul(x, w) + b #通过矩阵乘法将输入数据x与权重w相乘,然后加上偏置项b,生成新的输出y
      
          noise = torch.normal(0, 0.01, y.shape)   #噪声要加到y上
          y+= noise      #模拟真实数据的不确定性、防止模型过拟合
      
          return x, y
      
      num = 500   # 生成的数据数量
      
      true_w = torch.tensor([8.1, 2, 2, 4])  # 真实的权重
      true_b = torch.tensor(1.1)   # 真实的偏置
      
      
      X, Y = create_data(true_w, true_b, num)  # 生成数据
      
      plt.scatter(X[:, 0], Y, 1)  #对x张量进行切片,选择所有行、第一列
      plt.show()
      
      def data_provider(data, label, batchsize):   #每次访问这个函数,就能提供一批数据,传入参数依次为:数据、标签、步长
          length = len(label)   # 获取标签数据的长度,由于数据和标签一一对应,以此代表整个数据集的长度
          indices = list(range(length))  #  创建一个从0到length - 1的索引列表,每个索引对应数据集中的一个样本,用于后续操作
          random.shuffle(indices)  # 对索引列表进行随机打乱,确保每次取数据批次时是随机顺序,避免数据顺序依赖,增强模型训练效果
      
          # 按照指定的批量大小batchsize遍历整个数据集,每次取出一个批次的数据范围
          for each in range(0, length, batchsize):
              get_indices = indices[each: each+batchsize]  # 从打乱后的索引列表中取出当前批次对应的索引范围
              get_data = data[get_indices]   # 根据取出的索引范围从数据张量data中获取当前批次的数据
              get_label = label[get_indices]  # 根据取出的索引范围从标签张量label中获取当前批次对应的标签
      
              yield get_data, get_label  # 使用yield关键字返回当前批次的数据和标签,使函数成为生成器,下次调用继续返回下一批次
      
      batchsize = 16  #步长设置为16
      # for batch_x, batch_y in data_provider(X, Y, batchsize):
      #     print(batch_x, batch_y)
      #     #break
      
      # 定义函数fun,用于根据输入数据x、权重w和偏置b进行线性变换计算,得到预测输出
      def fun(x, w, b):
          pred_y = torch.matmul(x, w) + b   # 使用torch.matmul对输入数据x和权重w进行矩阵乘法运算,然后加上偏置b,得到预测输出pred_y
          return pred_y
      
      # 定义函数maeloss,用于计算预测值pre_y和真实值y之间的平均绝对误差(MAE)损失
      def maeloss(pre_y, y):
          return torch.sum(abs(pre_y-y))/len(y)  # 先计算预测值和真实值之间差值的绝对值,再对所有差值的绝对值求和,最后除以数据数量len(y),得到平均绝对误差作为损失值并返回
      
      # 定义函数sgd,实现随机梯度下降算法,用于更新模型参数
      def sgd(paras, lr):          #随机梯度下降,更新参数
          # 使用torch.no_grad()上下文管理器,在这个范围内的操作不会进行梯度计算,因为参数更新阶段不需要对更新操作本身计算梯度
          with torch.no_grad():  #属于这句代码的部分,不计算梯度
              for para in paras:# 遍历要更新的参数列表paras中的每一个参数
                  # 根据随机梯度下降算法规则,将当前参数para减去其梯度para.grad与学习率lr的乘积,实现参数更新(注意要用 -= 操作符进行原位更新)
                  para -= para.grad* lr  #不能写成  para = para - para.grad*lr
                  para.grad.zero_()     #使用过的梯度,归0,避免下一次迭代时梯度累积,导致参数更新错误
      
      lr = 0.03  #学习率
      w_0 = torch.normal(0, 0.01, true_w.shape, requires_grad=True)   # 使用torch.normal函数按照正态分布来初始化权重参数w_0,其中均值为0,方差为0.01,形状与之前定义的真实权重true_w保持一致,并且设置requires_grad为True,意味着这个参数在后续的计算中需要跟踪计算梯度,以便进行自动求导来更新它
      b_0 = torch.tensor(0.01, requires_grad=True)  # 初始化偏置参数b_0,将其设置为值是0.01的标量张量,同时设置requires_grad为True,这样该参数就能参与到梯度计算以及后续的参数更新过程中
      print(w_0, b_0)
      
      epochs  = 50#训练多少轮
      
      # 按训练轮数epochs循环,每轮训练模型
      for epoch in range(epochs):
          data_loss = 0   # 初始化本轮累计损失为0
          for batch_x, batch_y in data_provider(X, Y, batchsize):   # 按批次获取数据,循环处理每批次
              pred_y = fun(batch_x, w_0, b_0)   # 用当前参数预测本批次数据,得预测值
              loss = maeloss(pred_y, batch_y)  # 计算预测值与真实值的损失
              loss.backward()  # 反向传播求梯度
              sgd([w_0, b_0], lr)  # 用sgd更新模型参数
              data_loss += loss  # 累加本批次损失到本轮累计损失
      
          print("epoch %03d: loss: %.6f"%(epoch, data_loss))  # 打印本轮轮数和累计损失,观察训练情况
      
      print("真实的函数值是", true_w, true_b)
      print("训练得到的参数值是", w_0, b_0)
      
      idx = 0  # 初始化一个索引变量idx为0,这个索引通常用于选择数据张量X中的某一列数据,后续可能用于可视化等相关操作
      plt.plot(X[:, idx].detach().numpy(), X[:, idx].detach().numpy()*w_0[idx].detach().numpy()+b_0.detach().numpy())  # 绘制拟合直线,取X第idx列数据转numpy作横坐标,按线性回归公式用训练参数算纵坐标来绘制
      plt.scatter(X[:, idx], Y, 1)  # 绘制散点图,以输入数据张量X的第idx列数据作为横坐标,以对应的输出数据Y作为纵坐标,展示原始数据的分布情况
      plt.show()
      Tips:
      1、yield get_data, get_label
        当函数执行到 yield 语句时,函数会暂停执行,将 yield 后面的值返回给调用者,但函数并没有结束。下次调用这个函数时,它会从上次暂停的地方继续执行,直到遇到下一个 yield 或者函数结束。这意味着在一个生成器函数中,可以通过多个 yield 语句多次返回不同的值。就像在 data_provider 函数中,每次调用会返回一个新的数据批次和标签批次,直到所有批次都返回完。
      2、para.grad.zero_() 
        如果我们不清零这个梯度,在第二次训练批次进行反向传播时,新计算出来的梯度会和第一次遗留下来的梯度相加。就好像你在走迷宫,第一次得到的指示(梯度)是向左走三步,但是你没记住这个指示,第二次又得到一个指示(新的梯度)是向右走两步,但是你把两次的指示混在一起,变成了向左走一步(假设梯度相加的情况),这样就会让你的方向(参数更新方向)变得混乱。
      3、plt.plot(X[:, idx].detach().numpy(), X[:, idx].detach().numpy()*w_0[idx].detach().numpy()+b_0.detach().numpy())
        ①绘制一条直线,用于表示拟合的线性关系(在二维平面上展示线性回归拟合情况)。
        ②首先从输入数据张量X中取出第idx列数据(通过X[:, idx]),并将其转换为numpy数组(使用detach().numpy()方法,目的是从计算图中分离出来并转为numpy格式方便绘图),作为横坐标。
       然后根据线性回归的公式y = w * x + b,计算对应的纵坐标,这里使用当前训练得到的权重w_0的第idx个元素(w_0[idx])和偏置b_0,同样转换为numpy数组后参与计算,以此绘制出拟合的直线。
      

http://www.kler.cn/a/472841.html

相关文章:

  • WebSocket监听接口
  • Openwrt @ rk3568平台 固件编译实践(二)- ledeWRT版本
  • 杭州市有哪些大学能够出具论文检索报告?
  • 【大模型】百度千帆大模型对接LangChain使用详解
  • 【Axios使用手册】如何使用axios向后端发送请求并进行数据交互
  • Lianwei 安全周报|2024.1.7
  • rust学习——环境搭建
  • 海思Linux-DEMO(1)-sample_venc(h265,h264)视频流文件的获取
  • TRAVEO™ T2G的SWAP功能
  • 服务器及MySQL安全设置指南
  • 使用Postman进行Base64解码
  • 使用 Rust 实现零拷贝数据处理:性能优化的极致探索
  • 如何监听Vuex数据的变化?
  • 第四届智能系统、通信与计算机网络国际学术会议(ISCCN 2025)
  • 虚拟机配置静态ip后出现两个ip问题
  • 单片机毕业设计项目分享(4)
  • 实验四 数组和函数
  • Mysql--基础篇--事务(ACID特征及实现原理,事务管理模式,隔离级别,并发问题,锁机制,行级锁,表级锁,意向锁,共享锁,排他锁,死锁,MVCC)
  • 深入学习RabbitMQ的Direct Exchange(直连交换机)
  • 常见的http状态码 + ResponseEntity
  • 设计模式 结构型 桥接模式(Bridge Pattern)与 常见技术框架应用 解析
  • 【centos8 ES】Elasticsearch linux 龙晰8操作系统安装
  • Flink分区方式有哪些
  • Unity:删除注册表内的项目记录
  • 05、Docker学习,常用安装:Mysql、Redis、Nginx、Nacos
  • springboot点餐平台网站