当前位置: 首页 > article >正文

pytorch基于GloVe实现的词嵌入

PyTorch 实现 GloVe(Global Vectors for Word Representation) 的完整代码,使用 中文语料 进行训练,包括 共现矩阵构建、模型定义、训练和测试


 1. GloVe 介绍

基于词的共现信息(不像 Word2Vec 使用滑动窗口预测)
 适合较大规模的数据(比 Word2Vec 更稳定)
学习出的词向量能捕捉语义信息(如类比关系)

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import jieba
from collections import Counter
from scipy.sparse import coo_matrix

# ========== 1. 数据预处理 ==========
corpus = [
    "我们 喜欢 深度 学习",
    "自然 语言 处理 是 有趣 的",
    "人工智能 改变 了 世界",
    "深度 学习 是 人工智能 的 重要 组成部分"
]

# 分词
tokenized_corpus = [list(jieba.cut(sentence)) for sentence in corpus]
vocab = set(word for sentence in tokenized_corpus for word in sentence)
word2idx = {word: idx for idx, word in enumerate(vocab)}
idx2word = {idx: word for word, idx in word2idx.items()}

# 计算共现矩阵
window_size = 2
co_occurrence = Counter()

for sentence in tokenized_corpus:
    indices = [word2idx[word] for word in sentence]
    for center_idx in range(len(indices)):
        center_word = indices[center_idx]
        for offset in range(-window_size, window_size + 1):
            context_idx = center_idx + offset
            if 0 <= context_idx < len(indices) and context_idx != center_idx:
                context_word = indices[context_idx]
                co_occurrence[(center_word, context_word)] += 1

# 转换为稀疏矩阵
rows, cols, values = zip(*[(c[0], c[1], v) for c, v in co_occurrence.items()])
X = coo_matrix((values, (rows, cols)), shape=(len(vocab), len(vocab)))


# ========== 2. 定义 GloVe 模型 ==========
class GloVe(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(GloVe, self).__init__()
        self.w_embeddings = nn.Embedding(vocab_size, embedding_dim)  # 中心词嵌入
        self.c_embeddings = nn.Embedding(vocab_size, embedding_dim)  # 上下文词嵌入
        self.w_bias = nn.Embedding(vocab_size, 1)  # 中心词偏置
        self.c_bias = nn.Embedding(vocab_size, 1)  # 上下文词偏置
        nn.init.xavier_uniform_(self.w_embeddings.weight)
        nn.init.xavier_uniform_(self.c_embeddings.weight)

    def forward(self, center, context, co_occur):
        w_emb = self.w_embeddings(center)
        c_emb = self.c_embeddings(context)
        w_bias = self.w_bias(center).squeeze()
        c_bias = self.c_bias(context).squeeze()
        dot_product = (w_emb * c_emb).sum(dim=1)
        loss = (dot_product + w_bias + c_bias - torch.log(co_occur + 1e-8)) ** 2
        return loss.mean()


# 初始化模型
embedding_dim = 10
model = GloVe(len(vocab), embedding_dim)

# ========== 3. 训练 GloVe ==========
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
num_epochs = 100

# 转换数据
co_occurrence_tensor = torch.tensor(X.data, dtype=torch.float)
pairs = list(zip(X.row, X.col, co_occurrence_tensor))

for epoch in range(num_epochs):
    total_loss = 0
    np.random.shuffle(pairs)
    for center, context, co_occur in pairs:
        optimizer.zero_grad()
        loss = model(
            torch.tensor([center], dtype=torch.long),
            torch.tensor([context], dtype=torch.long),
            torch.tensor([co_occur], dtype=torch.float)  # 修正数据类型
        )
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss:.4f}")

# ========== 4. 获取词向量 ==========
word_vectors = model.w_embeddings.weight.data.numpy()


# ========== 5. 计算相似度 ==========
def most_similar(word, top_n=3):
    if word not in word2idx:
        return "单词不在词汇表中"

    word_vec = word_vectors[word2idx[word]].reshape(1, -1)
    similarities = np.dot(word_vectors, word_vec.T).squeeze()
    similar_idx = similarities.argsort()[::-1][1:top_n + 1]
    return [(idx2word[idx], similarities[idx]) for idx in similar_idx]


# 测试
test_words = ["深度", "学习", "人工智能"]
for word in test_words:
    print(f"【{word}】的相似单词:", most_similar(word))

数据预处理

  • 分词(使用 jieba.cut()
  • 构建共现矩阵(计算窗口内的单词共现频率)
  • 使用稀疏矩阵存储(提高计算效率)

GloVe 模型

  • Embedding 训练词向量(中心词和上下文词分开)
  • Bias 变量 用于调整预测值
  • 损失函数 最小化 log(共现次数) 与词向量点积的差值

 计算词向量相似度

  • 使用 cosine similarity
  • 找出 top_n 最相似的单词

http://www.kler.cn/a/526840.html

相关文章:

  • 【微服务与分布式实践】探索 Eureka
  • ROS2---基础操作
  • android获取EditText内容,TextWatcher按条件触发
  • Java 分布式与微服务架构:现代企业应用开发的新范式
  • 创建 priority_queue - 进阶(内置类型)c++
  • 在AWS上使用KMS客户端密钥加密S3文件,同时支持PySpark读写和Snowflake导入
  • C++计算特定随机操作后序列元素乘积的期望
  • w182网上服装商城的设计与实现
  • 因果推断与机器学习—因果推断入门(1)
  • (动态规划路径基础 最小路径和)leetcode 64
  • 被裁与人生的意义--春节随想
  • LevelDB 源码阅读:写入键值的工程实现和优化细节
  • 云原生(五十二) | DataGrip软件使用
  • 【疑海破局】一个注解引发的线上事故
  • 基于云计算、大数据与YOLO设计的火灾/火焰目标检测
  • AJAX XML
  • 环境中的CUDA配置
  • 工业相机如何设置曝光时间
  • STM32 ADC单通道配置
  • ARIMA详细介绍
  • 性能测试 —— Tomcat监控与调优:status页监控_tomcat 自带监控
  • 运算符重载(输出运算符<<) c++
  • 使用 scikit-learn 实现简单的线性回归案例
  • 使用 Motor-CAD 脚本实现 Maxwell 电机模型的 Ansys 自动化
  • Linux网络 | 网络层IP报文解析、认识网段划分与IP地址
  • 异步编程进阶:Python 中 asyncio 的多重应用