当前位置: 首页 > article >正文

pytorch基于FastText实现词嵌入

FastText 是 Facebook AI Research 提出的 改进版 Word2Vec,可以: ✅ 利用 n-grams 处理未登录词
比 Word2Vec 更快、更准确
适用于中文等形态丰富的语言

完整的 PyTorch FastText 代码(基于中文语料),包含:

  • 数据预处理(分词 + n-grams)
  • 模型定义
  • 训练
  • 测试
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import jieba
from collections import Counter
import random

# ========== 1. 数据预处理 ==========
corpus = [
    "我们 喜欢 深度 学习",
    "自然 语言 处理 是 有趣 的",
    "人工智能 改变 了 世界",
    "深度 学习 是 人工智能 的 重要 组成部分"
]

# 分词
tokenized_corpus = [list(jieba.cut(sentence)) for sentence in corpus]


# 构建 n-grams
def generate_ngrams(words, n=3):
    ngrams = []
    for word in words:
        ngrams += [word[i:i + n] for i in range(len(word) - n + 1)]
    return ngrams


# 生成 n-grams 词表
all_ngrams = set()
for sentence in tokenized_corpus:
    for word in sentence:
        all_ngrams.update(generate_ngrams(word))

# 构建词汇表
vocab = set(word for sentence in tokenized_corpus for word in sentence) | all_ngrams
word2idx = {word: idx for idx, word in enumerate(vocab)}
idx2word = {idx: word for word, idx in word2idx.items()}

# 构建训练数据(CBOW 方式)
window_size = 2
data = []

for sentence in tokenized_corpus:
    indices = [word2idx[word] for word in sentence]
    for center_idx in range(len(indices)):
        context = []
        for offset in range(-window_size, window_size + 1):
            context_idx = center_idx + offset
            if 0 <= context_idx < len(indices) and context_idx != center_idx:
                context.append(indices[context_idx])
        if context:
            data.append((context, indices[center_idx]))  # (上下文, 目标词)


# ========== 2. 定义 FastText 模型 ==========
class FastText(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(FastText, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.linear = nn.Linear(embedding_dim, vocab_size)

    def forward(self, context):
        context_vec = self.embeddings(context).mean(dim=1)  # 平均上下文向量
        output = self.linear(context_vec)
        return output


# 初始化模型
embedding_dim = 10
model = FastText(len(vocab), embedding_dim)

# ========== 3. 训练 FastText ==========
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
num_epochs = 100

for epoch in range(num_epochs):
    total_loss = 0
    random.shuffle(data)

    for context, target in data:
        context = torch.tensor([context], dtype=torch.long)
        target = torch.tensor([target], dtype=torch.long)

        optimizer.zero_grad()
        output = model(context)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss:.4f}")

# ========== 4. 获取词向量 ==========
word_vectors = model.embeddings.weight.data.numpy()


# ========== 5. 计算相似度 ==========
def most_similar(word, top_n=3):
    if word not in word2idx:
        return "单词不在词汇表中"

    word_vec = word_vectors[word2idx[word]].reshape(1, -1)
    similarities = np.dot(word_vectors, word_vec.T).squeeze()
    similar_idx = similarities.argsort()[::-1][1:top_n + 1]
    return [(idx2word[idx], similarities[idx]) for idx in similar_idx]


# 测试
test_words = ["深度", "学习", "人工智能"]
for word in test_words:
    print(f"【{word}】的相似单词:", most_similar(word))

1. 生成 n-grams

  • FastText 处理单词的 子词单元(n-grams)
  • 例如 "学习" 会生成 ["学习", "习学", "学"]
  • 这样即使遇到未登录词也能拆分为 n-grams 计算

2. 训练数据

  • 使用 CBOW(上下文预测中心词)
  • 窗口大小 = 2,即:
    句子: ["深度", "学习", "是", "人工智能"]
    示例: (["深度", "是"], "学习")
    

3. FastText 模型

  • 词向量是 n-grams 词向量的平均值
  • 计算公式: 
  • 这样,即使单词没见过,也能用它的 n-grams 计算词向量!

 4. 计算相似度

  • cosine similarity 找出最相似的单词
  • FastText 比 Word2Vec 更准确,因为它能利用 n-grams 捕捉词的语义信息
特性FastTextWord2VecGloVe
原理预测中心词 + n-grams预测中心词或上下文统计词共现信息
未登录词处理可处理无法处理无法处理
训练速度 快
适合领域中文、罕见词传统 NLP大规模数据

http://www.kler.cn/a/527602.html

相关文章:

  • FreeRTOS从入门到精通 第十五章(事件标志组)
  • HarmonyOS简介:HarmonyOS核心技术理念
  • SAP内向交货单详解
  • SpringBoot笔记
  • 浅谈AI的发展对IT行业的影响
  • React 的 12 个核心概念
  • java求职学习day23
  • 指针(C语言)从0到1掌握指针〕带你探究计算机神奇的秘密
  • autogen 自定义agent (1)
  • 基于排队理论的物联网发布/订阅通信系统建模与优化
  • 第二讲:类与对象(上)
  • deepseek大模型本机部署
  • OSCP:常见文件传输方法
  • OSCP 渗透测试:网络抓包工具的使用指南
  • Java多线程——对象的共享
  • DeepSeek本地部署(windows)
  • 软件测试(认识测试)
  • 无人机图传模块 wfb-ng openipc-fpv,4G
  • 【易理解】04_什么是try-catch-throw语句?
  • socket编程短平快
  • 计算机网络一点事(24)
  • 漏洞扫描工具之xray
  • 【视频+图文讲解】HTML基础2-html骨架与基本语法
  • OpenCV:Harris、Shi-Tomasi角点检测
  • 【小白学AI系列】NLP 核心知识点(六)Softmax函数介绍
  • 如何优化轮式移动机器人的运动稳定性?