GithubHelp home page GithubHelp logo

henyaoyuan / pytorch-attention Goto Github PK

View Code? Open in Web Editor NEW

This project forked from thomlake/pytorch-attention

0.0 1.0 0.0 123 KB

pytorch neural network attention mechanism

License: BSD 2-Clause "Simplified" License

Python 100.00%

pytorch-attention's Introduction

def attend(query, context, value=None, score='dot', normalize='softmax',
           context_sizes=None, context_mask=None, return_weight=False
           ):
    """Attend to value (or context) by scoring each query and context.

    Args
    ----
    query: Variable of size (B, M, D1)
        Batch of M query vectors.
    context: Variable of size (B, N, D2)
        Batch of N context vectors.
    value: Variable of size (B, N, P), default=None
        If given, the output vectors will be weighted
        combinations of the value vectors.
        Otherwise, the context vectors will be used.
    score: str or callable, default='dot'
        If score == 'dot', scores are computed
        as the dot product between context and
        query vectors. This Requires D1 == D2.
        Otherwise, score should be a callable:
             query    context     score
            (B,M,D1) (B,N,D2) -> (B,M,N)
    normalize: str, default='softmax'
        One of 'softmax', 'sigmoid', or 'identity'.
        Name of function used to map scores to weights.
    context_mask: Tensor of (B, M, N), default=None
        A Tensor used to mask context. Masked
        and unmasked entries should be filled 
        appropriately for the normalization function.
    context_sizes: list[int], default=None,
        List giving the size of context for each item
        in the batch and used to compute a context_mask.
        If context_mask or context_sizes are not given,
        context is assumed to have fixed size.
    return_weight: bool, default=False
        If True, return the attention weight Tensor.

    Returns
    -------
    output: Variable of size (B, M, P)
        If return_weight is False.
    weight, output: Variable of size (B, M, N), Variable of size (B, M, P)
        If return_weight is True.
    """

Install

python setup.py install

Test

python -m pytest

Tested with pytorch 1.0.0

About

Attention is used to focus processing on a particular region of input. The attend function provided by this package implements the most common attention mechanism [1, 2, 3, 4], which produces an output by taking a weighted combination of value vectors with weights from a scoring function operating over pairs of query and context vectors.

Given query vector q, context vectors c_1,...,c_n, and value vectors v_1,...,v_n the attention score of q with c_i is given by

    s_i = f(q, c_i)

Frequently f takes the form of a dot product between query and context vectors.

    s_i = q^T c_i

The scores are passed through a normalization functions g (normally the softmax function).

    w_i = g(s_1,...,s_n)_i

Finally, the output is computed as a weighted sum of the value vectors.

    z = \sum_{i=1}^n w_i * v_i

In many applications [1, 4, 5] attention is applied to the context vectors themselves, v_i = c_i.

Sizes

This attend function provided by this package accepts batches of size B containing M query vectors of dimension D1, N context vectors of dimension D2, and optionally N value vectors of dimension P.

Variable Length

If the number of context vectors varies within a batch, a context can be ignored by forcing the corresponding weight to be zero.

In the case of the softmax, this can be achieved by adding negative infinity to the corresponding score before normalization. Similarly, for elementwise normalization functions the weights can be multiplied by an appropriate {0,1} mask after normalization.

To facilitate the above behavior, a context mask, with entries in {-inf, 0} or {0, 1} depending on the normalization function, can be passed to this function. The masks should have size (B, M, N).

Alternatively, a list can be passed giving the size of the context for each item in the batch. Appropriate masks will be created from these lists.

Note that the size of output does not depend on the number of context vectors. Because of this, context positions are truly unaccounted for in the output.

References

@article{bahdanau2014neural,
    title={Neural machine translation by jointly learning to align and translate},
    author={Bahdanau, Dzmitry and Cho, Kyunghyun and Bengio, Yoshua},
    journal={arXiv preprint arXiv:1409.0473},
    year={2014}
}
@article{graves2014neural,
  title={Neural turing machines},
  author={Graves, Alex and Wayne, Greg and Danihelka, Ivo},
  journal={arXiv preprint arXiv:1410.5401},
  year={2014}
}
@inproceedings{sukhbaatar2015end,
    title={End-to-end memory networks},
    author={Sukhbaatar, Sainbayar and Weston, Jason and Fergus, Rob and others},
    booktitle={Advances in neural information processing systems},
    pages={2440--2448},
    year={2015}
}
@article{olah2016attention,
    title={Attention and augmented recurrent neural networks},
    author={Olah, Chris and Carter, Shan},
    journal={Distill},
    volume={1},
    number={9},
    pages={e1},
    year={2016}
}
@inproceedings{vinyals2015pointer,
    title={Pointer networks},
    author={Vinyals, Oriol and Fortunato, Meire and Jaitly, Navdeep},
    booktitle={Advances in Neural Information Processing Systems},
    pages={2692--2700},
    year={2015}
}

pytorch-attention's People

Contributors

thomlake avatar dngros avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.