PyTorch从零开始实现LSTM

文章目录

  • LSTM基础理论
  • 从零开始实现LSTM
  • 简洁版LSTM实现
  • 参考资料

LSTM基础理论

关于LSTM的基础理论不再赘述,可以参考资料:

  • RNN神经网络-LSTM模型结构
  • https://github.com/ShusenTang/Dive-into-DL-PyTorch/blob/master/docs/chapter06_RNN/6.8_lstm.md

从零开始实现LSTM

  1. 导入所需的依赖包,获取设备:
import torch
from torch import nn, optim
import numpy as np
import zipfile

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  1. 定义数据加载代码,这里以周杰伦的歌词数据为例:
def load_data_jay_lyrics():
    with zipfile.ZipFile('./data/jaychou_lyrics.txt.zip') as zin:
        with zin.open('jaychou_lyrics.txt') as f:
            corpus_chars = f.read().decode('utf-8')
    # corpus_chars[:40]  # '想要有直升机\n想要和你飞到宇宙去\n想要和你融化在一起\n融化在宇宙里\n我每天每天每'

    # 将换行符替换成空格;仅使用前1万个字符来训练模型
    corpus_chars = corpus_chars.replace('\n', ' ').replace('\r', ' ')
    corpus_chars = corpus_chars[0:10000]

    # 将每个字符映射成索引
    idx_to_char = list(set(corpus_chars))
    char_to_idx = dict([(char, i) for i, char in enumerate(idx_to_char)])
    vocab_size = len(char_to_idx)  # 1027
    corpus_indices = [char_to_idx[char] for char in corpus_chars]
    sample = corpus_indices[:20]
    return corpus_indices, char_to_idx, idx_to_char, vocab_size
  1. 定义参数初始化函数:
def get_params():
    def _one(shape):
        ts = torch.tensor(np.random.normal(0, 0.01, size=shape), device=device, dtype=torch.float32)
        return torch.nn.Parameter(ts, requires_grad=True)

    def _three():
        return (_one((num_inputs, num_hiddens)),
                _one((num_inputs, num_hiddens)),
                torch.nn.Parameter(torch.zeros(num_hiddens, device=device, dtype=torch.float32), requires_grad=True))

    W_xi, W_hi, b_i = _three()  # 输入门参数
    W_xf, W_hf, b_f = _three()  # 遗忘门参数
    W_xo, W_ho, b_o = _three()  # 输出门参数
    W_xc, W_hc, b_c = _three()  # 候选记忆细胞参数

    # 输出层参数
    W_hq = _one((num_hiddens, num_outputs))
    b_q = torch.nn.Parameter(torch.zeros(num_outputs, device=device, dtype=torch.float32), requires_grad=True)
    return nn.ParameterList([W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q])


# 加载周杰伦歌词数据
corpus_indices, char_to_idx, idx_to_char, vocab_size = load_data_jay_lyrics()
num_inputs, num_hiddens, num_outputs = vocab_size, 256, vocab_size


def init_lstm_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device),
            torch.zeros((batch_size, num_hiddens), device=device))
  1. 定义LSTM模型:
def lstm(inputs, state, params):
    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q] = params
    (H, C) = state
    outputs = []
    for X in inputs:
        I = torch.sigmoid(torch.matmul(X, W_xi) + torch.matmul(H, W_hi) + b_i)
        F = torch.sigmoid(torch.matmul(X, W_xf) + torch.matmul(H, W_hf) + b_f)
        O = torch.sigmoid(torch.matmul(X, W_xo) + torch.matmul(H, W_ho) + b_o)
        C_tilda = torch.tanh(torch.matmul(X, W_xc) + torch.matmul(H, W_hc) + b_c)
        C = F * C + I * C_tilda
        H = O * C.tanh()
        Y = torch.matmul(H, W_hq) + b_q
        outputs.append(Y)
    return outputs, (H, C)
  1. 训练模型。
num_epochs, num_steps, batch_size, lr, clipping_theta = 160, 35, 32, 1e2, 1e-2
pred_period, pred_len, prefixes = 40, 50, ['分开', '不分开']


train_and_predict_rnn(lstm, get_params, init_lstm_state, num_hiddens,
                          vocab_size, device, corpus_indices, idx_to_char,
                          char_to_idx, False, num_epochs, num_steps, lr,
                          clipping_theta, batch_size, pred_period, pred_len,
                          prefixes)

这里 train_and_predict_rnn 为训练函数,可以参考
https://github.com/ShusenTang/Dive-into-DL-PyTorch/blob/master/docs/chapter06_RNN/6.5_rnn-pytorch.md
中的实现。

简洁版LSTM实现

  1. 定义RNNModel类:
# 本类已保存在d2lzh_pytorch包中方便以后使用
class RNNModel(nn.Module):
    def __init__(self, rnn_layer, vocab_size):
        super(RNNModel, self).__init__()
        self.rnn = rnn_layer
        self.hidden_size = rnn_layer.hidden_size * (2 if rnn_layer.bidirectional else 1) 
        self.vocab_size = vocab_size
        self.dense = nn.Linear(self.hidden_size, vocab_size)
        self.state = None

    def forward(self, inputs, state): # inputs: (batch, seq_len)
        # 获取one-hot向量表示
        X = to_onehot(inputs, self.vocab_size) # X是个list
        Y, self.state = self.rnn(torch.stack(X), state)
        # 全连接层会首先将Y的形状变成(num_steps * batch_size, num_hiddens),它的输出
        # 形状为(num_steps * batch_size, vocab_size)
        output = self.dense(Y.view(-1, Y.shape[-1]))
        return output, self.state
def to_onehot(X, n_class):  
    # X shape: (batch, seq_len), output: seq_len elements of (batch, n_class)
    return [one_hot(X[:, i], n_class) for i in range(X.shape[1])]
  1. 定义LSTM模型并训练模型:
lr = 1e-2 # 注意调整学习率

lstm_layer = nn.LSTM(input_size=vocab_size, hidden_size=num_hiddens)

model = RNNModel(lstm_layer, vocab_size)

train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device,
                                corpus_indices, idx_to_char, char_to_idx,
                                num_epochs, num_steps, lr, clipping_theta,
                                batch_size, pred_period, pred_len, prefixes)

参考资料

  • https://github.com/ShusenTang/Dive-into-DL-PyTorch/blob/master/docs/chapter06_RNN/6.8_lstm.md

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/771996.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Android12 MultiMedia框架之MediaExtractorService

上节学到setDataSource()时会创建各种Source,source用来读取音视频源文件,读取到之后需要demux出音、视频、字幕数据流,然后再送去解码。那么负责进行demux功能的media extractor模块是在什么时候阶段创建的?这里暂时不考虑APP创建…

UE5.4新功能 - Texture Graph上手简介

TextureGraph是UE5.4还在实验(Experimental)阶段的新功能,该功能旨在材质生成方面达到类似Subtance Designer的效果,从而程序化的生成一些纹理。 本文就来简要学习一下。 1.使用UE5.4或以上版本,激活TextureGraph插件 2.内容视图中右键找到…

day11_homework_need2submit

Homework 编写—个将ts或mp4中视频文件解码到yuv的程序 yuv数据可以使用如下命令播放: ffplay -i output yuv-pix_fmt yuv420p-s 1024x436 要求: ffmpeg解析到avpacket并打印出pts和dts字段完成解码到avframe并打印任意字段完成yuv数据保存 // teminal orders on bash cd ex…

6 矩阵相关案例

矩阵计算在CUDA中的应用是并行计算领域的典型场景 ; 矩阵算法题通常涉及线性代数的基础知识,以及对数据结构和算法的深入理解。解决这类问题时,掌握一些核心思想和技巧会非常有帮助。以下是一些常见的矩阵算法题解题思想: 动态规划…

stm32——定时器级联

在STM32当中扩展定时范围:单个定时器的定时长度可能无法满足某些应用的需求。通过级联,可以实现更长时间的定时;提高定时精度:能够在长定时的基础上,通过合理配置,实现更精细的定时控制;处理复杂…

Postman工具基本使用

一、安装及基本使用 安装及基本使用参见外网文档:全网最全的 postman 工具使用教程_postman使用-CSDN博客 建议版本:11以下,比如10.x.x版本。11版本以后貌似是必须登录使用 二、禁止更新 彻底禁止postman更新 - 简书 host增加&#xff1…

vector与list的简单介绍

1. 标准库中的vector类的介绍: vector是表示大小可以变化的数组的序列容器。 就像数组一样,vector对其元素使用连续的存储位置,这意味着也可以使用指向其元素的常规指针上的偏移量来访问其元素,并且与数组中的元素一样高效。但与数…

【计算机毕业设计】026基于微信小程序的原创音乐

🙊作者简介:拥有多年开发工作经验,分享技术代码帮助学生学习,独立完成自己的项目或者毕业设计。 代码可以私聊博主获取。🌹赠送计算机毕业设计600个选题excel文件,帮助大学选题。赠送开题报告模板&#xff…

记录OSPF配置,建立邻居失败的过程

1.配置完ospf后,在路由表中不出现ospf相关信息 [SW2]ospf [SW2-ospf-1]are [SW2-ospf-1]area 0 [SW2-ospf-1-area-0.0.0.0]net [SW2-ospf-1-area-0.0.0.0]network 0.0.0.0 Jul 4 2024 22:11:58-08:00 SW2 DS/4/DATASYNC_CFGCHANGE:OID 1.3.6.1.4.1.2011.5.25 .1…

《昇思25天学习打卡营第9天|保存与加载》

文章目录 今日所学:一、构建与准备二、保存和加载模型权重三、保存和加载MindIR总结 今日所学: 在上一章节主要学习了如何调整超参数以进行网络模型训练。在这一过程中,我们通常会想要保存一些中间或最终的结果,以便进行后续的模…

《米小圈日记魔法》边看边学,轻松掌握写日记的魔法!

在当今充满数字化娱乐和信息快速变迁的时代,如何创新引导孩子们学习,特别是如何培养他们的写作能力,一直是家长和教育者们关注的焦点。今天就向大家推荐一部寓教于乐的动画片《米小圈日记魔法》,该系列动画通过其独特的故事情节和…

Elasticsearch 使用误区之二——频繁更新文档

在使用 Elasticsearch 时,频繁更新文档是一种常见误区。这不仅影响性能,还可能导致系统资源的浪费。 理解 Elasticsearch 的文档更新机制对于优化性能至关重要。 关于 Elasticsearch 更新操作,常见问题如下: ——https://t.zsxq.c…

word 转pdf 中图片不被压缩的方法

word 转pdf 中图片不被压缩的方法 法1: 调节word 选项中的图片格式为不压缩、高保真 法2: 1: word 中的图片尽可能使用高的分辨率,图片存为pnd或者 tif 格式(最高清) 2: 转化为pdf使用打印机器,参数如下…

Java面试题-锁

整体关于锁知识总结 下面是放大版: 补充:锁的粒度 忘记说了全局锁 : 1, 全局锁 flush tables with read lock ; // 对整个数据库上锁; 2, unlock tables; // 释放锁 但是我们一般不用;只有在数…

React@16.x(48)路由v5.x(13)源码(5)- 实现 Switch

目录 1&#xff0c;原生 Switch 的渲染内容2&#xff0c;实现 1&#xff0c;原生 Switch 的渲染内容 对如下代码来说&#xff1a; import { BrowserRouter as Router, Route, Switch } from "react-router-dom"; function News() {return <div className"p…

【Linux进阶】文件和目录的默认权限与隐藏权限

1.文件默认权限&#xff1a;umask OK&#xff0c;那么现在我们知道如何建立或是改变一个目录或文件的属性了&#xff0c;不过&#xff0c;你知道当你建立一个新的文件或目录时&#xff0c;它的默认权限会是什么吗&#xff1f; 呵呵&#xff0c;那就与umask这个玩意儿有关了&…

MFC+MySQL应用:配置

MFCMySQL 1. MFC UI界面生成2. 数据库和表生成创建数据库创建表添加表数据 3. VS中配置MySQL环境 1. MFC UI界面生成 链接: MFC使用方法 可以根据用户自身需求生成单文档、对话框等不同样式的UI界面。 2. 数据库和表生成 可以在workbench或者MySQL Server中创建数据库和表。…

SSM学生资助管理系统-计算机毕业设计源码30825

目 录 摘 要 1 绪论 1.1 研究背景 1.2研究意义 1.3论文结构与章节安排 2 学生资助管理系统分析 2.1 可行性分析 2.1.1 技术可行性分析 2.1.2 经济可行性分析 2.1.3 法律可行性分析 2.2 系统功能分析 2.2.1 功能性分析 2.2.2 非功能性分析 2.3 系统用例分析 2.4 …

传统数据处理系统存在的问题

传统应用的数据系统架构设计时&#xff0c;应用直接访问数据库系统。当用户访问量增加时&#xff0c;数据库无法支撑日益增长的用户请求的负载&#xff0c;从而导致数据库服务器无法及时响应用户请求&#xff0c;出现超时的错误。 出现这种情况以后&#xff0c;在系统架构上就采…

【Ubuntu24.04无显示器远控】【Todesk远程桌面黑屏】【Linux虚拟显示器】解决方案

1️⃣版本 Ubuntu 24.04Todesk 4.7.2.0xserver-xorg-video-dummy 1:0.4.0-1build1 2️⃣安装配置虚拟显示器 sudo apt install xserver-xorg-video-dummy编辑/etc/gdm3/custom.conf&#xff0c;关闭Ubuntu24.04Wayland切换为X11 WaylandEnablefalse /usr/share/X11/xorg.con…