【机器学习】机器学习的基本分类-半监督学习-半监督生成对抗网络(Semi-supervised GANs)
半监督生成对抗网络(Semi-supervised GANs,简称 SGAN)是一种结合生成对抗网络(GAN)和半监督学习的模型,能够在有限标注数据和大量未标注数据的情况下训练分类器。它扩展了传统 GAN 的结构,使得判别器不仅仅用于区分真假样本,还用于对标注样本进行分类。
半监督 GAN 的核心思想
-
生成器 (Generator):
- 模仿数据分布,生成与真实数据类似的样本。
- 输入为噪声 。
-
判别器 (Discriminator):
- 在传统 GAN 中,判别器只是二分类器,用于区分生成样本和真实样本。
- 在 SGAN 中,判别器被改造成多分类器,除了区分真假样本外,还负责对真实数据进行分类(有监督任务)。
-
未标注数据的利用:
- SGAN 将未标注数据作为“真实样本”,它们在训练中被用来提升生成器的质量以及判别器的分类能力。
半监督 GAN 的结构
判别器的输出
判别器 DDD 的输出层被设计为 K+1 个神经元,其中:
- 前 K 个神经元表示标注样本的 K 个类别概率。
- 第 K+1 个神经元表示“生成样本”的概率。
损失函数
-
分类损失(监督):
其中 表示判别器对类别 y 的预测概率。
-
生成对抗损失:
其中 表示判别器认为样本是生成样本的概率。
-
生成器损失:
总损失为上述损失的加权和。
半监督 GAN 的实现
以下是一个使用 TensorFlow/Keras 的 SGAN 实现示例:
import numpy as np
from keras.models import Sequential
from keras.layers import Dense, Flatten, Reshape
from keras.layers import Conv2D, Conv2DTranspose, LeakyReLU, Dropout, BatchNormalization
from keras.optimizers import Adam
from keras.losses import CategoricalCrossentropy
# 超参数
latent_dim = 100 # 随机噪声维度
image_shape = (28, 28, 1) # 输入图像形状
num_classes = 11 # 类别数
# 创建生成器
def build_generator(latent_dim):
model = Sequential([
Dense(128 * 7 * 7, activation='relu', input_dim=latent_dim),
Reshape((7, 7, 128)),
BatchNormalization(),
Conv2DTranspose(128, kernel_size=4, strides=2, padding='same', activation='relu'),
BatchNormalization(),
Conv2DTranspose(64, kernel_size=4, strides=2, padding='same', activation='relu'),
BatchNormalization(),
Conv2D(1, kernel_size=7, activation='tanh', padding='same') # 输出形状为 (28, 28, 1)
])
return model
# 创建判别器
def build_discriminator(image_shape, num_classes):
model = Sequential([
Conv2D(64, kernel_size=3, strides=2, padding='same', input_shape=image_shape),
LeakyReLU(alpha=0.2),
Dropout(0.3),
Conv2D(128, kernel_size=3, strides=2, padding='same'),
LeakyReLU(alpha=0.2),
Dropout(0.3),
Flatten(),
Dense(num_classes, activation='softmax') # 输出类别概率分布
])
return model
# 定义训练过程
def train_sgan(generator, discriminator, latent_dim, X_labeled, y_labeled, X_unlabeled, epochs=10000, batch_size=64):
# 编译判别器
discriminator.compile(optimizer=Adam(learning_rate=0.0002, beta_1=0.5),
loss=CategoricalCrossentropy(),
metrics=['accuracy'])
# 构建生成器-判别器联合模型
discriminator.trainable = False
sgan = Sequential([generator, discriminator])
sgan.compile(optimizer=Adam(learning_rate=0.0002, beta_1=0.5), loss=CategoricalCrossentropy())
for epoch in range(epochs):
# 生成虚假样本
noise = np.random.normal(0, 1, size=(batch_size, latent_dim))
fake_images = generator.predict(noise)
fake_labels = np.eye(num_classes)[np.random.choice(num_classes, size=batch_size)]
# 训练生成器
generator_loss = sgan.train_on_batch(noise, fake_labels)
# 训练判别器
idx = np.random.randint(0, X_unlabeled.shape[0], batch_size)
real_images = X_unlabeled[idx]
real_labels = np.eye(num_classes)[np.random.choice(num_classes, size=batch_size)]
d_loss_real = discriminator.train_on_batch(real_images, real_labels)
d_loss_fake = discriminator.train_on_batch(fake_images, fake_labels)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# 打印日志
if epoch % 100 == 0:
print(
f"Epoch {epoch}/{epochs} | D Loss: {d_loss[0]:.4f}, D Accuracy: {d_loss[1]:.4f} | G Loss: {generator_loss:.4f}")
# 数据加载示例(使用 MNIST 数据)
from keras.datasets import mnist
from keras.utils import to_categorical
(X_train, y_train), (_, _) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5) / 127.5 # 归一化到 [-1, 1]
X_train = np.expand_dims(X_train, axis=-1) # 转换为 (N, 28, 28, 1) 格式
y_train = to_categorical(y_train, num_classes=num_classes)
# 拆分有标签和无标签数据
X_labeled = X_train[:1000]
y_labeled = y_train[:1000]
X_unlabeled = X_train[1000:]
# 初始化模型
generator = build_generator(latent_dim)
discriminator = build_discriminator(image_shape, num_classes)
# 训练 SGAN
train_sgan(generator, discriminator, latent_dim, X_labeled, y_labeled, X_unlabeled)
部分结果
2/2 [==============================] - 0s 21ms/step
Epoch 0/10000 | D Loss: 2.4036, D Accuracy: 0.0859 | G Loss: 2.4005
2/2 [==============================] - 0s 20ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 34ms/step
2/2 [==============================] - 0s 18ms/step
2/2 [==============================] - 0s 20ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 33ms/step
2/2 [==============================] - 0s 18ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 20ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 24ms/step
2/2 [==============================] - 0s 31ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 33ms/step
2/2 [==============================] - 0s 24ms/step
2/2 [==============================] - 0s 23ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 18ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 25ms/step
2/2 [==============================] - 0s 27ms/step
2/2 [==============================] - 0s 18ms/step
2/2 [==============================] - 0s 34ms/step
2/2 [==============================] - 0s 18ms/step
2/2 [==============================] - 0s 20ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 20ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 18ms/step
2/2 [==============================] - 0s 35ms/step
2/2 [==============================] - 0s 18ms/step
2/2 [==============================] - 0s 20ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 21ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 24ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 20ms/step
2/2 [==============================] - 0s 36ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 18ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 18ms/step
2/2 [==============================] - 0s 20ms/step
2/2 [==============================] - 0s 20ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 39ms/step
2/2 [==============================] - 0s 22ms/step
2/2 [==============================] - 0s 28ms/step
2/2 [==============================] - 0s 20ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 23ms/step
2/2 [==============================] - 0s 22ms/step
2/2 [==============================] - 0s 24ms/step
2/2 [==============================] - 0s 55ms/step
2/2 [==============================] - 0s 20ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 23ms/step
2/2 [==============================] - 0s 20ms/step
2/2 [==============================] - 0s 18ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 20ms/step
2/2 [==============================] - 0s 22ms/step
2/2 [==============================] - 0s 33ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 21ms/step
2/2 [==============================] - 0s 30ms/step
2/2 [==============================] - 0s 18ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 34ms/step
2/2 [==============================] - 0s 26ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 20ms/step
2/2 [==============================] - 0s 20ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 34ms/step
2/2 [==============================] - 0s 47ms/step
2/2 [==============================] - 0s 29ms/step
2/2 [==============================] - 0s 18ms/step
2/2 [==============================] - 0s 20ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 19ms/step
2/2 [==============================] - 0s 31ms/step
2/2 [==============================] - 0s 19ms/step
Epoch 100/10000 | D Loss: 2.3941, D Accuracy: 0.1250 | G Loss: 2.3827
2/2 [==============================] - 0s 20ms/step
2/2 [==============================] - 0s 19ms/step
总结
半监督 GAN 的核心在于将判别器扩展为多分类器,充分利用未标注数据和生成样本的对抗训练,提升分类器性能。相比传统的 GAN 和全监督学习方法,SGAN 能在标注数据不足的情况下取得更好的分类效果。