PyTorch使用教程- Tensor包
### PyTorch使用教程- Tensor包
PyTorch是一个流行的深度学习框架,它提供了一个易于使用的API来创建和操作张量(Tensors)。张量是一个多维数组,类似于NumPy中的ndarray,但它是基于GPU的,支持自动求导。本文将详细介绍PyTorch中的Tensor包,包括张量的创建、运算、形状变换、索引与切片、以及重要的张量处理方式。
#### 一、张量的创建
在PyTorch中,可以使用多种方法创建张量。以下是一些常用的创建张量的方法:
1. **torch.tensor()**
使用`torch.tensor()`函数可以直接从数据中创建张量。数据类型会自动推断。
```python
import torch
data = [[1, 2], [3, 4]]
x_data = torch.tensor(data)
print(x_data)
```
2. **torch.from_numpy()**
如果有一个NumPy数组,可以使用`torch.from_numpy()`函数将其转换为PyTorch张量。
```python
import numpy as np
np_array = np.array(data)
x_np = torch.from_numpy(np_array)
print(x_np)
```
3. **torch.ones_like() 和 torch.zeros_like()**
使用`torch.ones_like()`和`torch.zeros_like()`函数可以创建一个与给定张量形状相同但所有元素分别为1和0的新张量。
```python
x_ones = torch.ones_like(x_data)
print(x_ones)
x_zeros = torch.zeros_like(x_data)
print(x_zeros)
```
4. **torch.rand() 和 torch.randn()**
`torch.rand()`函数创建一个形状为指定维度的张量,其元素是从[0, 1)区间均匀分布的随机数。`torch.randn()`函数则创建一个形状为指定维度的张量,其元素是从标准正态分布(均值为0,方差为1)中抽取的随机数。
```python
rand_tensor = torch.rand((2, 3))
print(rand_tensor)
randn_tensor = torch.randn((2, 3))
print(randn_tensor)
```
5. **torch.full()**
使用`torch.full()`函数可以创建一个填充常数的张量。
```python
full_tensor = torch.full((3, 3), 5.)
print(full_tensor)
```
6. **torch.arange()**
`torch.arange()`函数生成一个从起始值到结束值(不包括结束值),步长为指定值的张量。
```python
range_tensor = torch.arange(1, 20, 2)
print(range_tensor)
```
7. **torch.empty()**
`torch.empty()`函数创建一个指定形状的未初始化张量。
```python
empty_tensor = torch.empty((2, 3))
print(empty_tensor)
```
#### 二、张量的运算
PyTorch提供了丰富的张量运算操作,包括算术运算、线性代数运算、矩阵操作等。以下是一些常用的张量运算:
1. **算术运算**
- **加法**:使用`torch.add()`函数或`+`运算符进行加法运算。
```python
a = torch.randn(2, 3)
b = torch.randn(2, 3)
c = torch.add(a, b)
print(c)
d = a + b
print(d)
```
- **标量加法**:将一个张量与一个标量相加。
```python
e = torch.add(d, 10)
print(e)
```
- **绝对值**:使用`torch.abs()`函数计算张量的绝对值。
```python
abs_tensor = torch.abs(a)
print(abs_tensor)
```
- **乘法**:使用`torch.mul()`函数或`*`运算符进行乘法运算。
```python
z1 = a * a
z2 = torch.mul(a, a)
print(z1)
print(z2)
```
- **除法**:使用`torch.div()`函数或`/`运算符进行除法运算。
```python
div_tensor = torch.div(a, b)
print(div_tensor)
```
- **幂运算**:使用`torch.pow()`函数计算张量的幂。
```python
pow_tensor = torch.pow(a, 2)
print(pow_tensor)
```
2. **矩阵运算**
- **矩阵乘法**:使用`torch.mm()`函数或`@`运算符进行矩阵乘法运算。
```python
matrix1 = torch.tensor([[1, 2], [3, 4]])
matrix2 = torch.tensor([[5, 6], [7, 8]])
result_matrix = torch.mm(matrix1, matrix2)
print(result_matrix)
result_matrix_2 = matrix1 @ matrix2
print(result_matrix_2)
```
- **矩阵转置**:使用`torch.t()`函数计算矩阵的转置。
```python
transposed_matrix = torch.t(matrix1)
print(transposed_matrix)
```
3. **统计运算**
- **求和**:使用`torch.sum()`函数计算张量元素的和。
```python
tensor = torch.tensor([[1, 2], [3, 4]])
sum_result = torch.sum(tensor)
print(sum_result)
```
- **平均值**:使用`torch.mean()`函数计算张量元素的平均值。
```python
mean_result = torch.mean(tensor)
print(mean_result)
```
- **标准差**:使用`torch.std()`函数计算张量元素的标准差。
```python
std_result = torch.std(tensor)
print(std_result)
```
- **最大值和最小值**:使用`torch.max()`和`torch.min()`函数找到张量中的最大值和最小值及其索引。
```python
max_value, max_index = torch.max(tensor, dim=1)
print(max_value)
print(max_index)
min_value, min_index = torch.min(tensor, dim=1)
print(min_value)
print(min_index)
```
#### 三、张量的形状变换
1. **torch.view()**
使用`torch.view()`函数可以改变张量的形状,但要确保元素数量不变。
```python
original_tensor = torch.arange(1, 9)
reshaped_tensor = original_tensor.view(2, 4)
print(reshaped_tensor)
```
2. **torch.reshape()**
`torch.reshape()`函数与`torch.view()`类似,但提供了更多的灵活性,可以在某些情况下自动推断维度。
```python
reshaped_tensor_2 = original_tensor.reshape(2, 4)
print(reshaped_tensor_2)
```
#### 四、张量的拼接与索引
1. **torch.cat()**
使用`torch.cat()`函数可以沿指定维度拼接张量。
```python
tensor1 = torch.tensor([[1, 2], [3, 4]])
tensor2 = torch.tensor([[5, 6]])
concatenated_tensor = torch.cat((tensor1, tensor2), dim=0)
print(concatenated_tensor)
```
2. **索引与切片**
使用索引和切片可以获取张量的特定元素或子集。
```python
matrix_tensor = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
first_row = matrix_tensor[0, :]
print(first_row)
first_column = matrix_tensor[:, 0]
print(first_column)
subset_tensor = matrix_tensor[1:, 1:]
print(subset_tensor)
```
#### 五、重要的张量处理方式
1. **clamp()**
`torch.clamp()`函数对输入张量按照自定义的范围进行裁剪。
```python
a = torch.