GithubHelp home page GithubHelp logo

yjlolo / pytorch-deep-markov-model Goto Github PK

View Code? Open in Web Editor NEW
21.0 3.0 1.0 228 KB

PyTorch re-implementation of [Structured Inference Networks for Nonlinear State Space Models, AAAI 17]

License: MIT License

Python 100.00%
variational-autoencoders variational-inference sequential-data markov-model aaai pytorch-implementation reimplementation

pytorch-deep-markov-model's Introduction

pytorch-deep-markov-model

PyTorch re-implementatoin of the Deep Markov Model (https://arxiv.org/abs/1609.09869)

@inproceedings{10.5555/3298483.3298543,
    author = {Krishnan, Rahul G. and Shalit, Uri and Sontag, David},
    title = {Structured Inference Networks for Nonlinear State Space Models},
    year = {2017},
    publisher = {AAAI Press},
    booktitle = {Proceedings of the Thirty-First AAAI Conference on Artificial Intelligence},
    pages = {2101โ€“2109},
    numpages = {9},
    location = {San Francisco, California, USA},
    series = {AAAI'17}
}

Note:

  1. The calculated metrics in model/metrics.py do not match those reported in the paper, which might be (more likely) due to differences in parameter settings and metric calculations.
  2. The current implementatoin only supports JSB polyphonic music dataset.

Under-development

Refer to the branch factorial-dmm for a model described as Factorial DMM. The other branch refractor is trying to improve readability with increased options of models (DOCUMENT NOT UPDATED YET!).

Usage

Training the model with the default config.json:

python train.py -c config.json

add -i flag to specifically name the experiment that is to be saved under saved/.

config.json

This file specifies parameters and configurations. Below explains some key parameters.

A careful fine-tuning of the parameters seems necessary to match the reported performances.

{
    "arch": {
        "type": "DeepMarkovModel",
        "args": {
            "input_dim": 88,
            "z_dim": 100,
            "emission_dim": 100,
            "transition_dim": 200,
            "rnn_dim": 600,
            "rnn_type": "lstm",
            "rnn_layers": 1,
            "rnn_bidirection": false,   // condition z_t on both directions of inputs,
	    				// manually turn off `reverse_rnn_input` if True
					// (this is minor and could be quickly fixed)
            "use_embedding": true,      // use extra linear layer before RNN
            "orthogonal_init": true,    // orthogonal initialization for RNN
	    "gated_transition": true,       // use linear/non-linear gated transition
            "train_init": false,        // make z0 trainble
            "mean_field": false,        // use mean-field posterior q(z_t | x)
            "reverse_rnn_input": true,  // condition z_t on future inputs
            "sample": true              // sample during reparameterization
        }
    },
    "optimizer": {
        "type": "Adam",
        "args":{
            "lr": 0.0008,               // default value from the author's source code
            "weight_decay": 0.0,        // debugging stage indicates that 1.0 prevents training
            "amsgrad": true,
            "betas": [0.9, 0.999]
        }
    },
    "trainer": {
        "epochs": 3000,
        "overfit_single_batch": false,  // overfit one single batch for debug

        "save_dir": "saved/",
        "save_period": 500,
        "verbosity": 2,
        
        "monitor": "min val_loss",
        "early_stop": 100,

        "tensorboard": true,

        "min_anneal_factor": 0.0,
        "anneal_update": 5000
    }
}

References

  1. Project template brought from the pytorch-template
  2. The original source code in Theano
  3. PyTorch implementation in Pyro framework
  4. Another PyTorch implementation by @guxd

To-Do

  • fine-tune to match the reported performances in the paper
  • correct (if any) errors in metric calculation, model/metric.py
  • optimize important sampling

pytorch-deep-markov-model's People

Contributors

yjlolo avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Forkers

valeman

pytorch-deep-markov-model's Issues

Should tranisition function get the output of combiner ?

Hi @yjlolo

I started to read this paper and came across your implementation. It is quite neatly written.

As I was going through the code I noticed the following

mu_p, logvar_p = self.transition(z_prev)

Extracting the relevant parts here. Note the comment with prefix "KS"

for t in range(T_max):
            # q(z_t | z_{t-1}, x_{t:T})
            mu_q, logvar_q = self.combiner(h_rnn[:, t, :], z_prev,
                                           rnn_bidirection=self.rnn_bidirection)
            zt_q = self.reparameterization(mu_q, logvar_q)
            z_prev = zt_q
            # p(z_t | z_{t-1})
            # KS : You are supplying z_prev that is the output of combiner module
            #.        to the transition function
            mu_p, logvar_p = self.transition(z_prev)
            zt_p = self.reparameterization(mu_p, logvar_p)

The paper is not very clear about it but the way I have interpreted it so far is as follows -

The transition function (model) is a sort of a "learned prior". This is unlike regular VAE where a prior is assumed to be standard normal. So we use p for it.

The combiner function (model) is the approximate i.e. q and one would do KL on these p and q. You do that as well.

The part that is not clear in your above code is why self.transition should receive the input which is the output of combiner function. Shouldn't transition also get the original z_prev

If you look at another implementation of this model (in pytorch) https://github.com/guxd/deepHMM/blob/b1f596ee41f81c77ae20244c69092da184b5bcc9/models/dhmm.py#L83, the author uses the original z_prev

Honestly, the paper is not clear about it so not sure what is the right solution here.

Would appreciate your guidance and insight here

Regards
Kapil

Possible to parallelize emission

Hi @yjlolo

It's me again. Wanted to discuss if it is possible to parallelize emission.

Let's look at this snippet

        x_recon = torch.zeros([batch_size, T_max, self.input_dim]).to(x.device)
        mu_q_seq = torch.zeros([batch_size, T_max, self.z_dim]).to(x.device)
        logvar_q_seq = torch.zeros([batch_size, T_max, self.z_dim]).to(x.device)
        mu_p_seq = torch.zeros([batch_size, T_max, self.z_dim]).to(x.device)
        logvar_p_seq = torch.zeros([batch_size, T_max, self.z_dim]).to(x.device)
        z_q_seq = torch.zeros([batch_size, T_max, self.z_dim]).to(x.device)
        z_p_seq = torch.zeros([batch_size, T_max, self.z_dim]).to(x.device)
        for t in range(T_max):
            # q(z_t | z_{t-1}, x_{t:T})
            mu_q, logvar_q = self.combiner(h_rnn[:, t, :], z_prev,
                                           rnn_bidirection=self.rnn_bidirection)
            zt_q = self.reparameterization(mu_q, logvar_q)
            z_prev = zt_q
            # p(z_t | z_{t-1})
            mu_p, logvar_p = self.transition(z_prev)
            zt_p = self.reparameterization(mu_p, logvar_p)

            xt_recon = self.emitter(zt_q).contiguous()

            mu_q_seq[:, t, :] = mu_q
            logvar_q_seq[:, t, :] = logvar_q
            z_q_seq[:, t, :] = zt_q
            mu_p_seq[:, t, :] = mu_p
            logvar_p_seq[:, t, :] = logvar_p
            z_p_seq[:, t, :] = zt_p
            x_recon[:, t, :] = xt_recon

As per the above code self.emitter is called inside the loop (of time steps).

Here is a thought -

If the emitter (function/model) is written in a way that it takes (time_steps, z_dim) as the input shape instead (z_dim) then we can take it out of the for-loop.

Since you are storing z_q per time step (i.e. z_q_seq[:, t, :] = zt_q) we could simply then supply z_q_seq to the emitter function that takes input of shape (time_steps, z_dim).

What do you think about this? Am I am ignoring some aspect that makes the model invalid?

Regards
Kapil

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.