Tensorflow:tf.contrib.rnn.DropoutWrapper函数(谷歌已经为Dropout申请了专利!)、MultiRNNCell函数的解读与理解
目录
1、tf.contrib.rnn.DropoutWrapper函数解读与理解
1.1、源代码解读
1.2、案例应用
2、tf.contrib.rnn.MultiRNNCell函数解读与理解
2.1、源代码解读
2.2、案例应用
tensorflow官网API文档:https://tensorflow.google.cn/api_docs
1、tf.contrib.rnn.DropoutWrapper函数解读与理解在机器学习的模型中,如果模型的参数太多,而训练样本又太少,训练出来的模型很容易产生过拟合的现象。在训练神经网络的时候经常会遇到过拟合的问题。过拟合具体表现在:模型在训练数据上损失函数较小,预测准确率较高;但是在测试数据上损失函数比较大,预测准确率较低。
机器学习模型训练中,过拟合现象实在令人头秃。而 2012 年 Geoffrey Hinton 提出的 Dropout 对防止过拟合有很好的效果。之后大量 Dropout 变体涌现,这项技术也成为机器学习研究者常用的训练 trick。万万没想到的是,谷歌为该项技术申请了专利,而且这项专利已经正式生效,2019-06-26 专利生效,2034-09-03 专利到期!
Dropout,指在神经网络中,每个神经单元在每次有数据流入时,以一定的概率keep_prob
正常工作,否则输出0值。这是一种有效的正则化方法,可以有效降低过拟合。在RNN中进行dropout时,对于RNN的部分不进行dropout,也就是说从t-1时候的状态传递到t时刻进行计算时,这个中间不进行memory的dropout;仅在同一个t时刻中,多层cell之间传递信息的时候进行dropout。在RNN中,这里的dropout是在输入,输出,或者不用的循环层之间使用,或者全连接层,不会在同一层的循环体中使用。
1.1、源代码解读 Operator adding dropout to inputs and outputs of the given cell.操作者将dropout添加到给定单元的输入和输出。
tf.compat.v1.nn.rnn_cell.DropoutWrapper(
*args, **kwargs
)
Args:
cell
: an RNNCell, a projection to output_size is added to it.input_keep_prob
: unit Tensor or float between 0 and 1, input keep probability; if it is constant and 1, no input dropout will be added.output_keep_prob
: unit Tensor or float between 0 and 1, output keep probability; if it is constant and 1, no output dropout will be added.state_keep_prob
: unit Tensor or float between 0 and 1, output keep probability; if it is constant and 1, no output dropout will be added. State dropout is performed on the outgoing states of the cell. Note the state components to which dropout is applied whenstate_keep_prob
is in(0, 1)
are also determined by the argumentdropout_state_filter_visitor
(e.g. by default dropout is never applied to thec
component of anLSTMStateTuple
).variational_recurrent
: Python bool. IfTrue
, then the same dropout pattern is applied across all time steps per run call. If this parameter is set,input_size
must be provided.input_size
: (optional) (possibly nested tuple of)TensorShape
objects containing the depth(s) of the input tensors expected to be passed in to theDropoutWrapper
. Required and used iffvariational_recurrent = True
andinput_keep_prob < 1
.dtype
: (optional) Thedtype
of the input, state, and output tensors. Required and used iffvariational_recurrent = True
.seed
: (optional) integer, the randomness seed.dropout_state_filter_visitor
: (optional), default: (see below). Function that takes any hierarchical level of the state and returns a scalar or depth=1 structure of Python booleans describing which terms in the state should be dropped out. In addition, if the function returnsTrue
, dropout is applied across this sublevel. If the function returnsFalse
, dropout is not applied across this entire sublevel. Default behavior: perform dropout on all terms except the memory (c
) state ofLSTMCellState
objects, and don't try to apply dropout toTensorArray
objects:def dropout_state_filter_visitor(s): if isinstance(s, LSTMCellState): # Never perform dropout on the c state. return LSTMCellState(c=False, h=True) elif isinstance(s, TensorArray): return False return True
**kwargs
: dict of keyword arguments for base layer.
- cell:一个RNNCell,向它添加一个到output_size的投影。
- input_keep_prob:单位张量或浮点数在0到1之间,输入保持概率;如果是常数和1,则不添加输入dropout。
- output_keep_prob:单位张量或浮动在0和1之间,输出保持概率;如果是常数和1,则不添加输出dropout。
- state_keep_prob:单位张量或浮点数在0到1之间,输出保持概率;如果是常数和1,则不添加输出dropout。状态退出是在计算单元的输出状态上执行的。注意,当state_keep_prob位于(0,1)中时,dropout应用到的状态组件也由argumentdropout_state_filter_visitor(例如。默认情况下,dropout从不应用于LSTMStateTuple的c组件)。
- variational_recurrent: Python布尔类型。如果为真,则在每次运行调用的所有时间步上应用相同的退出模式。如果设置了该参数,则必须提供input_size。
- input_size:(可选的)(可能嵌套的元组)TensorShape对象,包含期望传递给DropoutWrapper的输入张量的深度。需要和使用的iff variational_= True和input_keep_prob < 1。
- (可选)输入、状态和输出张量的dtype。需要和使用iffvariational_= True。
- 种子:(可选)整数,随机种子。
- dropout_state_filter_visitor:(可选),默认:(见下)。函数,该函数接受状态的任何层次结构,并返回一个标量或深度=1的Python布尔值结构,该结构描述应该删除状态中的哪些项。此外,如果函数返回True,则在此子层上应用dropout。如果函数返回False,则不会在整个子层上应用dropout。默认行为:除了LSTMCellState对象的内存(c)状态外,在所有条件下执行dropout,并且不要试图将dropout应用到TensorArray对象:def dropout_state_filter_visitor(s): if isinstance(s, LSTMCellState): #永远不要在c状态下执行dropout。返回LSTMCellState(c=False, h=True) elif isinstance(s, TensorArray):返回False返回True
- **kwargs:基层关键字参数的字典。
Methods
get_initial_state
View source
get_initial_state(
inputs=None, batch_size=None, dtype=None
)
zero_state
View source
zero_state(
batch_size, dtype
)
1.2、案例应用
相关文章:TF之LSTM:利用多层LSTM算法对MNIST手写数字识别数据集进行多分类
lstm_cell = rnn.BasicLSTMCell(num_units=hidden_size, forget_bias=1.0, state_is_tuple=True) #定义一层 LSTM_cell,只需要说明 hidden_size, 它会自动匹配输入的 X 的维度
lstm_cell = rnn.DropoutWrapper(cell=lstm_cell, input_keep_prob=1.0, output_keep_prob=keep_prob) #添加 dropout layer, 一般只设置 output_keep_prob
2、tf.contrib.rnn.MultiRNNCell函数解读与理解 2.1、源代码解读 RNN cell composed sequentially of multiple simple cells.RNN细胞由多个简单细胞依次组成。
tf.compat.v1.nn.rnn_cell.MultiRNNCell(
cells, state_is_tuple=True
)
Args:
cells
: list of RNNCells that will be composed in this order.state_is_tuple
: If True, accepted and returned states are n-tuples, wheren = len(cells)
. If False, the states are all concatenated along the column axis. This latter behavior will soon be deprecated.
参数:
单元格:按此顺序组成的RNNCells列表。 state_is_tuple:如果为真,则接受状态和返回状态为n元组,其中n = len(cell)。如果为假,则所有状态都沿着列轴连接。后一种行为很快就会被摒弃。
Methods
get_initial_state
View source
get_initial_state(
inputs=None, batch_size=None, dtype=None
)
zero_state
View source
zero_state(
batch_size, dtype
)
Return zero-filled state tensor(s).
Args:
batch_size
: int, float, or unit Tensor representing the batch size.dtype
: the data type to use for the state.
Returns:
If state_size
is an int or TensorShape, then the return value is a N-D
tensor of shape [batch_size, state_size]
filled with zeros.
If state_size
is a nested list or tuple, then the return value is a nested list or tuple (of the same structure) of 2-D
tensors with the shapes [batch_size, s]
for each s in state_size
.
返回
如果state_size是一个int或TensorShape,那么返回值就是一个包含0的shape [batch_size, state_size]的N-D张量。
如果state_size是一个嵌套列表或元组,那么返回值就是一个嵌套列表或元组(具有相同结构)的2-张量,其中每个s的形状[batch_size, s]为state_size中的每个s。
2.2、案例应用
相关文章:DL之LSTM:LSTM算法论文简介(原理、关键步骤、RNN/LSTM/GRU比较、单层和多层的LSTM)、案例应用之详细攻略
num_units = [128, 64]
cells = [BasicLSTMCell(num_units=n) for n in num_units]
stacked_rnn_cell = MultiRNNCell(cells)