当前位置: 首页 > article >正文

如何构建TCN网络提取序列特征

  1. 原理可以看一下别人写的帖子
    https://blog.csdn.net/weixin_39910711/article/details/124678538

TCN网络主要在于提取序列的特征,并且可以实现特征维度的改变,并保持序列长度不变。
具体代码如下:

class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()

        self.chomp_size = chomp_size

    def forward(self, x):
        """
        去除padding,并且使tensor在内存中连续分布
        """
        return x[..., :-self.chomp_size].contiguous()


class TemporalResidualBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
        """
        Residual block 包含两个conv1d的卷积
        Args:
            n_inputs: int, 输入通道数, 对于(seq_len, dim)的序列来说,是序列的特征维度d
            n_outputs: int,输出通道数, 同样对序列来说是输出维度dim'
            kernel_size: int, 实际卷积大小为 k * n_inputs, 在seq_len上滑动卷积
            stride:
            dilation:
            padding: (k - 1 ) * dilation
            dropout:
        """
        super(TemporalResidualBlock, self).__init__()
        # conv1 n_inputs -> n_outputs
        self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))
        self.chomp1 = Chomp1d(padding)
        self.elu1 = nn.ELU()
        self.dropout1 = nn.Dropout(dropout)

        # conv2 n_outputs -> n_outputs
        self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))
        self.chomp2 = Chomp1d(padding)
        self.elu2 = nn.ELU()
        self.dropout2 = nn.Dropout(dropout)

        # conv block
        self.net = nn.Sequential(
            self.conv1, self.chomp1, self.elu1, self.dropout1,
            self.conv2, self.chomp2, self.elu2, self.dropout2
        )

        # 1维卷积下采样, k=1
        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        self.elu = nn.ELU()

        # 权重初始化
        self.init_weights()

    def init_weights(self):
        """
        使用标准正太分布进行权重初始化
        Returns:

        """
        self.conv1.weight.data.normal_(0, 0.01)
        self.conv2.weight.data.normal_(0, 0.01)
        if self.downsample is not None:
            self.downsample.weight.data.normal_(0, 0.01)

    def forward(self, x):
        """
        保证输入输出序列长度一致, 构造 Residual 结构
        Args:
            x: size of (b, n_inputs, seq_len)

        Returns: out size is (b, outputs, seq_len)

        """
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return self.elu(out + res)


class TemporalConvNet(nn.Module):
    def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
        """

        Args:
            num_inputs: TCN 输入的通道数,即输入序列的每一时刻向量的维度
            num_channels: list, 每一层残差网络隐藏层输出通道数
            kernel_size:
            dropout:
        """
        super(TemporalConvNet, self).__init__()
        layers = []
        num_levels = len(num_channels)
        for i in range(num_levels):
            # 膨胀系数 1, 2, 4 ...
            dilation_size = 2 ** i

            in_channels = num_inputs if i == 0 else num_channels[i - 1]
            out_channels = num_channels[i]
            layers += [TemporalResidualBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,
                                             padding=(kernel_size - 1) * dilation_size, dropout=dropout)]

        self.network = nn.Sequential(*layers)

    def forward(self, x):
        """
        返回的序列长度相等,输出维度可以升可以降,起到升降维的作用
        Args:
            x:  size of (b, input_channels, seq_len)

        Returns: size of (b, out_channels, seq_len)

        """
        return self.network(x)


class MyTCN(nn.Module):
    """ Classification tasks """
    def __init__(self, input_size, output_size, num_channels, kernel_size, dropout):
        """

        Args:
            input_size: 输入每一时刻向量的特征维度
            output_size: 任务输出的向量维度
            num_channels: list, TCN 每一层残差网络隐藏层输出通道数
            kernel_size:
            dropout:
        """
        super(MyTCN, self).__init__()
        self.tcn = TemporalConvNet(input_size, num_channels, kernel_size=kernel_size, dropout=dropout)
        self.linear = nn.Linear(num_channels[-1], output_size)

    def forward(self, x):
        """

        Args:
            x: size of (B, input_size, seq_len)

        Returns: size of (B, output_size)

        """
        x = self.tcn(x)
        # x = x.view(x.size(0), x.size(1) * x.size(2))
        out = self.linear(x[..., -1])
        return out

if __name__ == '__main__':
	from torch import nn
	import torch
	from torchinfo import summary
	from torch.nn.utils import weight_norm
    mytcn = MyTCN(80,25,[40,40,40,40,40,40,40,25],5,0.5)
    summary(mytcn, input_size=(512, 80, 40), depth=5)

输出模型结构如下:
这里是构建了一个简单的分类任务,通过指定num_channels这个列表参数,可以指定每层残差网络输出的特征维度,从而改变序列的特征维度,而不改变序列长度。对于一些时序任务可以作为特征提取器。

MyTCN [512, 25] –
├─TemporalConvNet: 1-1 [512, 25, 40] –
│ └─Sequential: 2-1 [512, 25, 40] –
│ │ └─TemporalResidualBlock: 3-1 [512, 40, 40] –
│ │ │ └─Sequential: 4-1 [512, 40, 40] –
│ │ │ │ └─Conv1d: 5-1 [512, 40, 44] 16,080
│ │ │ │ └─Chomp1d: 5-2 [512, 40, 40] –
│ │ │ │ └─ELU: 5-3 [512, 40, 40] –
│ │ │ │ └─Dropout: 5-4 [512, 40, 40] –
│ │ │ │ └─Conv1d: 5-5 [512, 40, 44] 8,080
│ │ │ │ └─Chomp1d: 5-6 [512, 40, 40] –
│ │ │ │ └─ELU: 5-7 [512, 40, 40] –
│ │ │ │ └─Dropout: 5-8 [512, 40, 40] –
│ │ │ └─Conv1d: 4-2 [512, 40, 40] 3,240
│ │ │ └─ELU: 4-3 [512, 40, 40] –
│ │ └─TemporalResidualBlock: 3-2 [512, 40, 40] –
│ │ │ └─Sequential: 4-4 [512, 40, 40] –
│ │ │ │ └─Conv1d: 5-9 [512, 40, 48] 8,080
│ │ │ │ └─Chomp1d: 5-10 [512, 40, 40] –
│ │ │ │ └─ELU: 5-11 [512, 40, 40] –
│ │ │ │ └─Dropout: 5-12 [512, 40, 40] –
│ │ │ │ └─Conv1d: 5-13 [512, 40, 48] 8,080
│ │ │ │ └─Chomp1d: 5-14 [512, 40, 40] –
│ │ │ │ └─ELU: 5-15 [512, 40, 40] –
│ │ │ │ └─Dropout: 5-16 [512, 40, 40] –
│ │ │ └─ELU: 4-5 [512, 40, 40] –
│ │ └─TemporalResidualBlock: 3-3 [512, 40, 40] –
│ │ │ └─Sequential: 4-6 [512, 40, 40] –
│ │ │ │ └─Conv1d: 5-17 [512, 40, 56] 8,080
│ │ │ │ └─Chomp1d: 5-18 [512, 40, 40] –
│ │ │ │ └─ELU: 5-19 [512, 40, 40] –
│ │ │ │ └─Dropout: 5-20 [512, 40, 40] –
│ │ │ │ └─Conv1d: 5-21 [512, 40, 56] 8,080
│ │ │ │ └─Chomp1d: 5-22 [512, 40, 40] –
│ │ │ │ └─ELU: 5-23 [512, 40, 40] –
│ │ │ │ └─Dropout: 5-24 [512, 40, 40] –
│ │ │ └─ELU: 4-7 [512, 40, 40] –
│ │ └─TemporalResidualBlock: 3-4 [512, 40, 40] –
│ │ │ └─Sequential: 4-8 [512, 40, 40] –
│ │ │ │ └─Conv1d: 5-25 [512, 40, 72] 8,080
│ │ │ │ └─Chomp1d: 5-26 [512, 40, 40] –
│ │ │ │ └─ELU: 5-27 [512, 40, 40] –
│ │ │ │ └─Dropout: 5-28 [512, 40, 40] –
│ │ │ │ └─Conv1d: 5-29 [512, 40, 72] 8,080
│ │ │ │ └─Chomp1d: 5-30 [512, 40, 40] –
│ │ │ │ └─ELU: 5-31 [512, 40, 40] –
│ │ │ │ └─Dropout: 5-32 [512, 40, 40] –
│ │ │ └─ELU: 4-9 [512, 40, 40] –
│ │ └─TemporalResidualBlock: 3-5 [512, 40, 40] –
│ │ │ └─Sequential: 4-10 [512, 40, 40] –
│ │ │ │ └─Conv1d: 5-33 [512, 40, 104] 8,080
│ │ │ │ └─Chomp1d: 5-34 [512, 40, 40] –
│ │ │ │ └─ELU: 5-35 [512, 40, 40] –
│ │ │ │ └─Dropout: 5-36 [512, 40, 40] –
│ │ │ │ └─Conv1d: 5-37 [512, 40, 104] 8,080
│ │ │ │ └─Chomp1d: 5-38 [512, 40, 40] –
│ │ │ │ └─ELU: 5-39 [512, 40, 40] –
│ │ │ │ └─Dropout: 5-40 [512, 40, 40] –
│ │ │ └─ELU: 4-11 [512, 40, 40] –
│ │ └─TemporalResidualBlock: 3-6 [512, 40, 40] –
│ │ │ └─Sequential: 4-12 [512, 40, 40] –
│ │ │ │ └─Conv1d: 5-41 [512, 40, 168] 8,080
│ │ │ │ └─Chomp1d: 5-42 [512, 40, 40] –
│ │ │ │ └─ELU: 5-43 [512, 40, 40] –
│ │ │ │ └─Dropout: 5-44 [512, 40, 40] –
│ │ │ │ └─Conv1d: 5-45 [512, 40, 168] 8,080
│ │ │ │ └─Chomp1d: 5-46 [512, 40, 40] –
│ │ │ │ └─ELU: 5-47 [512, 40, 40] –
│ │ │ │ └─Dropout: 5-48 [512, 40, 40] –
│ │ │ └─ELU: 4-13 [512, 40, 40] –
│ │ └─TemporalResidualBlock: 3-7 [512, 40, 40] –
│ │ │ └─Sequential: 4-14 [512, 40, 40] –
│ │ │ │ └─Conv1d: 5-49 [512, 40, 296] 8,080
│ │ │ │ └─Chomp1d: 5-50 [512, 40, 40] –
│ │ │ │ └─ELU: 5-51 [512, 40, 40] –
│ │ │ │ └─Dropout: 5-52 [512, 40, 40] –
│ │ │ │ └─Conv1d: 5-53 [512, 40, 296] 8,080
│ │ │ │ └─Chomp1d: 5-54 [512, 40, 40] –
│ │ │ │ └─ELU: 5-55 [512, 40, 40] –
│ │ │ │ └─Dropout: 5-56 [512, 40, 40] –
│ │ │ └─ELU: 4-15 [512, 40, 40] –
│ │ └─TemporalResidualBlock: 3-8 [512, 25, 40] –
│ │ │ └─Sequential: 4-16 [512, 25, 40] –
│ │ │ │ └─Conv1d: 5-57 [512, 25, 552] 5,050
│ │ │ │ └─Chomp1d: 5-58 [512, 25, 40] –
│ │ │ │ └─ELU: 5-59 [512, 25, 40] –
│ │ │ │ └─Dropout: 5-60 [512, 25, 40] –
│ │ │ │ └─Conv1d: 5-61 [512, 25, 552] 3,175
│ │ │ │ └─Chomp1d: 5-62 [512, 25, 40] –
│ │ │ │ └─ELU: 5-63 [512, 25, 40] –
│ │ │ │ └─Dropout: 5-64 [512, 25, 40] –
│ │ │ └─Conv1d: 4-17 [512, 25, 40] 1,025
│ │ │ └─ELU: 4-18 [512, 25, 40] –
├─Linear: 1-2 [512, 25] 650


http://www.kler.cn/a/228808.html

相关文章:

  • 【部署】将项目部署到云服务器
  • 基于springboot的口腔管理平台
  • 人工智能之数学基础:线性代数中的线性相关和线性无关
  • mysql_real_connect的概念和使用案例
  • 20250118-读取并显示彩色图像以及提取彩色图像的 R、G、B 分量
  • 网络安全---CMS指纹信息实战
  • NLP任务之Named Entity Recognition
  • 自然语言处理中所有任务的概括
  • vue-element-admin npm install 失败解决
  • 代码随想录算法训练营Day49|121. 买卖股票的最佳时机、122.买卖股票的最佳时机II
  • 【IMAX6U移植OpenCV】
  • 15.1 项目实践_OA系统
  • 【RT-DETR有效改进】UNetv2提出的一种SDI多层次特征融合模块(细节高效涨点)
  • 浅谈QT的几种线程的使用和区别。
  • 如何部署Linux AMH服务器管理面板并结合内网穿透远程访问
  • 【AI数字人-论文】Geneface论文
  • H5调用安卓原生相机API案例
  • Java学习day29:线程池Pool中创建线程方式(面试必考!)
  • 《热辣滚烫》预售狂潮来袭,贾玲、马丽、杨紫三大女神联袂出演。
  • (4)【Python数据分析进阶】Machine-Learning模型与算法应用-回归、分类模型汇总
  • Java实现线程安全的几种方式:常量/数据私有/互斥同步/非阻塞同步
  • 【数据结构 10】位图
  • jmeter-问题一:关于线程组,线程数,用户数详解
  • 5分钟快速掌握 XML (Extensible Markup Language)
  • 【51单片机】开发板&开发软件(Keil5&STC-ISP)简介&下载安装破译传送门(1)
  • QT styleSheet——控件设置样式表