Comments (5)
@zsgchinese can you refer me a related paper for concat attention? I am not sure which one you are talking about
from neural_sequence_labeling.
@IsaacChanghau http://www.emnlp2015.org/proceedings/EMNLP/pdf/EMNLP166.pdf
In 3.1 Global attention . in calcute scroce, you used the dot attention , i dont know how to wirte concat attention
from neural_sequence_labeling.
@zsgchinese Hi, I was on a long vacation, just come back, I will read the paper and try to write it for your reference.
from neural_sequence_labeling.
@zsgchinese Here is my thinkings about your request.
The dot-attention I used is inspired by the “Attention Is All You Need” (ref. https://arxiv.org/pdf/1706.03762.pdf), I think it is different from the dot method described in the paper you shared to me, and I also think that the seq2seq model of the paper you shared is not suitable to be used in the sequence labeling task here. Since machine translation task has two different inputs while training process, source language sentences (for encoding) and target language sentences (for decoding), while sequence labeling task only accepts single input.
But, only consider your request, I think here is two ways:
-
If you want to write some codes similar to the paper “Effective Approaches to Attention-based Neural Machine Translation” you shared to me, which is encoder-decoder (seq2seq) model. Tensorflow provides the comprehensive wrapper for it, you can follow its tutorials: https://www.tensorflow.org/tutorials/seq2seq
-
If not, I assume you want to use the similar mechanism described in the “Effective Approaches to Attention-based Neural Machine Translation” to compute each hidden output of the dynamic rnn. In this case, for each time slot, you need to consider the align weights according to the source hidden states, so I am thinking we can create an attention cell using tensorflow to tackle this issue. (not sure if this idea fits your requirement, but I test it works). See below:
The attention cell is built as:
import tensorflow as tf
from tensorflow.python.ops.rnn_cell import LSTMCell, GRUCell, RNNCell
from model.nns import dense
class AttentionCell(RNNCell):
"""A time-major Attention based RNN cell
ref: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn_cell_impl.py"""
def __init__(self, num_units, memory, cell_type='lstm'):
"""
:param num_units: number of hidden units in attention cell
:param memory: all the source hidden state, shape = (max_time, batch_size, dim)
:param cell_type: rnn cell type
"""
super(AttentionCell, self).__init__()
self._cell = LSTMCell(num_units) if cell_type == 'lstm' else GRUCell(num_units)
self.num_units = num_units
self.memory = memory
self.mem_units = memory.get_shape().as_list()[-1]
@property
def state_size(self):
return self._cell.state_size
@property
def output_size(self):
return self._cell.output_size
def __call__(self, inputs, state, scope=None):
c, m = state # c is previous cell state, m is previous hidden state
concat1 = tf.nn.tanh(tf.add(self.memory, dense(m, self.mem_units, use_bias=False, scope='concat')))
alphas = tf.squeeze(tf.exp(dense(concat1, hidden=1, use_bias=False, scope='raw_alphas')), axis=[-1])
alphas = tf.div(alphas, tf.reduce_sum(alphas, axis=0, keep_dims=True)) # (max_time, batch_size)
w_context = tf.reduce_sum(tf.multiply(self.memory, tf.expand_dims(alphas, axis=-1)), axis=0)
h, new_state = self._cell(inputs, state)
concat2 = tf.concat([w_context, h], axis=-1)
output = tf.nn.tanh(dense(concat2, self.num_units, use_bias=False, scope='dense'))
return output, new_state
To use the attention cell, assume you have a output from RNN layer, named as “context”:
The shape of “context” is (batch_size, max_time, num_units), since the attention cell I built is time_major based, you need to transpose the “context” first.
# ……
context = tf.transpose(context, [1, 0, 2]) # (max_time, batch_size, num_units)
att_cell = AttentionCell(num_units, context, cell_type=‘lstm’) # create attention cell
# using dynamic rnn to compute output
att_output, _ = dynamic_rnn(att_cell, context, sequence_length=self.seq_len, dtype=tf.float32, time_major=True)
# transpose att_output back to bach_major
att_output = tf.transpose(att_output, [1, 0, 2]). # shape = (batch_size, max_time, num_units)
# ……
Then you can derive the attentive rnn outputs, and you can use these outputs to do further things.
Thanks.
Plus, I do not have any GPUs by my side currently, so I just test whether it is able to compile, not really train the model comprehensively.
from neural_sequence_labeling.
@IsaacChanghau First of all , really thanks for your detailed reply. I have read your reply carefully. And I will say something about your answer with my understanding.
concat1 = tf.nn.tanh(tf.add(self.memory, dense(m, self.mem_units, use_bias=False, scope='concat')))
the shape of memory(context) is (batch_size, max_time, num_units), the shape of m is (batch_size, num_units)
the function call() should be called by every step of dynamic_rnn(that is decoder).
But the concat1 is the score ? (in the paper I refered in 3.1 Global Attention ) if it is , I dont find which method you use among three function(dot , general,concat).
from neural_sequence_labeling.
Related Issues (15)
- word emb and char emb doesn't match error HOT 4
- Evaluation code for Bi-LSTM+CRF model
- The attention usage
- pretrained models
- Error when run inference HOT 2
- 卡在93的位置是为啥
- TypeError: Fetch argument None has invalid type <class 'NoneType'>
- InvalidArgumentError in char_rep_method HOT 4
- Question for the NER performance HOT 1
- Tensorflow version HOT 1
- which attention architecture is used in NER? HOT 7
- unable to achive 91.82 in ner HOT 6
- Words Embeddings + BiLSTM + CRF structure HOT 2
- saving .pb format HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from neural_sequence_labeling.