torch.fft 出现 ComplexHalf 或 Half 不支持
错误报告
错误信息
-
错误类型:
RuntimeError: Unsupported dtype Half
发生位置:torch.fft.rfft2
-
错误类型:
RuntimeError: "roll_cuda" not implemented for 'ComplexHalf'
发生位置:fft.fftshift
环境信息
bash
pip list | grep torch
torch 1.12.1+cu102
torchaudio 0.12.1+cu102
torchvision 0.13.1+cu102
原因
- 在 PyTorch 的 FFT 操作中,许多操作不支持
ComplexHalf
数据类型。
解决
- 将
ComplexHalf
转换为ComplexFloat
:
在调用 FFT 操作之前,确保将输入数据类型转换为ComplexFloat
。
python
if x.dtype == torch.complex64:
x = x.to(torch.complex64)
- 使用 CPU 处理:
如果性能要求不高,可以选择在 CPU 上执行 FFT 操作,避免 CUDA 相关问题。
python
xfreq = fft.fftn(x.cpu(), dim=(-2, -1))
xfreq = fft.fftshift(x_freq, dim=(-2, -1))