GithubHelp home page GithubHelp logo

jadentravnik / keras-mdn-layer Goto Github PK

View Code? Open in Web Editor NEW

This project forked from cpmpercussion/keras-mdn-layer

0.0 1.0 0.0 3.61 MB

An MDN Layer for Keras using TensorFlow's distributions module

License: MIT License

Python 15.17% Jupyter Notebook 84.83%

keras-mdn-layer's Introduction

Keras Mixture Density Network Layer

Build Status MIT License DOI PyPI version

A mixture density network (MDN) Layer for Keras using TensorFlow's distributions module. This makes it a bit more simple to experiment with neural networks that predict multiple real-valued variables that can take on multiple equally likely values.

This layer can help build MDN-RNNs similar to those used in RoboJam, Sketch-RNN, handwriting generation, and maybe even world models. You can do a lot of cool stuff with MDNs!

One benefit of this implementation is that you can predict any number of real-values. TensorFlow's Mixture, Categorical, and MultivariateNormalDiag distribution functions are used to generate the loss function (the probability density function of a mixture of multivariate normal distributions with a diagonal covariance matrix). In previous work, the loss function has often been specified by hand which is fine for 1D or 2D prediction, but becomes a bit more annoying after that.

Two important functions are provided for training and prediction:

  • get_mixture_loss_func(output_dim, num_mixtures): This function generates a loss function with the correct output dimensiona and number of mixtures.
  • sample_from_output(params, output_dim, num_mixtures, temp=1.0): This functions samples from the mixture distribution output by the model.

Installation

This project requires Python 3.6+. You can easily install this package from PyPI via pip like so:

python3 -m pip install keras-mdn-layer

And finally, import the mdn module in Python: import mdn

Alternatively, you can clone or download this repository and then install via python setup.py install, or copy the mdn folder into your own project.

Examples

Some examples are provided in the notebooks directory.

There's scripts for fitting multivalued functions, a standard MDN toy problem:

Keras MDN Demo

There's also a script for generating fake kanji characters:

kanji test 1

And finally, for learning how to generate musical touch-screen performances with a temporal component:

Robojam Model Examples

How to use

The MDN layer should be the last in your network and you should use get_mixture_loss_func to generate a loss function. Here's an example of a simple network with one Dense layer followed by the MDN.

import keras
import mdn

N_HIDDEN = 15  # number of hidden units in the Dense layer
N_MIXES = 10  # number of mixture components
OUTPUT_DIMS = 2  # number of real-values predicted by each mixture component

model = keras.Sequential()
model.add(keras.layers.Dense(N_HIDDEN, batch_input_shape=(None, 1), activation='relu'))
model.add(mdn.MDN(OUTPUT_DIMS, N_MIXES))
model.compile(loss=mdn.get_mixture_loss_func(OUTPUT_DIMS,N_MIXES), optimizer=keras.optimizers.Adam())
model.summary()

Fit as normal:

history = model.fit(x=x_train, y=y_train)

The predictions from the network are parameters of the mixture models, so you have to apply the sample_from_output function to generate samples.

y_test = model.predict(x_test)
y_samples = np.apply_along_axis(sample_from_output, 1, y_test, OUTPUT_DIMS, N_MIXES, temp=1.0)

See the notebooks directory for examples in jupyter notebooks!

Load/Save Model

Saving models is straight forward:

model.save('test_save.h5')

But loading requires cutom_objects to be filled with the MDN layer, and a loss function with the appropriate parameters:

m_2 = keras.models.load_model('test_save.h5', custom_objects={'MDN': mdn.MDN, 'mdn_loss_func': mdn.get_mixture_loss_func(1, N_MIXES)})

Acknowledgements

References

  1. Christopher M. Bishop. 1994. Mixture Density Networks. Technical Report NCRG/94/004. Neural Computing Research Group, Aston University. http://publications.aston.ac.uk/373/
  2. Axel Brando. 2017. Mixture Density Networks (MDN) for distribution and uncertainty estimation. Master’s thesis. Universitat Politècnica de Catalunya.
  3. A. Graves. 2013. Generating Sequences With Recurrent Neural Networks. ArXiv e-prints (Aug. 2013). https://arxiv.org/abs/1308.0850
  4. David Ha and Douglas Eck. 2017. A Neural Representation of Sketch Drawings. ArXiv e-prints (April 2017). https://arxiv.org/abs/1704.03477
  5. Charles P. Martin and Jim Torresen. 2018. RoboJam: A Musical Mixture Density Network for Collaborative Touchscreen Interaction. In Evolutionary and Biologically Inspired Music, Sound, Art and Design: EvoMUSART ’18, A. Liapis et al. (Ed.). Lecture Notes in Computer Science, Vol. 10783. Springer International Publishing. DOI:10.1007/9778-3-319-77583-8_11

keras-mdn-layer's People

Contributors

cpmpercussion avatar duhaime avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.