当前位置: 首页 > article >正文

【机器学习】机器学习的基本分类-半监督学习-半监督生成对抗网络(Semi-supervised GANs)

半监督生成对抗网络(Semi-supervised GANs,简称 SGAN)是一种结合生成对抗网络(GAN)和半监督学习的模型,能够在有限标注数据和大量未标注数据的情况下训练分类器。它扩展了传统 GAN 的结构,使得判别器不仅仅用于区分真假样本,还用于对标注样本进行分类。


半监督 GAN 的核心思想

  1. 生成器 (Generator):

    • 模仿数据分布,生成与真实数据类似的样本。
    • 输入为噪声 z \sim p(z)
  2. 判别器 (Discriminator):

    • 在传统 GAN 中,判别器只是二分类器,用于区分生成样本和真实样本。
    • 在 SGAN 中,判别器被改造成多分类器,除了区分真假样本外,还负责对真实数据进行分类(有监督任务)。
  3. 未标注数据的利用:

    • SGAN 将未标注数据作为“真实样本”,它们在训练中被用来提升生成器的质量以及判别器的分类能力。

半监督 GAN 的结构

判别器的输出

判别器 DDD 的输出层被设计为 K+1 个神经元,其中:

  • 前 K 个神经元表示标注样本的 K 个类别概率。
  • 第 K+1 个神经元表示“生成样本”的概率。
损失函数
  1. 分类损失(监督):

    \mathcal{L}_{\text{supervised}} = -\mathbb{E}_{(x, y) \sim p_{\text{data}}} \log D_y(x)

    其中 D_y(x) 表示判别器对类别 y 的预测概率。

  2. 生成对抗损失:

    \mathcal{L}_{\text{unsupervised}} = -\mathbb{E}_{x \sim p_{\text{data}}} \log(1 - D_{K+1}(x)) - \mathbb{E}_{z \sim p(z)} \log D_{K+1}(G(z))

    其中 D_{K+1}(x) 表示判别器认为样本是生成样本的概率。

  3. 生成器损失:

    \mathcal{L}_G = -\mathbb{E}_{z \sim p(z)} \log(1 - D_{K+1}(G(z)))

总损失为上述损失的加权和。


半监督 GAN 的实现

以下是一个使用 TensorFlow/Keras 的 SGAN 实现示例:

import numpy as np
from keras.models import Sequential
from keras.layers import Dense, Flatten, Reshape
from keras.layers import Conv2D, Conv2DTranspose, LeakyReLU, Dropout, BatchNormalization
from keras.optimizers import Adam
from keras.losses import CategoricalCrossentropy

# 超参数
latent_dim = 100  # 随机噪声维度
image_shape = (28, 28, 1)  # 输入图像形状
num_classes = 11  # 类别数


# 创建生成器
def build_generator(latent_dim):
    model = Sequential([
        Dense(128 * 7 * 7, activation='relu', input_dim=latent_dim),
        Reshape((7, 7, 128)),
        BatchNormalization(),
        Conv2DTranspose(128, kernel_size=4, strides=2, padding='same', activation='relu'),
        BatchNormalization(),
        Conv2DTranspose(64, kernel_size=4, strides=2, padding='same', activation='relu'),
        BatchNormalization(),
        Conv2D(1, kernel_size=7, activation='tanh', padding='same')  # 输出形状为 (28, 28, 1)
    ])
    return model


# 创建判别器
def build_discriminator(image_shape, num_classes):
    model = Sequential([
        Conv2D(64, kernel_size=3, strides=2, padding='same', input_shape=image_shape),
        LeakyReLU(alpha=0.2),
        Dropout(0.3),
        Conv2D(128, kernel_size=3, strides=2, padding='same'),
        LeakyReLU(alpha=0.2),
        Dropout(0.3),
        Flatten(),
        Dense(num_classes, activation='softmax')  # 输出类别概率分布
    ])
    return model


# 定义训练过程
def train_sgan(generator, discriminator, latent_dim, X_labeled, y_labeled, X_unlabeled, epochs=10000, batch_size=64):
    # 编译判别器
    discriminator.compile(optimizer=Adam(learning_rate=0.0002, beta_1=0.5),
                          loss=CategoricalCrossentropy(),
                          metrics=['accuracy'])

    # 构建生成器-判别器联合模型
    discriminator.trainable = False
    sgan = Sequential([generator, discriminator])
    sgan.compile(optimizer=Adam(learning_rate=0.0002, beta_1=0.5), loss=CategoricalCrossentropy())

    for epoch in range(epochs):
        # 生成虚假样本
        noise = np.random.normal(0, 1, size=(batch_size, latent_dim))
        fake_images = generator.predict(noise)
        fake_labels = np.eye(num_classes)[np.random.choice(num_classes, size=batch_size)]

        # 训练生成器
        generator_loss = sgan.train_on_batch(noise, fake_labels)

        # 训练判别器
        idx = np.random.randint(0, X_unlabeled.shape[0], batch_size)
        real_images = X_unlabeled[idx]
        real_labels = np.eye(num_classes)[np.random.choice(num_classes, size=batch_size)]

        d_loss_real = discriminator.train_on_batch(real_images, real_labels)
        d_loss_fake = discriminator.train_on_batch(fake_images, fake_labels)

        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        # 打印日志
        if epoch % 100 == 0:
            print(
                f"Epoch {epoch}/{epochs} | D Loss: {d_loss[0]:.4f}, D Accuracy: {d_loss[1]:.4f} | G Loss: {generator_loss:.4f}")


# 数据加载示例(使用 MNIST 数据)
from keras.datasets import mnist
from keras.utils import to_categorical

(X_train, y_train), (_, _) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5) / 127.5  # 归一化到 [-1, 1]
X_train = np.expand_dims(X_train, axis=-1)  # 转换为 (N, 28, 28, 1) 格式
y_train = to_categorical(y_train, num_classes=num_classes)

# 拆分有标签和无标签数据
X_labeled = X_train[:1000]
y_labeled = y_train[:1000]
X_unlabeled = X_train[1000:]

# 初始化模型
generator = build_generator(latent_dim)
discriminator = build_discriminator(image_shape, num_classes)

# 训练 SGAN
train_sgan(generator, discriminator, latent_dim, X_labeled, y_labeled, X_unlabeled)

部分结果

2/2 [==============================] - 0s 21ms/step
Epoch 0/10000 | D Loss: 2.4036, D Accuracy: 0.0859 | G Loss: 2.4005
2/2 [==============================] - 0s 20ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 34ms/step
2/2 [==============================] - 0s 18ms/step
2/2 [==============================] - 0s 20ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 33ms/step
2/2 [==============================] - 0s 18ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 20ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 24ms/step
2/2 [==============================] - 0s 31ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 33ms/step
2/2 [==============================] - 0s 24ms/step
2/2 [==============================] - 0s 23ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 18ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 25ms/step
2/2 [==============================] - 0s 27ms/step
2/2 [==============================] - 0s 18ms/step
2/2 [==============================] - 0s 34ms/step
2/2 [==============================] - 0s 18ms/step
2/2 [==============================] - 0s 20ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 20ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 18ms/step
2/2 [==============================] - 0s 35ms/step
2/2 [==============================] - 0s 18ms/step
2/2 [==============================] - 0s 20ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 21ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 24ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 20ms/step
2/2 [==============================] - 0s 36ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 18ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 18ms/step
2/2 [==============================] - 0s 20ms/step
2/2 [==============================] - 0s 20ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 39ms/step
2/2 [==============================] - 0s 22ms/step
2/2 [==============================] - 0s 28ms/step
2/2 [==============================] - 0s 20ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 23ms/step
2/2 [==============================] - 0s 22ms/step
2/2 [==============================] - 0s 24ms/step
2/2 [==============================] - 0s 55ms/step
2/2 [==============================] - 0s 20ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 23ms/step
2/2 [==============================] - 0s 20ms/step
2/2 [==============================] - 0s 18ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 20ms/step
2/2 [==============================] - 0s 22ms/step
2/2 [==============================] - 0s 33ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 21ms/step
2/2 [==============================] - 0s 30ms/step
2/2 [==============================] - 0s 18ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 34ms/step
2/2 [==============================] - 0s 26ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 20ms/step
2/2 [==============================] - 0s 20ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 34ms/step
2/2 [==============================] - 0s 47ms/step
2/2 [==============================] - 0s 29ms/step
2/2 [==============================] - 0s 18ms/step
2/2 [==============================] - 0s 20ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 31ms/step
2/2 [==============================] - 0s 19ms/step
Epoch 100/10000 | D Loss: 2.3941, D Accuracy: 0.1250 | G Loss: 2.3827
2/2 [==============================] - 0s 20ms/step
2/2 [==============================] - 0s 19ms/step

总结

半监督 GAN 的核心在于将判别器扩展为多分类器,充分利用未标注数据和生成样本的对抗训练,提升分类器性能。相比传统的 GAN 和全监督学习方法,SGAN 能在标注数据不足的情况下取得更好的分类效果。


http://www.kler.cn/a/460035.html

相关文章:

  • SpringBoot异步线程@Async的使用注意
  • Dell服务器升级ubuntu 22.04失败解决
  • 机器学习之逻辑回归算法、数据标准化处理及数据预测和数据的分类结果报告
  • Python 列表的高级索引技巧
  • Datawhale AI冬令营(第二期)动手学AI Agent--Task3:学Agent工作流搭建,创作进阶Agent
  • C# 设计模式的六大原则(SOLID)
  • Effective C++ 条款41:了解隐式接口和编译期多态
  • mysql只恢复某个库或某个表
  • 算法环境安装GPU驱动、CUDA、cuDNN、Docker及NVIDIA Container Toolkit
  • node.js文件压缩包解析,反馈解析进度,解析后的文件字节正常
  • Ungoogled Chromium127编译指南 Linux篇 - 项目要求(二)
  • 华为,新华三,思科网络设备指令
  • 异步爬虫之aiohttp的使用
  • fetch请求代码
  • 大数据_HBase的列族属性配置
  • Kotlin 协程基础知识总结四 —— Flow
  • 基于PyQt5的UI界面开发——图像与视频的加载与显示
  • Java爬虫获取速卖通(AliExpress)商品详情
  • SpringAI从入门到熟练
  • Linux day 1203
  • 41.1 预聚合提速实战项目之需求分析和架构设计
  • C++通讯录管理系统
  • 9. 大数据集群(PySpark)+Hive+MySQL+PyEcharts+Flask:信用贷款风险分析与预测
  • DotnetSpider实现网络爬虫
  • 云轴科技ZStack获评OpenCloudOS社区2024年度优秀贡献单位
  • C++ 设计模式:门面模式(Facade Pattern)