当前位置: 首页 > article >正文

jiehun_DEMO

import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt


# 定义卷积自编码器模型
def build_autoencoder(conv_filters=48, learning_rate=0.001, patch_size=5, num_bands=330):
    input_layer = layers.Input(shape=(patch_size, patch_size, num_bands))

    # diff
    # x = layers.Conv2D(conv_filters, (3, 3), activation='leaky_relu', padding='same')(input_layer)
    x = layers.Conv2D(conv_filters, (3, 3), padding='same')(input_layer)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    # diff
    # x = layers.Conv2D(16, (1, 1), activation='leaky_relu', padding='same')(x)
    x = layers.Conv2D(16, (1, 1), padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    abundance_layer = layers.Lambda(lambda x: tf.nn.softmax(x * 3.5), output_shape=lambda input_shape: input_shape)(x)

    decoded = layers.Conv2D(num_bands, (3, 3), activation='linear', padding='same')(abundance_layer)

    autoencoder = models.Model(inputs=input_layer, outputs=decoded)

    autoencoder.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), loss='mse')

    return autoencoder


# 数据生成器函数
def preprocess_data_generator(hsi_data, patch_size=5):
    H, W, B = hsi_data.shape
    for i in tqdm(range(0, H - patch_size + 1, patch_size), desc="Generating patches"):
        for j in range(0, W - patch_size + 1, patch_size):
            patch = hsi_data[i:i + patch_size, j:j + patch_size, :].astype(np.float32)
            yield patch.reshape(1, patch_size, patch_size, B)


# 手动定义超参数
conv_filters = 64  # 卷积滤波器数量
# diff
# learning_rate = 0.001  # 学习率
learning_rate = 0.0001  # 学习率

# 读取和处理高光谱数据(这里假设你已经加载了数据)
hsi_data = np.random.rand(2432, 2372, 330)  # 模拟数据,替换为真实的高光谱数据

# 构建卷积自编码器
autoencoder = build_autoencoder(conv_filters=conv_filters, learning_rate=learning_rate, num_bands=hsi_data.shape[2])

# 打印模型结构
autoencoder.summary()

# 准备数据
data_generator = preprocess_data_generator(hsi_data)

# 计算每个 epoch 的步骤
steps_per_epoch = (hsi_data.shape[0] // 5) * (hsi_data.shape[1] // 5) // 32

# 清理计算图
tf.keras.backend.clear_session()

# 训练模型
for epoch in range(10):  # 设置总的 epoch 数量
    print(f"Epoch {epoch + 1}/10")
    for step in tqdm(range(steps_per_epoch)):
        x_batch = next(data_generator)
        autoencoder.train_on_batch(x_batch, x_batch)


# 获取丰度图
def get_abundance_maps(model, hsi_data, patch_size=5):
    abundance_maps = []
    H, W, B = hsi_data.shape
    for i in tqdm(range(0, H - patch_size + 1, patch_size), desc="Extracting abundance maps"):
        for j in range(0, W - patch_size + 1, patch_size):
            patch = hsi_data[i:i + patch_size, j:j + patch_size, :].astype(np.float32)
            abundance_map = model.predict(patch.reshape(1, patch_size, patch_size, B))
            abundance_maps.append(abundance_map.reshape(patch_size, patch_size, B))

    return np.array(abundance_maps)


abundance_maps = get_abundance_maps(autoencoder, hsi_data)


# 展示丰度图
def display_abundance_maps(abundance_maps, num_bands):
    plt.figure(figsize=(15, 15))
    for i in range(num_bands):
        plt.subplot(10, 10, i + 1)
        plt.imshow(abundance_maps[i], cmap='jet')
        plt.axis('off')
        plt.title(f'Band {i + 1}')
    plt.tight_layout()
    plt.show()


# 假设我们只展示前 10 个丰度图
display_abundance_maps(abundance_maps, 10)


# 进行聚类分析
def cluster_abundance_maps(abundance_maps, num_clusters=5):
    reshaped_abundance = abundance_maps.reshape(-1, abundance_maps.shape[2])
    # diff
    # kmeans = KMeans(n_clusters=num_clusters)
    kmeans = KMeans(n_clusters=num_clusters, init='k-means++')

    print("Clustering abundance maps...")
    kmeans.fit(reshaped_abundance)
    cluster_labels = kmeans.labels_.reshape(abundance_maps.shape[0], abundance_maps.shape[1])
    return cluster_labels


# 聚类岩性
num_clusters = 5
cluster_labels = cluster_abundance_maps(abundance_maps, num_clusters)

# 可视化聚类结果
plt.figure(figsize=(8, 8))
plt.imshow(cluster_labels, cmap='jet')
plt.title('Clustered Lithology Map')
plt.axis('off')
plt.show()

http://www.kler.cn/news/360226.html

相关文章:

  • without OpenSSL
  • 【ArcGIS微课1000例】0125:ArcGIS矢量化无法自动完成面解决方案
  • itext自定义pdf
  • 【Python实战】---- 自动生成前端项目图标管理文件
  • windows安装mysql,跳过自定义的密码验证
  • 【力扣打卡系列】滑动窗口与双指针(两数之和)
  • “射线沿其正向平移可变为其真子集”这一中学“常识”其实是几百年重大错误——百年病态集论的症结
  • 【Qt】绘图API
  • YashanDB学习-服务启停
  • 【Java 22 | 7】 深入解析Java 22 :密封类(Sealed Classes)增强详解
  • LTD助力经营数字化,浙商数智营销学堂开讲入站营销新理念
  • 【视频编码】视频编码中拉格朗日乘子法的简单理解
  • 基于SSM+微信小程序的家庭记账本管理系统(家庭1)
  • 08_实现 reactive
  • DAPLINK 源码学习(1)BL 之 main() 函数
  • typescript 的类型注解和类型断言
  • C#学习笔记(十)
  • 拥抱“新市民” ,数字银行的“谋与变”
  • jetson agx orin 的pytorch、torchvision安装
  • el-table表格数据处理,列表将变更前数据放置在前面,变更后数据放在表格后面