您好,在跑这个代码的时候,因为rnn_cell.linear被移除,所以我手动加了一段linear(),但是却有报错……代码如下:
def linear(input, output_size, scope=None):
print('linear')
input_tensor = tf.convert_to_tensor(input_)
shape = input_tensor.get_shape().as_list()
if len(shape) != 2:
raise ValueError("Linear is expecting 2D arguments: %s" % str(shape))
if not shape[1]:
raise ValueError("Linear expects shape[1] of arguments: %s" % str(shape))
input_size = shape[1]
# Now the computation.
with tf.variable_scope(scope or "SimpleLinear"):
matrix = tf.get_variable("Matrix", [output_size, input_size], dtype=input_tensor.dtype)
bias_term = tf.get_variable("Bias", [output_size], dtype=input_tensor.dtype)
return tf.matmul(input_, tf.transpose(matrix)) + bias_term