当前位置: 首页 > article >正文

【机器学习】机器学习的基本分类-自监督学习-对比学习(Contrastive Learning)

对比学习是一种自监督学习方法,其目标是学习数据的表征(representation),使得在表征空间中,相似的样本距离更近,不相似的样本距离更远。通过设计对比损失函数(Contrastive Loss),模型能够有效捕捉数据的语义结构。


核心思想

对比学习的关键在于:

  1. 正样本(Positive Pair):具有相似语义或来源的样本对,例如同一图像的不同增强版本。
  2. 负样本(Negative Pair):语义不同或来源不同的样本对,例如不同图像。

通过对比正负样本对,模型能够学习区分不同数据点的特征。


方法流程

  1. 数据增强:对一个样本 x 应用两种不同的增强方法,生成 x_1, x_2​,作为正样本对。
  2. 特征提取:通过编码器(如卷积神经网络)将数据映射到潜在特征空间,得到表征 z_1, z_2
  3. 对比损失:设计损失函数,使正样本对的表征距离最小化,负样本对的表征距离最大化。

对比学习的损失函数

1. 对比损失(Contrastive Loss)

对比损失鼓励正样本对的距离更小,负样本对的距离更大。

L = \frac{1}{N} \sum_{i=1}^N \left[ y_i \cdot d(z_i, z_j)^2 + (1 - y_i) \cdot \max(0, m - d(z_i, z_j))^2 \right]

  • y_i:样本对是否为正样本(1 表示正样本,0 表示负样本)。
  • d(z_i, z_j):样本对在表征空间中的距离(通常使用欧氏距离)。
  • m:负样本对的最小距离(margin)。
2. InfoNCE 损失

用于最大化正样本对的相似性,同时将负样本对的相似性最小化。

L = - \log \frac{\exp(\text{sim}(z_i, z_j) / \tau)}{\sum_{k=1}^{N} \exp(\text{sim}(z_i, z_k) / \tau)}

  • \text{sim}(z_i, z_j) = \frac{z_i \cdot z_j}{\|z_i\| \|z_j\|}:余弦相似度。
  • \tau:温度参数,用于控制分布的平滑程度。
  • N:批量中样本数量。

典型方法

1. SimCLR

SimCLR 是对比学习的经典方法之一:

  • 核心思想:通过数据增强生成正样本对,并利用 InfoNCE 损失函数进行优化。
  • 数据增强:随机裁剪、颜色抖动、模糊等。
2. MoCo(Momentum Contrast)

通过维护一个动态更新的“字典”,解决负样本数量不足的问题。

  • 核心思想:使用动量编码器(momentum encoder)生成更多的负样本。
3. BYOL(Bootstrap Your Own Latent)

无需显式的负样本,通过自回归(self-prediction)学习特征表征。

  • 核心思想:一个在线网络(Online Network)和一个目标网络(Target Network)协同训练。
4. SWAV(Swapping Assignments Between Views)

结合聚类和对比学习,利用图像的多视图表征。

  • 核心思想:通过在线分配伪标签,避免显式使用负样本。

示例代码:SimCLR

以下是一个实现 SimCLR 的示例代码:

import tensorflow as tf
from tensorflow.keras import layers, models


# 图像增强函数
def augment_image(image):
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_crop(image, size=(32, 32, 3))
    image = tf.image.random_brightness(image, max_delta=0.5)
    return image


# 定义编码器
def create_encoder():
    base_model = tf.keras.applications.ResNet50(include_top=False, pooling='avg', input_shape=(32, 32, 3))
    return models.Model(inputs=base_model.input, outputs=base_model.output)


# SimCLR 模型
class SimCLRModel(tf.keras.Model):
    def __init__(self, encoder, projection_dim):
        super(SimCLRModel, self).__init__()
        self.encoder = encoder
        self.projection_head = tf.keras.Sequential([
            layers.Dense(256, activation='relu'),
            layers.Dense(projection_dim)
        ])

    def call(self, x):
        features = self.encoder(x)
        projections = self.projection_head(features)
        return tf.math.l2_normalize(projections, axis=1)


# 构建模型
encoder = create_encoder()
simclr_model = SimCLRModel(encoder, projection_dim=128)


# InfoNCE 损失
def info_nce_loss(features, temperature=0.5):
    batch_size = tf.shape(features)[0]
    labels = tf.range(batch_size)
    similarity_matrix = tf.matmul(features, features, transpose_b=True)
    logits = similarity_matrix / temperature
    return tf.reduce_mean(tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True))


# 训练
(X_train, _), _ = tf.keras.datasets.cifar10.load_data()
X_train = tf.image.resize(X_train, (32, 32)) / 255.0


def preprocess_data(image):
    return augment_image(image), augment_image(image)


train_data = tf.data.Dataset.from_tensor_slices(X_train)
train_data = train_data.map(preprocess_data).batch(32)

optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)

for epoch in range(10):
    for x1, x2 in train_data:
        with tf.GradientTape() as tape:
            z1 = simclr_model(x1)
            z2 = simclr_model(x2)
            loss = info_nce_loss(tf.concat([z1, z2], axis=0))
        gradients = tape.gradient(loss, simclr_model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, simclr_model.trainable_variables))
    print(f"Epoch {epoch + 1}, Loss: {loss.numpy()}")

输出结果

Epoch 1, Loss: 3.465735912322998
Epoch 2, Loss: 3.465735912322998
Epoch 3, Loss: 3.465735912322998
Epoch 4, Loss: 3.465735912322998
Epoch 5, Loss: 3.465735912322998

对比学习的优势与挑战

优势
  1. 无需标签数据:适用于大规模无标签数据集。
  2. 高质量特征:学习的表征具有很强的迁移能力。
  3. 通用性强:适用于图像、文本、语音等多种模态。
挑战
  1. 负样本选择:负样本数量和质量对性能影响大。
  2. 计算成本:对比学习需要大量计算资源,尤其是在大规模数据上训练。
  3. 超参数调整:温度参数等对模型表现至关重要。

http://www.kler.cn/a/460837.html

相关文章:

  • CDP集群安全指南-静态数据加密
  • UE5通过蓝图节点控制材质参数
  • 高频java面试题
  • Javascript数据结构——图Graph
  • pip下载包出现SSLError
  • 2000-2020年各省财政一般预算支出面板数据
  • 解决vue-i18n在非.vue文件中,在其他js文件中无法使用的问题
  • Ubuntu 搭建SVN服务
  • 探索基于WebAssembly的下一代前端性能优化方案
  • 如何在谷歌浏览器中使用自定义CSS
  • 在pytest钩子函数中判断Android和iOS设备(方法一)
  • 【2024年-5月-28日-开源社区openEuler实践记录】走进 GCC:开源编译器的传奇之旅
  • ACE之ACE_Message_Queue
  • 《Java核心技术II》抽取子流和组合流
  • 攻破 kioprix level 4 靶机
  • C++语言编程————C++数据类型与表达式
  • 期权懂|国内场外期权都有哪些种类?
  • MybatisPlus查询更so easy
  • 数据结构与算法之动态规划: LeetCode 62. 不同路径 (Ts版)
  • 非常简单实用的前后端分离项目-仓库管理系统(Springboot+Vue)part 5(未实现预期目标)
  • Pytest 高级用法:间接参数化
  • 25考研希望渺茫,工作 VS 二战,怎么选?
  • 2024年RAG:回顾与展望
  • KEGG大更新:开启生物研究新纪元
  • 物联网技术在电商API接口中的应用实践
  • Spring Boot中使用Zookeeper实现分布式锁的案例