当前位置: 首页 > article >正文

【PyTorch】循环神经网络

循环神经网络是什么

Recurrent Neural Networks
RNN:循环神经网络

  • 处理不定长输入的模型
  • 常用于NLP及时间序列任务(输入数据具有前后关系

RNN网络结构

参考资料
Recurrent Neural Networks Tutorial, Part 1 – Introduction to RNNs
Understanding LSTM Networks
在这里插入图片描述

RNN实现人名分类

问题定义:输入任意长度姓名(字符串),输出姓名来自哪一个国家(18类分类任务)
数据: https://download.pytorch.org/tutorial/data.zip
Jackie Chan —— 成龙
Jay Chou —— 周杰伦
Tingsong Yue —— 余霆嵩

RNN如何处理不定长输入

思考:计算机如何实现不定长字符串分类向量的映射?
Chou(字符串)→ RNN →Chinese(分类类别)

  1. 单词字符 → 数字
  2. 数字 → model
  3. 下一个字符 → 数字 → model
  4. 最后一个字符 → 数字 → model → 分类向量
# 伪代码
# Chou(字符串)→ RNN →Chinese(分类类别)
for string in [C, h, o, u]:
	1. one-hot:string → [0,0, ...., 1, ..., 0]	# 首先把每个字母转换成编码
	2. y, h = model([0,0, ...., 1, ..., 0], h)		# h就是隐藏层的状态信息

xt:时刻t的输入,shape = (1, 57)
st:时刻t的状态值,shape=(1, 128)
ot:时刻t的输出值,shape=(1, 18)
U:linear层的权重参数, shape = (128, 57)
W:linear层的权重参数, shape = (128, 128)
V:linear层的权重参数, shape = (18, 128)

代码如下:

# -*- coding: utf-8 -*-
"""
# @file name  : rnn_demo.py
# @author     : TingsongYu https://github.com/TingsongYu
# @date       : 2019-12-09
# @brief      : rnn人名分类
"""
from io import open
import glob
import unicodedata
import string
import math
import os
import time
import torch.nn as nn
import torch
import random
import matplotlib.pyplot as plt
import torch.utils.data
import sys
# 获取路径
hello_pytorch_DIR = os.path.abspath(os.path.dirname(__file__)+os.path.sep+".."+os.path.sep+"..")
sys.path.append(hello_pytorch_DIR)

from tools.common_tools import set_seed

set_seed(1)  # 设置随机种子
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# 选择运行设备
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")


# Read a file and split into lines
def readLines(filename):
    lines = open(filename, encoding='utf-8').read().strip().split('\n')
    return [unicodeToAscii(line) for line in lines]


def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
        and c in all_letters)


# Find letter index from all_letters, e.g. "a" = 0
def letterToIndex(letter):
    return all_letters.find(letter)


# Just for demonstration, turn a letter into a <1 x n_letters> Tensor
def letterToTensor(letter):
    tensor = torch.zeros(1, n_letters)
    tensor[0][letterToIndex(letter)] = 1
    return tensor


# Turn a line into a <line_length x 1 x n_letters>,
# or an array of one-hot letter vectors
def lineToTensor(line):
    tensor = torch.zeros(len(line), 1, n_letters)
    for li, letter in enumerate(line):
        tensor[li][0][letterToIndex(letter)] = 1
    return tensor


def categoryFromOutput(output):
    top_n, top_i = output.topk(1)
    category_i = top_i[0].item()
    return all_categories[category_i], category_i


def randomChoice(l):
    return l[random.randint(0, len(l) - 1)]


def randomTrainingExample():
    category = randomChoice(all_categories)                 # 选类别
    line = randomChoice(category_lines[category])           # 选一个样本
    category_tensor = torch.tensor([all_categories.index(category)], dtype=torch.long)
    line_tensor = lineToTensor(line)    # str to one-hot
    return category, line, category_tensor, line_tensor


def timeSince(since):
    now = time.time()
    s = now - since
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


# Just return an output given a line
def evaluate(line_tensor):
    hidden = rnn.initHidden()

    for i in range(line_tensor.size()[0]):
        output, hidden = rnn(line_tensor[i], hidden)

    return output


def predict(input_line, n_predictions=3):
    print('\n> %s' % input_line)
    with torch.no_grad():
        output = evaluate(lineToTensor(input_line))

        # Get top N categories
        topv, topi = output.topk(n_predictions, 1, True)

        for i in range(n_predictions):
            value = topv[0][i].item()
            category_index = topi[0][i].item()
            print('(%.2f) %s' % (value, all_categories[category_index]))


def get_lr(iter, learning_rate):
    lr_iter = learning_rate if iter < n_iters else learning_rate*0.1
    return lr_iter

# 定义网络结构
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()

        self.hidden_size = hidden_size

        self.u = nn.Linear(input_size, hidden_size)
        self.w = nn.Linear(hidden_size, hidden_size)
        self.v = nn.Linear(hidden_size, output_size)

        self.tanh = nn.Tanh()
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, inputs, hidden):

        u_x = self.u(inputs)

        hidden = self.w(hidden)
        hidden = self.tanh(hidden + u_x)

        output = self.softmax(self.v(hidden))

        return output, hidden

    def initHidden(self):
        return torch.zeros(1, self.hidden_size)


def train(category_tensor, line_tensor):
    hidden = rnn.initHidden()

    rnn.zero_grad()

    line_tensor = line_tensor.to(device)
    hidden = hidden.to(device)
    category_tensor = category_tensor.to(device)

    for i in range(line_tensor.size()[0]):
        output, hidden = rnn(line_tensor[i], hidden)

    loss = criterion(output, category_tensor)
    loss.backward()

    # Add parameters' gradients to their values, multiplied by learning rate
    for p in rnn.parameters():
        # p.data.add_(-learning_rate, p.grad.data) # 该方法已经被弃用
        p.data.add_(p.grad.data, alpha=-learning_rate)

    return output, loss.item()


if __name__ == "__main__":
    print(device)

    # config
    data_dir = os.path.abspath(os.path.join(BASE_DIR, "..", "..", "data", "rnn_data", "names"))
    if not os.path.exists(data_dir):
        raise Exception("\n{} 不存在,请下载 08-05-数据-20200724.zip  放到\n{}  下,并解压即可".format(
            data_dir, os.path.dirname(data_dir)))

    path_txt = os.path.join(data_dir, "*.txt")
    all_letters = string.ascii_letters + " .,;'"
    n_letters = len(all_letters)    # 52 + 5 字符总数
    print_every = 5000
    plot_every = 5000
    learning_rate = 0.005
    n_iters = 200000

    # step 1 data
    # Build the category_lines dictionary, a list of names per language
    category_lines = {}
    all_categories = []
    for filename in glob.glob(path_txt):
        category = os.path.splitext(os.path.basename(filename))[0]
        all_categories.append(category)
        lines = readLines(filename)
        category_lines[category] = lines

    n_categories = len(all_categories)

    # step 2 model
    n_hidden = 128
    # rnn = RNN(n_letters, n_hidden, n_categories)
    rnn = RNN(n_letters, n_hidden, n_categories)

    rnn.to(device)

    # step 3 loss
    criterion = nn.NLLLoss()

    # step 4 optimize by hand

    # step 5 iteration
    current_loss = 0
    all_losses = []
    start = time.time()
    for iter in range(1, n_iters + 1):
        # sample
        category, line, category_tensor, line_tensor = randomTrainingExample()

        # training
        output, loss = train(category_tensor, line_tensor)

        current_loss += loss

        # Print iter number, loss, name and guess
        if iter % print_every == 0:
            guess, guess_i = categoryFromOutput(output)
            correct = '✓' if guess == category else '✗ (%s)' % category
            print('Iter: {:<7} time: {:>8s} loss: {:.4f} name: {:>10s}  pred: {:>8s} label: {:>8s}'.format(
                iter, timeSince(start), loss, line, guess, correct))

        # Add current loss avg to list of losses
        if iter % plot_every == 0:
            all_losses.append(current_loss / plot_every)
            current_loss = 0

path_model = os.path.abspath(os.path.join(BASE_DIR, "..", "..", "data", "rnn_state_dict.pkl"))
if not os.path.exists(path_model):
    raise Exception("\n{} 不存在,请下载 08-05-数据-20200724.zip  放到\n{}  下,并解压即可".format(
        path_model, os.path.dirname(path_model)))
torch.save(rnn.state_dict(), path_model)
plt.plot(all_losses)
plt.show()

predict('Yue Tingsong')
predict('Yue tingsong')
predict('yutingsong')

predict('test your name')

http://www.kler.cn/news/326006.html

相关文章:

  • Ubuntu磁盘不足扩容
  • 【Linux篇】常用命令及操作技巧(进阶篇 - 上)
  • 数据结构之栈和队列——LeetCode:150. 逆波兰表达式求值,224. 基本计算器,232. 用栈实现队列
  • CSS 的user-select属性,控制用户是否能够选中文本内容
  • Java知识要点及面试题
  • 确保从IP池提取的IP是可用的对于数据抓取或其他网络活动至关重要。以下是一些确保IP可用性的有效方法:
  • 创新车展模式 焕新直播生态——第十一届麓谷汽车文化节圆满收官
  • 2024前端技术发展概况
  • Linux RCE 利用打印机服务 CVE-2024-47177
  • 【Redis】初识 Redis
  • 城市空间设计对居民生活质量的影响:构建宜居城市的蓝图
  • 基于SpringBoot实现QQ邮箱发送短信功能 | 免费短信服务
  • Nacos 安全使用最佳实践 - 访问控制实践
  • Redis 基础数据改造
  • AI绘画另类人像写真教程:用SD的 Reactor 实现换脸,效果真的很逼真!请谨慎使用
  • 基于大数据的商品推荐及可视化系统
  • E34.【C语言】位段练习题
  • 使用Python实现图形学的阴影体积算法
  • 秘密武器与选择指南
  • maven给springboot项目打成jar包 maven springboot打包配置
  • 深入解析 NoSQL 数据库的分类与特点
  • Qt5和Qt6获取屏幕的宽高,有区别
  • CSS宽度和高度
  • C++的6种构造函数
  • 8月面试总结
  • sql注入工具升级:自动化时间盲注、布尔盲注
  • IT基础监控范围和对象
  • Excel 获取某列不为空的值【INDEX函数 | SMALL函数或 LARGE函数 | ROW函数 | ISBLANK 函数】
  • Linux编译与调试
  • 四非人的保研之路,2024(2025届)四非计算机的保研经验分享(西南交通、苏大nlp、西电、北邮、山软、山计、电科、厦大等)