Pytorch 学习(8):Recurrent layers (循环层)实现之GRUCell
GRU是Gated Recurrent Unit,GRU是LSTM的一个变化形式。先看一个GRUCell的小例子
rnn = nn.GRUCell(10, 20)
input = Variable(torch.randn(6, 3, 10))
hx = Variable(torch.randn(3, 20))
output = []
for i in range(6):
hx = rnn(input[i], hx)
output.append(hx)
在rnn = nn.GRUCell(10, 20)代码中,GRUCell进行初始化,GRUCell的__init__方法调用父类class Module(object)的__init__方法:
class Module(object):
r"""Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in
a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()