如何构建TCN网络提取序列特征
- 原理可以看一下别人写的帖子
https://blog.csdn.net/weixin_39910711/article/details/124678538
TCN网络主要在于提取序列的特征,并且可以实现特征维度的改变,并保持序列长度不变。
具体代码如下:
class Chomp1d(nn.Module):
def __init__(self, chomp_size):
super(Chomp1d, self).__init__()
self.chomp_size = chomp_size
def forward(self, x):
"""
去除padding,并且使tensor在内存中连续分布
"""
return x[..., :-self.chomp_size].contiguous()
class TemporalResidualBlock(nn.Module):
def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
"""
Residual block 包含两个conv1d的卷积
Args:
n_inputs: int, 输入通道数, 对于(seq_len, dim)的序列来说,是序列的特征维度d
n_outputs: int,输出通道数, 同样对序列来说是输出维度dim'
kernel_size: int, 实际卷积大小为 k * n_inputs, 在seq_len上滑动卷积
stride:
dilation:
padding: (k - 1 ) * dilation
dropout:
"""
super(TemporalResidualBlock, self).__init__()
# conv1 n_inputs -> n_outputs
self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size,
stride=stride, padding=padding, dilation=dilation))
self.chomp1 = Chomp1d(padding)
self.elu1 = nn.ELU()
self.dropout1 = nn.Dropout(dropout)
# conv2 n_outputs -> n_outputs
self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size,
stride=stride, padding=padding, dilation=dilation))
self.chomp2 = Chomp1d(padding)
self.elu2 = nn.ELU()
self.dropout2 = nn.Dropout(dropout)
# conv block
self.net = nn.Sequential(
self.conv1, self.chomp1, self.elu1, self.dropout1,
self.conv2, self.chomp2, self.elu2, self.dropout2
)
# 1维卷积下采样, k=1
self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
self.elu = nn.ELU()
# 权重初始化
self.init_weights()
def init_weights(self):
"""
使用标准正太分布进行权重初始化
Returns:
"""
self.conv1.weight.data.normal_(0, 0.01)
self.conv2.weight.data.normal_(0, 0.01)
if self.downsample is not None:
self.downsample.weight.data.normal_(0, 0.01)
def forward(self, x):
"""
保证输入输出序列长度一致, 构造 Residual 结构
Args:
x: size of (b, n_inputs, seq_len)
Returns: out size is (b, outputs, seq_len)
"""
out = self.net(x)
res = x if self.downsample is None else self.downsample(x)
return self.elu(out + res)
class TemporalConvNet(nn.Module):
def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
"""
Args:
num_inputs: TCN 输入的通道数,即输入序列的每一时刻向量的维度
num_channels: list, 每一层残差网络隐藏层输出通道数
kernel_size:
dropout:
"""
super(TemporalConvNet, self).__init__()
layers = []
num_levels = len(num_channels)
for i in range(num_levels):
# 膨胀系数 1, 2, 4 ...
dilation_size = 2 ** i
in_channels = num_inputs if i == 0 else num_channels[i - 1]
out_channels = num_channels[i]
layers += [TemporalResidualBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,
padding=(kernel_size - 1) * dilation_size, dropout=dropout)]
self.network = nn.Sequential(*layers)
def forward(self, x):
"""
返回的序列长度相等,输出维度可以升可以降,起到升降维的作用
Args:
x: size of (b, input_channels, seq_len)
Returns: size of (b, out_channels, seq_len)
"""
return self.network(x)
class MyTCN(nn.Module):
""" Classification tasks """
def __init__(self, input_size, output_size, num_channels, kernel_size, dropout):
"""
Args:
input_size: 输入每一时刻向量的特征维度
output_size: 任务输出的向量维度
num_channels: list, TCN 每一层残差网络隐藏层输出通道数
kernel_size:
dropout:
"""
super(MyTCN, self).__init__()
self.tcn = TemporalConvNet(input_size, num_channels, kernel_size=kernel_size, dropout=dropout)
self.linear = nn.Linear(num_channels[-1], output_size)
def forward(self, x):
"""
Args:
x: size of (B, input_size, seq_len)
Returns: size of (B, output_size)
"""
x = self.tcn(x)
# x = x.view(x.size(0), x.size(1) * x.size(2))
out = self.linear(x[..., -1])
return out
if __name__ == '__main__':
from torch import nn
import torch
from torchinfo import summary
from torch.nn.utils import weight_norm
mytcn = MyTCN(80,25,[40,40,40,40,40,40,40,25],5,0.5)
summary(mytcn, input_size=(512, 80, 40), depth=5)
输出模型结构如下:
这里是构建了一个简单的分类任务,通过指定num_channels这个列表参数,可以指定每层残差网络输出的特征维度,从而改变序列的特征维度,而不改变序列长度。对于一些时序任务可以作为特征提取器。
MyTCN [512, 25] –
├─TemporalConvNet: 1-1 [512, 25, 40] –
│ └─Sequential: 2-1 [512, 25, 40] –
│ │ └─TemporalResidualBlock: 3-1 [512, 40, 40] –
│ │ │ └─Sequential: 4-1 [512, 40, 40] –
│ │ │ │ └─Conv1d: 5-1 [512, 40, 44] 16,080
│ │ │ │ └─Chomp1d: 5-2 [512, 40, 40] –
│ │ │ │ └─ELU: 5-3 [512, 40, 40] –
│ │ │ │ └─Dropout: 5-4 [512, 40, 40] –
│ │ │ │ └─Conv1d: 5-5 [512, 40, 44] 8,080
│ │ │ │ └─Chomp1d: 5-6 [512, 40, 40] –
│ │ │ │ └─ELU: 5-7 [512, 40, 40] –
│ │ │ │ └─Dropout: 5-8 [512, 40, 40] –
│ │ │ └─Conv1d: 4-2 [512, 40, 40] 3,240
│ │ │ └─ELU: 4-3 [512, 40, 40] –
│ │ └─TemporalResidualBlock: 3-2 [512, 40, 40] –
│ │ │ └─Sequential: 4-4 [512, 40, 40] –
│ │ │ │ └─Conv1d: 5-9 [512, 40, 48] 8,080
│ │ │ │ └─Chomp1d: 5-10 [512, 40, 40] –
│ │ │ │ └─ELU: 5-11 [512, 40, 40] –
│ │ │ │ └─Dropout: 5-12 [512, 40, 40] –
│ │ │ │ └─Conv1d: 5-13 [512, 40, 48] 8,080
│ │ │ │ └─Chomp1d: 5-14 [512, 40, 40] –
│ │ │ │ └─ELU: 5-15 [512, 40, 40] –
│ │ │ │ └─Dropout: 5-16 [512, 40, 40] –
│ │ │ └─ELU: 4-5 [512, 40, 40] –
│ │ └─TemporalResidualBlock: 3-3 [512, 40, 40] –
│ │ │ └─Sequential: 4-6 [512, 40, 40] –
│ │ │ │ └─Conv1d: 5-17 [512, 40, 56] 8,080
│ │ │ │ └─Chomp1d: 5-18 [512, 40, 40] –
│ │ │ │ └─ELU: 5-19 [512, 40, 40] –
│ │ │ │ └─Dropout: 5-20 [512, 40, 40] –
│ │ │ │ └─Conv1d: 5-21 [512, 40, 56] 8,080
│ │ │ │ └─Chomp1d: 5-22 [512, 40, 40] –
│ │ │ │ └─ELU: 5-23 [512, 40, 40] –
│ │ │ │ └─Dropout: 5-24 [512, 40, 40] –
│ │ │ └─ELU: 4-7 [512, 40, 40] –
│ │ └─TemporalResidualBlock: 3-4 [512, 40, 40] –
│ │ │ └─Sequential: 4-8 [512, 40, 40] –
│ │ │ │ └─Conv1d: 5-25 [512, 40, 72] 8,080
│ │ │ │ └─Chomp1d: 5-26 [512, 40, 40] –
│ │ │ │ └─ELU: 5-27 [512, 40, 40] –
│ │ │ │ └─Dropout: 5-28 [512, 40, 40] –
│ │ │ │ └─Conv1d: 5-29 [512, 40, 72] 8,080
│ │ │ │ └─Chomp1d: 5-30 [512, 40, 40] –
│ │ │ │ └─ELU: 5-31 [512, 40, 40] –
│ │ │ │ └─Dropout: 5-32 [512, 40, 40] –
│ │ │ └─ELU: 4-9 [512, 40, 40] –
│ │ └─TemporalResidualBlock: 3-5 [512, 40, 40] –
│ │ │ └─Sequential: 4-10 [512, 40, 40] –
│ │ │ │ └─Conv1d: 5-33 [512, 40, 104] 8,080
│ │ │ │ └─Chomp1d: 5-34 [512, 40, 40] –
│ │ │ │ └─ELU: 5-35 [512, 40, 40] –
│ │ │ │ └─Dropout: 5-36 [512, 40, 40] –
│ │ │ │ └─Conv1d: 5-37 [512, 40, 104] 8,080
│ │ │ │ └─Chomp1d: 5-38 [512, 40, 40] –
│ │ │ │ └─ELU: 5-39 [512, 40, 40] –
│ │ │ │ └─Dropout: 5-40 [512, 40, 40] –
│ │ │ └─ELU: 4-11 [512, 40, 40] –
│ │ └─TemporalResidualBlock: 3-6 [512, 40, 40] –
│ │ │ └─Sequential: 4-12 [512, 40, 40] –
│ │ │ │ └─Conv1d: 5-41 [512, 40, 168] 8,080
│ │ │ │ └─Chomp1d: 5-42 [512, 40, 40] –
│ │ │ │ └─ELU: 5-43 [512, 40, 40] –
│ │ │ │ └─Dropout: 5-44 [512, 40, 40] –
│ │ │ │ └─Conv1d: 5-45 [512, 40, 168] 8,080
│ │ │ │ └─Chomp1d: 5-46 [512, 40, 40] –
│ │ │ │ └─ELU: 5-47 [512, 40, 40] –
│ │ │ │ └─Dropout: 5-48 [512, 40, 40] –
│ │ │ └─ELU: 4-13 [512, 40, 40] –
│ │ └─TemporalResidualBlock: 3-7 [512, 40, 40] –
│ │ │ └─Sequential: 4-14 [512, 40, 40] –
│ │ │ │ └─Conv1d: 5-49 [512, 40, 296] 8,080
│ │ │ │ └─Chomp1d: 5-50 [512, 40, 40] –
│ │ │ │ └─ELU: 5-51 [512, 40, 40] –
│ │ │ │ └─Dropout: 5-52 [512, 40, 40] –
│ │ │ │ └─Conv1d: 5-53 [512, 40, 296] 8,080
│ │ │ │ └─Chomp1d: 5-54 [512, 40, 40] –
│ │ │ │ └─ELU: 5-55 [512, 40, 40] –
│ │ │ │ └─Dropout: 5-56 [512, 40, 40] –
│ │ │ └─ELU: 4-15 [512, 40, 40] –
│ │ └─TemporalResidualBlock: 3-8 [512, 25, 40] –
│ │ │ └─Sequential: 4-16 [512, 25, 40] –
│ │ │ │ └─Conv1d: 5-57 [512, 25, 552] 5,050
│ │ │ │ └─Chomp1d: 5-58 [512, 25, 40] –
│ │ │ │ └─ELU: 5-59 [512, 25, 40] –
│ │ │ │ └─Dropout: 5-60 [512, 25, 40] –
│ │ │ │ └─Conv1d: 5-61 [512, 25, 552] 3,175
│ │ │ │ └─Chomp1d: 5-62 [512, 25, 40] –
│ │ │ │ └─ELU: 5-63 [512, 25, 40] –
│ │ │ │ └─Dropout: 5-64 [512, 25, 40] –
│ │ │ └─Conv1d: 4-17 [512, 25, 40] 1,025
│ │ │ └─ELU: 4-18 [512, 25, 40] –
├─Linear: 1-2 [512, 25] 650