PyTorch 张量数据类型定义和转换
在 PyTorch 中,张量(Tensor)是进行深度学习和数值计算的核心数据结构。张量的数据类型(dtype)定义了张量中存储的数值类型,例如浮点数、整数等。正确地定义和转换张量的数据类型对于确保计算的精度和效率非常重要。以下是关于 PyTorch 张量数据类型定义和转换的详细介绍:
1. PyTorch 支持的数据类型
PyTorch 提供了多种数据类型,主要包括以下几类:
浮点类型
-
torch.float32
或torch.float
:32位浮点数,是默认的浮点类型。 -
torch.float64
或torch.double
:64位浮点数,精度更高。 -
torch.float16
或torch.half
:16位浮点数,适用于半精度计算,常用于 GPU 加速。 -
torch.bfloat16
:16位脑浮点数,适合某些特定的硬件加速。
整数类型
-
torch.int8
:8位有符号整数。 -
torch.uint8
:8位无符号整数。 -
torch.int16
:16位有符号整数。 -
torch.int32
:32位有符号整数。 -
torch.int64
:64位有符号整数。
其他类型
-
torch.bool
:布尔类型,值为True
或False
。 -
torch.complex64
和torch.complex128
:复数类型,分别对应单精度和双精度复数。
2. 定义张量数据类型
在创建张量时,可以通过 dtype
参数指定数据类型。以下是一些示例:
import torch
# 创建一个浮点型张量
tensor_float = torch.tensor([1.0, 2.0, 3.0