pytorch 均方误差损失函数
均方误差损失函数主要用于回归问题。它计算预测值与真实值之间差的平方,然后取平均值。这个损失函数通过惩罚大的误差,使得模型在训练时更加注重减少较大的偏差。
import torch
import torch.nn as nn
# 创建预测值和实际值张量
predicted = torch.tensor([0.5, 0.3, 0.2], requires_grad=True)
actual = torch.tensor([0.6, 0.1, 0.2])
# 定义 MSE 损失函数
criterion = nn.MSELoss()
# 计算损失
loss = criterion(predicted, actual)
print(loss.item()) # 输出损失值
解释:
predicted
是模型的预测输出,actual
是对应的真实值。nn.MSELoss()
定义了均方误差损失函数。loss = criterion(predicted, actual)
计算预测值和实际值之间的均方误差。.item()
用于从单个元素张量中提取数值。
参考
MSELoss — PyTorch 2.4 documentation