当前位置: 首页 > article >正文

【单层神经网络】基于MXNet库简化实现线性回归

写在前面

同最开始的两篇文章

完整程序及注释

'''
导入使用的库
'''
# 基本
from mxnet import autograd, nd, gluon
# 模型、网络
from mxnet.gluon import nn                     
from mxnet import init
# 学习
from mxnet.gluon import loss as gloss
# 数据集
from mxnet.gluon import data as gdata
'''
生成测试数据集
'''
# 被拟合参数
true_w = [2, -3.4]      # 特征的权重系数
true_b = 4.2            # 整体模型的偏置
# 创建训练数据集
num_inputs = 2          
num_examples = 1000
features = nd.random.normal(loc=0, scale=1, shape=(num_examples, num_inputs))  # 均值为0,标准差为1
labels = true_w[0]*features[:,0] + true_w[1]*features[:,1] + true_b
labels_noise = labels + nd.random.normal()
'''
确定模型
'''
net = nn.Sequential()                       # 声明一个Sequential容器,存放Neural Network
net.add(nn.Dense(1))                        # 向容器中添加一个全连接层,且不使用激活函数,“1”表示该全连接层的输出神经元有1个
net.initialize(init.Normal(sigma=0.01))     # 权重参数随机取自均值=0,标准差=0.01的高斯分布,bias默认=0
'''
确定学习方式
'''
loss = gloss.L2Loss()       # L2范数损失 等价于 平方损失
# .collect_params()方法获取net实例的全部参数,并提供给trainer
# 选择小批量随机梯度下降法(sgd)寻优,学习率为0.03
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.03})
'''
数据集采样
'''
batch_size = 10
dataset = gdata.ArrayDataset(features, labels_noise)        # 将标签和特征组合成完整数据集
# DataLoader返回一个迭代器,每次从数据集中提取一个长度为batch_size的子集出来
data_iter = gdata.DataLoader(dataset, batch_size, shuffle=True) # shuffle=True 打乱数据集(随机采样)
'''
开始训练
'''
num_epoch = 3       # 训练轮次
for epoch in range(0, num_epoch):
    for x, y in data_iter:          # 随机取出一组小批量,同时做到遍历
        with autograd.record():     # 自动保存梯度数据
            l = loss(net(x), y)     # 将得到的一组特征放入网络,求得到的输出与对应的标签(含噪声)的损失
        l.backward()                # 计算该次损失的梯度
        trainer.step(batch_size)    # 反向传播,基于l.backward()得到的梯度来更新模型的参数
    l = loss(net(features), labels_noise)     # 该轮训练结束后,求网络对数据集特征的输出,再求输出和含噪声标签的损失
    print('epoch %d, mean loss: %f' % (epoch+1, l.mean().asnumpy()))  # 展示训练轮次和数据集损失的平均

具体函数解释

trainer.step(batch_size):batch_size指定了当前批的大小,用于计算这次梯度下降的步长

with autograd.record():这行代码的作用是在其作用域内的计算将会被记录下来,以便自动求导


http://www.kler.cn/a/529339.html

相关文章:

  • vue2项目(一)
  • x86-64数据传输指令
  • Pandas基础07(Csv/Excel/Mysql数据的存储与读取)
  • 异步编程进阶:Python 中 asyncio 的多重应用
  • Ubuntu 18.04安装Emacs 26.2问题解决
  • 表格结构标签
  • Python sider-ai-api库 — 访问Claude、llama、ChatGPT、gemini、o1等大模型API
  • ollama和deepseek-r1-1.5b和AnythingLLM
  • Java 有很多常用的库
  • 【4. C++ 变量类型详解与创新解读】
  • UI线程用到COM只能选单线程模型
  • [CVPR 2024] AnyDoor: Zero-shot Object-level Image Customization
  • 17.2 图形绘制7
  • ES的机架感知-Rack Awareness
  • kimi,天工,gpt,deepseek效果对比
  • 【Arxiv 大模型最新进展】TOOLGEN:探索Agent工具调用新范式
  • python 从知网的期刊导航页面抓取与农业科技相关的数据
  • 网络测试工具
  • 前端学习-事件委托(三十)
  • 简单易懂的倒排索引详解
  • 仿真设计|基于51单片机的温湿度、一氧化碳、甲醛检测报警系统
  • AI 计算的未来:去中心化浪潮与全球竞争格局重塑
  • 迪杰斯特拉(Dijkstra)算法
  • “新月之智”智能战术头盔系统(CITHS)
  • 抖♬♬__ac_signature 算法逆向分析
  • mybatis辅助配置