GithubHelp home page GithubHelp logo

dugu9sword / torch_random_fields Goto Github PK

View Code? Open in Web Editor NEW
12.0 2.0 2.0 167 KB

A highly optimized library for building markov random fields with pytorch.

Python 18.71% Jupyter Notebook 81.17% Shell 0.12%
crf mrf pytorch

torch_random_fields's Introduction

logo

torch_random_fields is a library for building markov random fields (MRF) with complex topology [1] [2] with pytorch, it is optimized for batch training on GPU.

The key features include:

  • Easy to plug into your research code
  • Support for batch acceleration of any random field with arbitary binary or ternary connections on the GPU
  • Fast training/inference with top-K logits, do not worry about too large label space
  • Support for context-aware transition matrix and low-rank factorization

You may cite this project by:

@inproceedings{
  wang2022regularized,
  title={Regularized Molecular Conformation Fields},
  author={Lihao Wang and Yi Zhou and Yiqun Wang and Xiaoqing Zheng and Xuanjing Huang and Hao Zhou},
  booktitle={Advances in Neural Information Processing Systems},
  editor={Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho},
  year={2022},
  url={https://openreview.net/forum?id=7XCFxnG8nGS}
}

Cases

Linear-Chain CRF

Check out the tutorial.

The well known linear-chain CRF which is heavily adopted in sequence labeling (POS-tagging, chunking, NER, etc.) is supported.

logo

Top-K Skip-Chain CRF

Check out the tutorial.

In torch_random_fields, any random field with arbitary topology is supported. To be more precise, we require binary connections, although in some case ternary connections are also supported (yes, I am lazy).

Here we show a case of Dynamic Skip-Chain CRF, where:

  • Some nodes (e.g., two nodes with the same words) are connected, which looks skipping the linear connection [3]
  • Only the top-3 labels for each node are kept, greatly speeding up training and inference [4]

logo

Ising/Potts Model

Ising model (or Potts model) is widely used in statistical physics and computational biology [5]. In this case, the random variables form a grid, but it can be fully connected.

logo

Features

Learning

  • Linear-Chain CRF:
    • maximum likelihood estimation
    • structured perceptron
    • piecewise training
    • pseudo-likelihood
  • General CRF:
    • structured perceptron
    • piecewise training
    • pseudo-likelihood

Inference

  • Linear-Chain CRF:

    • viterbi decoding
    • batch loopy belief propagation
    • batch mean field variational inference
  • General CRF:

    • batch loopy belief propagation
    • naive mean field variational inference
    • batch naive mean field inference

Sampling

  • Gibbs Sampling

Acknowledgement

Some implementation borrows from these great projects with modifications:

Reference

[1] An Introduction to Conditional Random Fields (Sutton and McCallum, 2010)

[2] Graphical Models, Exponential Families, and Variational Inference (Wainwright and Jordan, 2008)

[3] A Skip-Chain Conditional Random Field for Ranking Meeting Utterances by Importance (Galley, 2006)

[4] Fast Structured Decoding for Sequence Models (Sun, 2020)

[5] Improved contact prediction in proteins: Using pseudolikelihoods to infer Potts models (Ekeberg, 2013)

torch_random_fields's People

Contributors

dugu9sword avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

torch_random_fields's Issues

Question about LinearChainCRF._compute_score

Hello,

Thank you for sharing this code! I am interested in learning / adapting this code for my own project that'll need to run belief propagation efficiently. I was walking through the tutorial (tests/test_linear_chain_crf.ipynb) just now, and noticed a discrepancy between the edge score computation in forward and in _compute_score of LinearChainCRF.

In forward function, the score of an edge/state transition is computed as the bilinear formula E1 @ edge_wise @ E2, where E1 and E2 contain the static state embeddings (of the topk states) at the current and next timestep, and edge_wise contains the influence from the input sequence. On the other hand, in _compute_score you compute the edge scores as E1 @ E2, without considering in the influence of the input sequence (via edge_wise). Does this correspond to some sort of approximation, or perhaps edge_wise was meant to be multiplied in but accidentally left out?

I tried changing _compute_score to also use E1 @ edge_wise @ E2 as the score and it seems to improve the accuracy trajectory of perceptron and exact-likelihood training (which are the only two training methods that call _compute_score) on the tutorial toy dataset. Results are attached below.

Any advice is appreciated!

Thanks,
Brian

Original code results:

====== cost ======
iteration            100    200    300    400    500
-----------------  -----  -----  -----  -----  -----
piecewise          0.282  0.250  0.247  0.248  0.248
pseudo-likelihood  0.237  0.243  0.242  0.243  0.244
perceptron         0.551  0.555  0.565  0.551  0.551
exact-likelihood   1.176  1.188  1.528  1.514  1.644
====== accu ======
iteration            100    200    300    400    500
-----------------  -----  -----  -----  -----  -----
piecewise          0.872  0.888  0.917  0.935  0.949
pseudo-likelihood  0.856  0.874  0.891  0.908  0.907
perceptron         0.254  0.361  0.540  0.650  0.704
exact-likelihood   0.179  0.123  0.076  0.714  0.835

After changing to E1 @ edge_wise @ E2:

====== cost ======
iteration            100    200    300    400    500
-----------------  -----  -----  -----  -----  -----
piecewise          0.267  0.244  0.242  0.243  0.243
pseudo-likelihood  0.227  0.231  0.236  0.236  0.237
perceptron         0.604  0.617  0.614  0.612  0.613
exact-likelihood   1.091  1.146  1.152  1.163  1.179
====== accu ======
iteration            100    200    300    400    500
-----------------  -----  -----  -----  -----  -----
piecewise          0.568  0.908  0.920  0.924  0.939
pseudo-likelihood  0.867  0.878  0.898  0.908  0.908
perceptron         0.742  0.839  0.903  0.910  0.924
exact-likelihood   0.836  0.876  0.904  0.922  0.950

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.