当前位置: 首页 > article >正文

batc和mini-batch

一、概念介绍

batch

批处理,在机器学习中,batch 是指一次处理整个训练数据集的方式。例如,如果有 1000 个训练样本,使用 batch 训练时,模型会同时使用这 1000 个样本进行一次参数更新。也就是说,计算损失函数(如均方误差、交叉熵等)是基于整个数据集的所有样本。

mini-batch

小批次,将整个训练数据集分成多个较小的子集(批次)来进行训练。比如还是 1000 个训练样本,我们可以将其分成 10 个 mini - batch,每个 mini - batch 包含 100 个样本。模型在训练时,每次使用一个 mini - batch 来计算损失和更新参数。

二、区别

batch 训练参数更新方向更稳定但可能陷入局部最优;mini - batch 在训练中有一定随机性,有助于寻找全局最优,但批次过小时可能使训练不稳定。

三、使用场景

数据量较小:使用batch;
数据量较大:使用mini-batch;在神经网络基本使用这个

四、mini-batch代码

以下是在深度学习模型中使用batch和mini - batch的方法:

1. 数据准备阶段

  • 数据加载:首先,需要将原始数据加载到程序中。对于图像数据,可以使用ImageDataLoader(PyTorch中)等工具;对于文本数据,可以使用DataLoader结合自定义的文本处理函数。例如,在PyTorch中:
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader

# 加载MNIST数据集
train_data = MNIST(root='data/', train=True, transform=ToTensor(), download=True)
test_data = MNIST(root='data/', train=False, transform=ToTensor(), download=True)
  • 划分批次(针对mini - batch):使用数据加载器将数据集划分为指定大小的批次。例如,继续上面的代码,设置batch_size为64来创建训练集和测试集的数据加载器:
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)

这里batch_size参数决定了每个mini - batch的样本数量,shuffle参数用于在每个训练轮次(epoch)开始时是否打乱数据顺序,对于mini - batch训练通常设置为True以增加随机性;对于测试集,一般不需要打乱数据。

2. 模型训练阶段

  • 使用mini - batch进行训练:在训练循环中,每次从数据加载器中获取一个mini - batch的数据进行训练。以下是一个典型的使用PyTorch训练神经网络的示例:
# 假设model是已经定义好的模型,criterion是损失函数,optimizer是优化器
for epoch in range(num_epochs):
    for i, (x_batch, y_batch) in enumerate(train_loader):
        # 前向传播
        outputs = model(x_batch)
        loss = criterion(outputs, y_batch)

        # 反向传播和更新参数
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 100 == 0:
            print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {loss.item()}')

在这个示例中,train_loader每次迭代返回一个mini - batch的输入数据x_batch和对应的标签y_batch。模型使用这些数据进行前向传播计算预测值,然后计算损失,接着进行反向传播更新模型参数。

enumerate将一个可遍历的数据对象(如列表、元组、字符串或迭代器)组合为一个索引序列。

3. 模型评估阶段

  • 使用mini - batch评估:在测试循环中,使用测试数据加载器以mini - batch的方式获取数据进行评估。例如:
correct = 0
total = 0
with torch.no_grad():
    for x_batch, y_batch in test_loader:
        outputs = model(x_batch)
        _, predicted = torch.max(outputs.data, 1)
        total += y_batch.size(0)
        correct += (predicted == y_batch).sum().item()

accuracy = correct / total
print(f'Test Accuracy: {accuracy}')

这里使用test_loader以mini - batch方式获取测试数据,对每个mini - batch进行预测,并统计正确预测的样本数量,最后计算模型在整个测试集上的准确率。


http://www.kler.cn/a/375050.html

相关文章:

  • 数字工厂管理系统就是ERP系统吗
  • QT-基础-1-Qt 中的字符串处理与常见数据类型
  • 深入解析:Python中的决策树与随机森林
  • log4j2漏洞复现(CVE-2021-44228)
  • 突发!!!GitLab停止为中国大陆、港澳地区提供服务,60天内需迁移账号否则将被删除
  • ajax中get和post的区别,datatype返回的数据类型有哪些?web开发中数据提交的几种方式,有什么区别。
  • 苹果开发 IOS 证书生成步骤
  • HT71672 13V,12A全集成同步升压转换器
  • Linux系统块存储子系统分析记录
  • stm32不小心把SWD和JTAG都给关了,程序下载不进去,怎么办?
  • CSS--导航栏案例
  • Python小白学习教程从入门到入坑------第十七课 内置函数拆包(语法基础)
  • 100种算法【Python版】第30篇——IDA*算法
  • Altium Designer使用技巧(一)
  • 向量数据库:PGVector 为AI知识库做准备
  • qt QRadioButton详解
  • 人工智能:改变未来生活与工作的无尽可能
  • 汽车免拆诊断案例 | 2010款起亚赛拉图车发动机转速表指针不动
  • Doris集群搭建
  • 服务器被攻击黑洞后如何自救
  • Debian下载ISO镜像的方法
  • yum不能使用: cannot find a valid baseurl for repo: base/7/x86_64
  • ASP.NET创建网站(一):创建新项目login页面设计
  • Gradio DataFrame分页功能详解:从入门到实战
  • 你的网站需要防护吗?
  • linux使用jar包部署solr