当前位置: 首页 > article >正文

从0书写一个softmax分类 李沐pytorch实战

 输出维度

在softmax 分类中 我们输出与类别一样多。 数据集有10个类别,所以网络输出维度为10。

 初始化权重和偏置

torch.norma 生成一个均值为 0,标准差为0.01,一个形状为size=(num_inputs, num_outputs)的张量

偏置生成一个num_outputs =10 的一维张量,并用0初始化 

W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)###
b = torch.zeros(num_outputs, requires_grad=True)

requires_grad=True,PyTorch 会在后向传播过程中自动计算该张量的梯度,这对于优化模型参数非常重要。

sum 运算符工作机制:

X = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
X.sum(0, keepdim=True), X.sum(1, keepdim=True)

sum = 0 张量按列求和 sum = 1 张量按行求和 

定义softmax函数 

def softmax(X):
    X_exp = torch.exp(X)
    partition = X_exp.sum(1, keepdim=True)
    return X_exp / partition  # 这里应用了广播机制

测试

X = torch.normal(0, 1, (2, 5))
X_prob = softmax(X)
print(X_prob)
print(X_prob.sum(1))

定义sofrmax模型:

softmax 回归模型 定义

def net(X):
    return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)

 W.shape[0]表示权重的第一维大小

reshape 函数会根据原始张量 X 的元素总数和你提供的其他维度来计算出 -1 代表的维度

定义交叉熵损失函数:

回顾

y = torch.tensor([0, 2])
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
print(y_hat[[0, 1], y])

y 张量是两个真实类别,第0类和第二类

y_hat 是对两个类别 在三种类别上的预测,真实的第0类预测结果为0.1,第2类预测结果为0.5

输出,取出y_hat的指定索引,

  • 对于第一个样本(索引 0),取 y_hat[0, 0],即 0.1
  • 对于第二个样本(索引 1),取 y_hat[1, 2],即 0.5
def cross_entropy(y_hat, y):
    return - torch.log(y_hat[range(len(y_hat)), y])

print(f'交叉熵损失为{cross_entropy(y_hat, y)}')

交叉熵损失为tensor([2.3026, 0.6931])

定义分类精度:

计算出 正确预测数量与总预测数量之比

def accuracy(y_hat, y):  #@save
    """计算预测正确的数量"""
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1)
    cmp = y_hat.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum())
print(accuracy(y_hat, y) / len(y))
  • y_hat 的形状为 (2, 3),这意味着:
    • 第一维的大小是 2,表示有 2 个样本(行)。
    • 第二维的大小是 3,表示每个样本有 3 个类别的预测概率(列)。

因此,y_hat.shape[1] 返回的是第二维的大小,也就是 3。这个值表示每个样本的类别数。在这个例子中,y_hat 中的每一行包含了对应样本对 3 个类别的预测概率。

print :y_hat.argmax(axis=1) 预测值张量为[2,2],与y = torch.tensor([0, 2])做对比

将布尔张量转换为整型张量

将[False,True],转换为0,1形式

(cmp.type(y.dtype).sum())

报错RuntimeError: DataLoader worker (pid(s) 12452, 3084, 29000, 29444) exited unexpectedly解决方法:

train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
test_iter.num_workers = 0
train_iter.num_workers = 0

再训练迭代器和测试迭代器后加入

定义评估模型准确率函数:

def evaluate_accuracy(net, data_iter):  #@save
    """计算在指定数据集上模型的精度"""
    if isinstance(net, torch.nn.Module):
        net.eval()  # 将模型设置为评估模式
    metric = Accumulator(2)  # 正确预测数、预测总数
    with torch.no_grad():
        for X, y in data_iter:
            metric.add(accuracy(net(X), y), y.numel())
    return metric[0] / metric[1]

 if isinstance(net, torch.nn.Module)是一个Python内置函数,用于检查对象net是否是torch.nn.Module类的实例。

  • 创建一个Accumulator对象,用于累加正确预测的数量和预测的总数量。Accumulator类会存储两个值。
  • 调用模型net对输入X进行预测,得到预测结果,然后使用accuracy函数计算预测的准确数量。y.numel()返回标签y中的元素总数(即样本数量),这两个值一起传递给metric.add()进行累加。

代码分析:

    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]
  • zip函数将self.data(当前存储的累积值)和args(输入的多个参数)配对。假设self.data是一个包含n个元素的列表,而args也是一个包含n个元素的可变参数列表。
  • 例如,如果self.data = [3.0, 5.0],而args = (2, 1)zip会生成[(3.0, 2), (5.0, 1)]的迭代器。
  • [a + float(b) for a, b in zip(self.data, args)]是一个列表推导式,用于遍历zip生成的配对。在这个过程中:
    • aself.data中的当前元素。
    • bargs中的当前元素。
  • 对每一对(a, b),该表达式计算a + float(b),将b转换为浮点数并与a相加。

 预测:

def predict_ch3(net, test_iter, n=9):  #@save
    """预测标签(定义见第3章)"""
    for X, y in test_iter:
        break
    trues = d2l.get_fashion_mnist_labels(y)
    preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))
    titles = [true +'\n' + pred for true, pred in zip(trues, preds)]
    d2l.show_images(
        X[0:n].reshape((n, 28, 28)), 1, n, titles=titles[0:n])
predict_ch3(net, test_iter)
d2l.plt.show()

 

 


http://www.kler.cn/a/312085.html

相关文章:

  • 车-路-站-网”信息耦合的汽车有序充电
  • 论文阅读《BEVFormer v2》
  • 什么是数字图像?
  • gdb编译教程(支持linux下X86和ARM架构)
  • SpringBoot(八)使用AES库对字符串进行加密解密
  • RS®SZM 倍频器
  • 《深入了解 Linux 操作系统》
  • Scrapy爬虫框架 Pipeline 数据传输管道
  • K8S容器实例Pod安装curl-vim-telnet工具
  • 人工智能在鼻咽癌中的应用综述|文献精析·24-09-13
  • Python中使用Redis布隆过滤器
  • 苹果为什么不做折叠屏手机?
  • 2024蓝桥杯省B好题分析
  • vulnhub靶机:Holynix: v1
  • GO CronGin
  • 【Flask教程】 flask安装简明教程
  • Visual Studio配置opencv环境
  • Web Worker 简单使用
  • 2024永久激活版 Studio One 6 Pro for mac 音乐创作编辑软件 完美兼容
  • 基于STM32设计的路灯故障定位系统(微信小程序)(229)
  • flink自定义process,使用状态求历史总和(scala)
  • spring boot启动报错:so that it conforms to the canonical names requirements
  • 【系统架构设计师-2017年真题】案例分析-答案及详解
  • C# Socket网络通信【高并发场景】
  • 【QT】重载信号Connect链接使用方式
  • cuda中使用二维矩阵