基于LSTM和SSUN模型的高光谱遥感分类实现
为什么高光谱遥感需要深度学习?
高光谱遥感数据因其丰富的光谱和空间信息在地物分类、农作物识别、环境监测等领域具有广泛应用。然而,传统的分类方法(如支持向量机SVM、随机森林RF等)难以充分挖掘高光谱数据的高维特性和时空依赖性。因此,引入深度学习模型,如LSTM(长短期记忆网络)和SSUN(光谱-空间联合网络),来解决这些问题成为一个研究热点。
在本项目中,我们结合了LSTM和**SSUN(Spectral-Spatial Unified Network)**两种模型,并通过PyTorch框架实现了高光谱遥感分类任务。我们重点展示如何从数据预处理、模型构建、训练到最终分类图的生成,完成端到端的遥感分类工作。
1. 项目目录结构与核心功能
该项目的主要功能包括高光谱数据加载、LSTM与SSUN模型的实现、分类任务的训练与测试,以及最终分类结果的可视化。以下是项目的目录结构与核心功能:
目录结构
- configs/:存放模型参数和训练配置文件。
- data/:数据预处理模块,包括训练集和测试集的划分。
- model/:
LSTM.py
:实现基于LSTM的高光谱分类模型。SSUN.py
:实现光谱-空间联合分类网络(SSUN)。
- tool/:
train.py
:训练模块。test.py
:测试模块。show.py
:分类结果的可视化工具。
- weights/:存放模型训练后的权重文件。
- main.py:主程序,整合各个模块,实现完整的训练-测试流程。
2. 数据处理与加载
高光谱遥感数据通常以.mat
文件格式存储,包含光谱图像和对应的地物分类标签(Ground Truth)。在本项目中,我们使用经典数据集如Indian Pines和Salinas。
数据加载
数据加载与划分的代码位于data/
目录下,通过HSI_data.py
实现数据的读取和预处理,包括:
- 归一化:将每个光谱波段的值标准化至[0, 1]区间。
- 划分训练集与测试集:根据配置文件指定比例,生成训练集、测试集和无标签数据集。
from data.HSI_data import HSI_data
from data.get_train_test_set import get_train_test_set
data_sets = get_train_test_set(cfg_data)
train_data = HSI_data(data_sets, cfg_data['train_data'])
test_data = HSI_data(data_sets, cfg_data['test_data'])
no_gt_data = HSI_data(data_sets, cfg_data['no_gt_data'])
3. 模型设计
本项目中实现了两种深度学习模型:
- LSTM模型:通过学习光谱特征的时序依赖性,实现光谱分类。
- SSUN模型:联合建模光谱与空间信息,实现高精度的分类。
(1) LSTM模型
LSTM适合处理序列数据。在高光谱分类任务中,LSTM用于提取光谱维度上的时序特征。
LSTM的核心代码:model/LSTM.py
class lstm(nn.Module):
def __init__(self, band_num=4, chose_model=1):
super(lstm, self).__init__()
self.band = band_num
self.chose_model = chose_model
self.lstm_model = nn.LSTM(
input_size=55, # 每个时间步的输入大小(光谱波段数)
hidden_size=128, # LSTM隐层单元数
num_layers=1, # LSTM层数
batch_first=True # batch位于第一维
)
self.outlayer = nn.Sequential(
nn.Linear(128, 50),
nn.Linear(50, 16) # 最终输出分类数
)
def forward(self, x):
# 数据维度变换与重构
b, c, h_size, w_size = x.shape
input = x[:, :, h_size // 2, w_size // 2]
input = input.reshape(b, c)
nb_features = int(c // self.band)
input_reshape = torch.zeros((x.shape[0], self.band, nb_features)).type_as(x)
if self.chose_model == 1:
for j in range(0, self.band):
input_reshape[:, j, :] = input[:, j:j + (nb_features - 1) * self.band + 1:self.band]
else:
for j in range(0, self.band):
input_reshape[:, j, :] = input[:, j * nb_features:(j + 1) * nb_features]
out, (h0, c0) = self.lstm_model(input_reshape)
out = out[:, -1, :]
out_finnal = self.outlayer(out)
return out, out_finnal
(2) SSUN模型
SSUN(Spectral-Spatial Unified Network)联合建模光谱和空间信息。通过光谱维度的特征提取和空间维度的上下文建模,SSUN能够更好地捕捉地物类别的特性。
SSUN的核心代码:model/SSUN.py
class SSUN(nn.Module):
def __init__(self):
super(SSUN, self).__init__()
# 光谱特征提取
self.spectral_conv = nn.Conv3d(1, 16, kernel_size=(7, 1, 1), stride=(1, 1, 1), padding=(3, 0, 0))
# 空间特征提取
self.spatial_conv = nn.Conv3d(16, 32, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
# 分类头
self.classifier = nn.Sequential(
nn.Linear(32, 64),
nn.ReLU(),
nn.Linear(64, 16) # 输出16个类别
)
def forward(self, x):
x = self.spectral_conv(x)
x = F.relu(self.spatial_conv(x))
x = x.view(x.size(0), -1) # 展平
x = self.classifier(x)
return x
4. 模型训练与测试
训练过程
训练过程通过tool/train.py
实现,包含以下步骤:
- 数据加载与划分。
- 初始化模型、优化器和损失函数。
- 迭代训练,更新模型权重。
from tool.train import train
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = SSUN().to(device)
optimizer = optim.Adam(model.parameters(), lr=cfg_optim['lr'], weight_decay=cfg_optim['weight_decay'])
loss_fun = nn.CrossEntropyLoss()
train(train_data, model, loss_fun, optimizer, device, cfg_train)
测试与结果可视化
测试过程通过tool/test.py
实现,评估模型在测试集上的分类精度,并生成预测结果。
可视化工具tool/show.py
将模型的预测结果转换为图像形式,便于直观地观察分类效果。
from tool.test import test
from tool.show import Predict_Label2Img_slect
pred_test_label = test(test_data, data_sets['ori_gt'], model, device, cfg_test)
HSI = Predict_Label2Img_slect(pred_test_label, cfg_num)
plt.imshow(HSI)
plt.show()