Press "Enter" to skip to content

利用PyTorch使用LSTM

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

nn.

 

PyTorch LSTM API文档

 

 

输入数据格式:

 

[seq_len, batch, input_size]
[num_layers, batch, hidden_size]
[num_layers, batch, hidden_size]

 

输出数据格式:

 

[seq_len, batch, hidden_size]
[num_layers, batch, hidden_size]
[num_layers, batch, hidden_size]

 

接下来看个具体的例子

 

import torch
import torch.nn as nn
lstm = nn.LSTM(input_size=100, hidden_size=20, num_layers=4)
x = torch.randn(10, 3, 100) # 一个句子10个单词,送进去3条句子,每个单词用一个100维的vector表示
out, (h, c) = lstm(x)
print(out.shape, h.shape, c.shape)
# torch.Size([10, 3, 20]) torch.Size([4, 3, 20]) torch.Size([4, 3, 20])

 

nn.LSTMCell

 

PyTorch LSTMCell API文档

 

 

和RNNCell类似,输入input_size的shape是 [batch, input_size]$h_t$和$c_t$的shape是 [batch, hidden_size]

 

看个一层的LSTM的例子

 

import torch
import torch.nn as nn
cell = nn.LSTMCell(input_size=100, hidden_size=20) # one layer LSTM
h = torch.zeros(3, 20)
c = torch.zeros(3, 20)
x = torch.randn(10, 3, 100)
for xt in x:
    h, c = cell(xt, [h, c])
print(h.shape, c.shape) # torch.Size([3, 20]) torch.Size([3, 20])

 

两层的LSTM例子

 

import torch
import torch.nn as nn
cell1 = nn.LSTMCell(input_size=100, hidden_size=30)
cell2 = nn.LSTMCell(input_size=30, hidden_size=20)
h1 = torch.zeros(3, 30)
c1 = torch.zeros(3, 30)
h2 = torch.zeros(3, 20)
c2 = torch.zeros(3, 20)
x = torch.randn(10, 3, 100)
for xt in x:
    h1, c1 = cell1(xt, [h1, c1])
    h2, c2 = cell2(h1, [h2, c2])
print(h2.shape, c2.shape) # torch.Size([3, 20]) torch.Size([3, 20])

Be First to Comment

发表评论

电子邮件地址不会被公开。 必填项已用*标注