GithubHelp home page GithubHelp logo

ml-probabilistic-attention's Introduction

Project Name

This software project accompanies the research paper Probabilistic Attention for Interactive Segmentation and its predecessor Probabilistic Transformers.

It contains implementation of a Pytorch module for the probabilistic attention update proposed in the above paper(s).

Documentation

Runs an update of the probabilistic version of attention based on a Mixture of Gaussians model.
It accepts the following parameters during a forward pass:

  • q: A tensor of queries with dims N, G, C, H
  • zeta: A tensor of keys (query/key Gaussian means) with dims N, G, C, H
  • alpha: A scalar (see special case above) or tensor of query/key Gaussian precisions with dims N, G, C, H
  • mu: A tensor of value Gaussian means with dims N, G, Cv, H
  • beta: A scalar (see special case above) or tensor of value Gaussian precisions with dims N, G, C, H
  • pi: A tensor of mixture component priors with dims N, G, H, H
  • v_init: A tensor of initial vals for the values with dims N, G, Cv, H (optional)
  • v_fixed: A tensor of fixed vals for the values with dims N, G, (Cv+1), H (optional). The extra (last) channel is an indicator for the fixed val locations
  • zeta_prior_precision: A tensor of precisions for the Gaussian prior over zeta with dims N, G, C, H (optional)
  • mu_prior_precision: A tensor of precisions for the Gaussian prior over mu with dims N, G, Cv, H (optional)
  • q_pos_emb: A tensor of query positional embeddings with dims C, H, H
  • zeta_pos_emb: A tensor of key positional embeddings with dims C, H, H
  • v_pos_emb: A tensor of value positional embeddings with dims Cv, H, H
  • nonzero_wts_mask: A boolean indexing tensor for setting weight matrix values to zero (where mask value is false) with dims H, H

And returns the following output tensor:

  • Updated values with dims N, G, Cv, H if no position embedding (v_pos_emb=None) else N, G, 2*Cv, H

Notably, this layer is equivalent to a standard dot product attention (without position embeddings) when:

  • uniform_query_precision = True
  • uniform_value_precision = True
  • magnitude_priors = True
  • alpha = 1/sqrt(C) (Could be a scalar to save some memory)
  • beta = 0 (Could be a scalar to save some memory)
  • v_init = None
  • v_fixed = None

Getting Started

The module is in the file probabilisticattention.py. It can be imported as any other Pytorch layer.

ml-probabilistic-attention's People

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.