import tensorflow as tf
import texar.tf as tx
hp = {
'type': 'LSTMCell',
'kwargs': {
'num_units': 256,
'forget_bias': 0.
},
'dropout': {'output_keep_prob': 1},
'num_layers': 1
}
encoder = tx.modules.UnidirectionalRNNEncoder(hparams={"rnn_cell": hp})
inputs = tf.zeros([32, 50, 256])
sequence_length = [26, 17, 13, 25, 36, 29, 25, 34, 11, 17, 10,
22, 23, 24, 33, 18, 21, 17, 22, 20, 34, 22,
40, 50, 19, 18, 14, 22, 14, 34, 22, 28]
tf.convert_to_tensor(sequence_length, dtype=tf.int32)
_, states = encoder(inputs, sequence_length)
initilizer = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(initilizer)
print("states[0]", sess.run(states[0]))
print("states[1]", sess.run(states[1]))
import torch
import texar.torch as tx
hp = {
'type': 'LSTMCell',
'kwargs': {
'num_units': 256,
'forget_bias': 0.
},
'dropout': {'output_keep_prob': 1},
'num_layers': 1
}
encoder = tx.modules.UnidirectionalRNNEncoder(input_size=256, hparams={"rnn_cell": hp})
inputs = torch.zeros([32, 50, 256])
sequence_length = torch.Tensor([26, 17, 13, 25, 36, 29, 25, 34, 11, 17, 10,
22, 23, 24, 33, 18, 21, 17, 22, 20, 34, 22,
40, 50, 19, 18, 14, 22, 14, 34, 22, 28]).to(torch.int32)
_, states = encoder(inputs, sequence_length)
print("states[0]", states[0])
print("states[1]", states[1])
states[1] [[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
...
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]]
states[1] [[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
...
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]]
states[0] tensor([[ 0.0346, 0.0295, -0.0203, ..., -0.0030, -0.0020, -0.0020],
[ 0.0345, 0.0294, -0.0203, ..., -0.0030, -0.0020, -0.0020],
[ 0.0345, 0.0294, -0.0202, ..., -0.0030, -0.0021, -0.0020],
...,
[ 0.0346, 0.0295, -0.0203, ..., -0.0030, -0.0020, -0.0020],
[ 0.0346, 0.0295, -0.0203, ..., -0.0030, -0.0020, -0.0020],
[ 0.0346, 0.0295, -0.0203, ..., -0.0030, -0.0020, -0.0020]],
grad_fn=<StackBackward>)
states[1] tensor([[ 0.0679, 0.0591, -0.0420, ..., -0.0061, -0.0039, -0.0040],
[ 0.0679, 0.0591, -0.0420, ..., -0.0061, -0.0039, -0.0040],
[ 0.0679, 0.0591, -0.0419, ..., -0.0061, -0.0042, -0.0041],
...,
[ 0.0679, 0.0591, -0.0420, ..., -0.0061, -0.0039, -0.0040],
[ 0.0679, 0.0591, -0.0420, ..., -0.0061, -0.0039, -0.0040],
[ 0.0679, 0.0591, -0.0420, ..., -0.0061, -0.0039, -0.0040]],