当前位置: 首页 > article >正文

TextCNN:文本卷积神经网络模型

目录

  • 什么是TextCNN
  • 定义TextCNN类
  • 初始化一个model实例
  • 输出model

什么是TextCNN

  • TextCNN(Text Convolutional Neural Network)是一种用于处理文本数据的卷积神经网(CNN)。通过在文本数据上应用卷积操作来提取局部特征,这些特征可以捕捉到文本中的局部模式,如n-gram(连续的n个单词或字符)。

定义TextCNN类

import torch.nn as nn

# 它继承自 PyTorch 的 nn.Module
class TextCNN(nn.Module):
	# __init__:类的构造函数,初始化模型,包括嵌入层、卷积层、dropout层和全连接层
    def __init__(self, vocab_size, embed_dim, num_classes, num_filters, kernel_sizes):
    	# 调用父类 nn.Module 的构造函数
        super(TextCNN, self).__init__()
        # 创建一个嵌入层,将词汇表中的每个单词映射到一个embed_dim 维的向量空间。vocab_size 是词汇表的大小
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        # 创建一个卷积层列表,每个卷积层使用不同的 kernel_size。in_channels 是嵌入向量的维度,out_channels 是每个卷积核输出的特征数量
        self.convs = nn.ModuleList([nn.Conv1d(in_channels=embed_dim, out_channels=num_filters, kernel_size=k) for k in kernel_sizes])
        # 创建一个 Dropout 层,用于在训练过程中随机丢弃 50% 的节点,以减少过拟合
        self.dropout = nn.Dropout(0.5)
        # 创建一个全连接层,将卷积层的输出连接到最终的分类结果。
        # 输入特征的数量是卷积核数量乘以每个卷积核的输出特征数量,输出特征数量是分类类别的数量
        self.fc = nn.Linear(len(kernel_sizes) * num_filters, num_classes)
	
	# forward:定义模型的前向传播过程。
	# x:输入数据,通常是文本的整数序列。
    def forward(self, x):
    	# 将输入数据通过嵌入层转换为嵌入向量
        x = self.embedding(x)  
        # 调整张量维度,以便卷积操作可以在嵌入向量的维度上进行
        x = x.transpose(1, 2) 
        # 对每个卷积层应用激活函数 ReLU,生成特征图
        convs = [torch.relu(conv(x)) for conv in self.convs]  
        # 对每个卷积层的输出应用最大池化,以减少特征图的维度
        pooled = [torch.max_pool1d(conv, conv.size(2)).squeeze(2) for conv in convs]  
        # 将所有卷积层的最大池化结果拼接在一起,形成一个单一的特征向量
        cat = torch.cat(pooled, 1) 
        # 通过 Dropout 层和全连接层进行分类,输出最终的分类结果
        return self.fc(self.dropout(cat))

初始化一个model实例

vocab_size = 1000
embed_dim = 128
num_classes = 2
num_filters = 100
kernel_sizes = [3, 4, 5]

model = TextCNN(vocab_size, embed_dim, num_classes, num_filters, kernel_sizes)

输出model

TextCNN(
  (embedding): Embedding(8, 128)
  (convs): ModuleList(
    (0): Conv1d(128, 100, kernel_size=(3,), stride=(1,))
    (1): Conv1d(128, 100, kernel_size=(4,), stride=(1,))
    (2): Conv1d(128, 100, kernel_size=(5,), stride=(1,))
  )
  (dropout): Dropout(p=0.5, inplace=False)
  (fc): Linear(in_features=300, out_features=2, bias=True)
)
  • Embedding(8, 128):这是一个嵌入层,它将词汇表中的每个单词映射到一个128维的向量空间。这里的8表示词汇表的大小(即输入序列中可能的最大单词索引),128表示每个单词将被映射到的向量维度。
  • convs: ModuleList[...]:这是一个包含多个一维卷积层(Conv1d)的模块列表。每个卷积层都用于提取文本数据的不同局部特征。
  • Conv1d(128, 100, kernel_size=(3,), stride=(1,)):每个卷积层有128个输入通道(与嵌入层的输出维度相同)和100个输出通道(即100个滤波器)。kernel_size=3表示每个滤波器的窗口大小为3个词。stride=1表示滤波器在文本序列上滑动的步长为1
  • Dropout(p=0.5, inplace=False):这是一个Dropout层,它在训练过程中随机丢弃50%的节点,以减少过拟合。inplace=False表示Dropout操作不会在原地修改输入张量。
  • fc: Linear(in_features=300, out_features=2, bias=True):这是一个全连接层,它将卷积层和Dropout层的输出转换为最终的分类结果。in_features=300表示全连接层的输入特征数量(这是由卷积层的数量和每个卷积层的输出特征数量决定的,即3个卷积层各100个特征)。out_features=2表示输出特征的数量,这通常与分类任务的类别数相对应(在这个例子中,可能是二分类问题)。bias=True表示全连接层的权重将包含偏置项。

http://www.kler.cn/news/306770.html

相关文章:

  • 【安全漏洞】Java-WebSocket 信任管理漏洞
  • 拓扑排序专题篇
  • 前端基础知识(HTML+CSS+JavaScript)
  • 828 华为云征文|华为 Flexus 云服务器搭建萤火商城 2.0
  • 【Go - 类型断言】
  • Ubuntu下Git使用教程:从入门到实践
  • Java怎么把多个对象的list的数据合并
  • [Android][Reboot/Shutdown] 重启/关机 分析
  • bibtex是什么
  • WPF的**逻辑树**和**可视树**。
  • 2024年数学建模比赛题目及解题代码
  • 初识Linux · 进程(3)
  • 软考架构-面向服务的架构风格
  • 电子废物检测回收系统源码分享
  • STM32点亮第一个LED
  • starUML使用说明文档[简单易懂/清晰明了]||好上手
  • Netty笔记03-组件Channel
  • Android中的Context
  • 接口测试从入门到精通项目实战
  • 【Android 13源码分析】WindowContainer窗口层级-3-实例分析
  • MTC完成右臂抓取放置任务\\放置姿态设置
  • 【SQL】百题计划:SQL判断条件OR的使用。
  • 如何为子域名配置 Nginx 反向代理到 Flask 应用
  • IEEE会议论文引用格式
  • 在 Android 中,事件的分发机制
  • 淘宝商品详情API返回值中的预售与定制信息解析
  • xtu oj 折纸
  • [网络]从零开始的计算机网络基础知识讲解
  • eureka.client.service-url.defaultZone的坑
  • 数据库系统 第50节 数据库灾难恢复计划