GithubHelp home page GithubHelp logo

Comments (5)

26hzhang avatar 26hzhang commented on June 1, 2024

@zsgchinese can you refer me a related paper for concat attention? I am not sure which one you are talking about

from neural_sequence_labeling.

zsgchinese avatar zsgchinese commented on June 1, 2024

@IsaacChanghau http://www.emnlp2015.org/proceedings/EMNLP/pdf/EMNLP166.pdf
In 3.1 Global attention . in calcute scroce, you used the dot attention , i dont know how to wirte concat attention

from neural_sequence_labeling.

26hzhang avatar 26hzhang commented on June 1, 2024

@zsgchinese Hi, I was on a long vacation, just come back, I will read the paper and try to write it for your reference.

from neural_sequence_labeling.

26hzhang avatar 26hzhang commented on June 1, 2024

@zsgchinese Here is my thinkings about your request.

The dot-attention I used is inspired by the “Attention Is All You Need” (ref. https://arxiv.org/pdf/1706.03762.pdf), I think it is different from the dot method described in the paper you shared to me, and I also think that the seq2seq model of the paper you shared is not suitable to be used in the sequence labeling task here. Since machine translation task has two different inputs while training process, source language sentences (for encoding) and target language sentences (for decoding), while sequence labeling task only accepts single input.

But, only consider your request, I think here is two ways:

  1. If you want to write some codes similar to the paper “Effective Approaches to Attention-based Neural Machine Translation” you shared to me, which is encoder-decoder (seq2seq) model. Tensorflow provides the comprehensive wrapper for it, you can follow its tutorials: https://www.tensorflow.org/tutorials/seq2seq

  2. If not, I assume you want to use the similar mechanism described in the “Effective Approaches to Attention-based Neural Machine Translation” to compute each hidden output of the dynamic rnn. In this case, for each time slot, you need to consider the align weights according to the source hidden states, so I am thinking we can create an attention cell using tensorflow to tackle this issue. (not sure if this idea fits your requirement, but I test it works). See below:

The attention cell is built as:

import tensorflow as tf
from tensorflow.python.ops.rnn_cell import LSTMCell, GRUCell, RNNCell
from model.nns import dense


class AttentionCell(RNNCell):
    """A time-major Attention based RNN cell
    ref: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn_cell_impl.py"""
    def __init__(self, num_units, memory, cell_type='lstm'):
        """
        :param num_units: number of hidden units in attention cell
        :param memory: all the source hidden state, shape = (max_time, batch_size, dim)
        :param cell_type: rnn cell type
        """
        super(AttentionCell, self).__init__()
        self._cell = LSTMCell(num_units) if cell_type == 'lstm' else GRUCell(num_units)
        self.num_units = num_units
        self.memory = memory
        self.mem_units = memory.get_shape().as_list()[-1]

    @property
    def state_size(self):
        return self._cell.state_size

    @property
    def output_size(self):
        return self._cell.output_size

    def __call__(self, inputs, state, scope=None):
        c, m = state  # c is previous cell state, m is previous hidden state
        concat1 = tf.nn.tanh(tf.add(self.memory, dense(m, self.mem_units, use_bias=False, scope='concat')))
        alphas = tf.squeeze(tf.exp(dense(concat1, hidden=1, use_bias=False, scope='raw_alphas')), axis=[-1])
        alphas = tf.div(alphas, tf.reduce_sum(alphas, axis=0, keep_dims=True))  # (max_time, batch_size)
        w_context = tf.reduce_sum(tf.multiply(self.memory, tf.expand_dims(alphas, axis=-1)), axis=0)
        h, new_state = self._cell(inputs, state)
        concat2 = tf.concat([w_context, h], axis=-1)
        output = tf.nn.tanh(dense(concat2, self.num_units, use_bias=False, scope='dense'))
        return output, new_state

To use the attention cell, assume you have a output from RNN layer, named as “context”:
The shape of “context” is (batch_size, max_time, num_units), since the attention cell I built is time_major based, you need to transpose the “context” first.

# ……
context = tf.transpose(context, [1, 0, 2])  # (max_time, batch_size, num_units)
att_cell = AttentionCell(num_units, context, cell_type=lstm’)  # create attention cell
# using dynamic rnn to compute output
att_output, _ = dynamic_rnn(att_cell, context, sequence_length=self.seq_len, dtype=tf.float32, time_major=True)
# transpose att_output back to bach_major
att_output = tf.transpose(att_output, [1, 0, 2]). # shape = (batch_size, max_time, num_units)
# ……

Then you can derive the attentive rnn outputs, and you can use these outputs to do further things.

Thanks.

Plus, I do not have any GPUs by my side currently, so I just test whether it is able to compile, not really train the model comprehensively.

from neural_sequence_labeling.

zsgchinese avatar zsgchinese commented on June 1, 2024

@IsaacChanghau First of all , really thanks for your detailed reply. I have read your reply carefully. And I will say something about your answer with my understanding.
concat1 = tf.nn.tanh(tf.add(self.memory, dense(m, self.mem_units, use_bias=False, scope='concat')))
the shape of memory(context) is (batch_size, max_time, num_units), the shape of m is (batch_size, num_units)
the function call() should be called by every step of dynamic_rnn(that is decoder).
But the concat1 is the score ? (in the paper I refered in 3.1 Global Attention ) if it is , I dont find which method you use among three function(dot , general,concat).

from neural_sequence_labeling.

Related Issues (15)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.