当前位置: 首页 > article >正文

基于LSTM和SSUN模型的高光谱遥感分类实现

为什么高光谱遥感需要深度学习?

高光谱遥感数据因其丰富的光谱和空间信息在地物分类、农作物识别、环境监测等领域具有广泛应用。然而,传统的分类方法(如支持向量机SVM、随机森林RF等)难以充分挖掘高光谱数据的高维特性和时空依赖性。因此,引入深度学习模型,如LSTM(长短期记忆网络)和SSUN(光谱-空间联合网络),来解决这些问题成为一个研究热点。

在本项目中,我们结合了LSTM和**SSUN(Spectral-Spatial Unified Network)**两种模型,并通过PyTorch框架实现了高光谱遥感分类任务。我们重点展示如何从数据预处理、模型构建、训练到最终分类图的生成,完成端到端的遥感分类工作。


1. 项目目录结构与核心功能

该项目的主要功能包括高光谱数据加载、LSTM与SSUN模型的实现、分类任务的训练与测试,以及最终分类结果的可视化。以下是项目的目录结构与核心功能:

目录结构
  • configs/:存放模型参数和训练配置文件。
  • data/:数据预处理模块,包括训练集和测试集的划分。
  • model/
    • LSTM.py:实现基于LSTM的高光谱分类模型。
    • SSUN.py:实现光谱-空间联合分类网络(SSUN)。
  • tool/
    • train.py:训练模块。
    • test.py:测试模块。
    • show.py:分类结果的可视化工具。
  • weights/:存放模型训练后的权重文件。
  • main.py:主程序,整合各个模块,实现完整的训练-测试流程。

2. 数据处理与加载

高光谱遥感数据通常以.mat文件格式存储,包含光谱图像和对应的地物分类标签(Ground Truth)。在本项目中,我们使用经典数据集如Indian PinesSalinas

数据加载

数据加载与划分的代码位于data/目录下,通过HSI_data.py实现数据的读取和预处理,包括:

  1. 归一化:将每个光谱波段的值标准化至[0, 1]区间。
  2. 划分训练集与测试集:根据配置文件指定比例,生成训练集、测试集和无标签数据集。
from data.HSI_data import HSI_data
from data.get_train_test_set import get_train_test_set

data_sets = get_train_test_set(cfg_data)
train_data = HSI_data(data_sets, cfg_data['train_data'])
test_data = HSI_data(data_sets, cfg_data['test_data'])
no_gt_data = HSI_data(data_sets, cfg_data['no_gt_data'])

3. 模型设计

本项目中实现了两种深度学习模型:

  • LSTM模型:通过学习光谱特征的时序依赖性,实现光谱分类。
  • SSUN模型:联合建模光谱与空间信息,实现高精度的分类。
(1) LSTM模型

LSTM适合处理序列数据。在高光谱分类任务中,LSTM用于提取光谱维度上的时序特征。

LSTM的核心代码:model/LSTM.py

class lstm(nn.Module):
    def __init__(self, band_num=4, chose_model=1):
        super(lstm, self).__init__()
        self.band = band_num
        self.chose_model = chose_model
        self.lstm_model = nn.LSTM(
            input_size=55,   # 每个时间步的输入大小(光谱波段数)
            hidden_size=128, # LSTM隐层单元数
            num_layers=1,    # LSTM层数
            batch_first=True # batch位于第一维
        )
        self.outlayer = nn.Sequential(
            nn.Linear(128, 50),
            nn.Linear(50, 16)  # 最终输出分类数
        )

    def forward(self, x):
        # 数据维度变换与重构
        b, c, h_size, w_size = x.shape
        input = x[:, :, h_size // 2, w_size // 2]
        input = input.reshape(b, c)
        nb_features = int(c // self.band)
        input_reshape = torch.zeros((x.shape[0], self.band, nb_features)).type_as(x)

        if self.chose_model == 1:
            for j in range(0, self.band):
                input_reshape[:, j, :] = input[:, j:j + (nb_features - 1) * self.band + 1:self.band]
        else:
            for j in range(0, self.band):
                input_reshape[:, j, :] = input[:, j * nb_features:(j + 1) * nb_features]

        out, (h0, c0) = self.lstm_model(input_reshape)
        out = out[:, -1, :]
        out_finnal = self.outlayer(out)
        return out, out_finnal
(2) SSUN模型

SSUN(Spectral-Spatial Unified Network)联合建模光谱和空间信息。通过光谱维度的特征提取和空间维度的上下文建模,SSUN能够更好地捕捉地物类别的特性。

SSUN的核心代码:model/SSUN.py

class SSUN(nn.Module):
    def __init__(self):
        super(SSUN, self).__init__()
        # 光谱特征提取
        self.spectral_conv = nn.Conv3d(1, 16, kernel_size=(7, 1, 1), stride=(1, 1, 1), padding=(3, 0, 0))
        # 空间特征提取
        self.spatial_conv = nn.Conv3d(16, 32, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
        # 分类头
        self.classifier = nn.Sequential(
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Linear(64, 16)  # 输出16个类别
        )

    def forward(self, x):
        x = self.spectral_conv(x)
        x = F.relu(self.spatial_conv(x))
        x = x.view(x.size(0), -1)  # 展平
        x = self.classifier(x)
        return x

4. 模型训练与测试

训练过程

训练过程通过tool/train.py实现,包含以下步骤:

  1. 数据加载与划分。
  2. 初始化模型、优化器和损失函数。
  3. 迭代训练,更新模型权重。

from tool.train import train

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = SSUN().to(device)
optimizer = optim.Adam(model.parameters(), lr=cfg_optim['lr'], weight_decay=cfg_optim['weight_decay'])
loss_fun = nn.CrossEntropyLoss()

train(train_data, model, loss_fun, optimizer, device, cfg_train)
测试与结果可视化

测试过程通过tool/test.py实现,评估模型在测试集上的分类精度,并生成预测结果。

可视化工具tool/show.py将模型的预测结果转换为图像形式,便于直观地观察分类效果。

from tool.test import test
from tool.show import Predict_Label2Img_slect

pred_test_label = test(test_data, data_sets['ori_gt'], model, device, cfg_test)
HSI = Predict_Label2Img_slect(pred_test_label, cfg_num)
plt.imshow(HSI)
plt.show()


http://www.kler.cn/a/446800.html

相关文章:

  • CTF学习24.12.21[隐写术进阶]
  • iOS + watchOS Tourism App(含源码可简单复现)
  • ViEW生命周期
  • CH340系列芯片驱动电路·CH340系列芯片驱动!!!
  • 《探秘 Qt Creator Manual 4.11.1》
  • web自动化测试知识总结
  • PCL点云库入门——PCL库中点云数据拓扑关系之K-D树(KDtree)
  • 1、学习大模型总纲
  • FreeRTOS的任务调度
  • 全志H618 Android12修改doucmentsui鼠标单击图片、文件夹选中区域
  • Suno Api V4模型无水印开发「高清音频WAV下载」 —— 「Suno Api系列」第6篇
  • netcore 集成Prometheus
  • 大数据-环保领域
  • 【1.排序】
  • 【Linux】-学习笔记10
  • 呼入机器人:24小时客户服务的未来趋势
  • 秒优科技-供应链管理系统 login/doAction SQL注入漏洞复现
  • Oracle筑基篇-通过一个事务流程筑基Oracle
  • 基于mmdetection进行语义分割(不修改源码)
  • 怎么通过亚矩阵云手机实现营销?
  • Go框架比较:goframe、beego、iris和gin
  • jvm栈帧中的动态链接
  • [Unity Shader]【图形渲染】 数学基础4 - 矩阵定义和矩阵运算详解
  • SQLAlchemy 2.0 高级特性详解
  • SpringMVC 进阶学习笔记
  • 【Python】【数据分析】深入探索 Python 数据可视化:Seaborn 可视化库详解