GithubHelp home page GithubHelp logo

gokyeongryeol / neural-process-family Goto Github PK

View Code? Open in Web Editor NEW
0.0 1.0 0.0 2.31 MB

Pytorch implementation of Neural Process Family (CNP, NP, ANP, Meta-Fun) for functions and images

Python 100.00%

neural-process-family's Introduction

Neural Process Family

Pytorch implementation of neural process families (CNP, NP, ANP, Meta-Fun) on both regression and classification.

How the studies have evolved

Neural Process Family is devised to alternate the gaussian process with a scalable neural network. Specifically, given a set of input-output pairs {Cx, Cy} and some input Tx, the model is expected to credibly estimate its corresponding output Ty. Hence, it follows the problem setting of few-shot learning and exploits an encoder-decoder pipeline. A permutation invariant set encoding r is first extracted from the {Cx, Cy}, and Tx is then feedforwarded to decoder along with r to estimate the parameters of the distribution of Ty.

Conditional Neural Process(CNP) was the first instantiation of the studies, which is trained to maximize the marginal likelihood. Neural Process(NP) is a simple extension to CNP by incorporating the stochastic latent variable following Variational AutoEncoder(VAE). Since the marginal likelihood is no more tractable, variational inference technique is applied and the model is trained by maximizing the Evidence Lower Bound(ELBO). Based on the Kolmogorov Extension Theorem, NP is proven to be a stochastic process, however, many complex modules are required to avoid underfitting due to difficulty on approximate inference in nature of probabilistic models.

Attentive Neural Process(ANP) is one of the follow-up studies based on this idea such that multi-head attention and self-attention introduced in Transformer is used to consider the dependency between the set elements. Meta-Fun further bridges to the functional gradient descent that implictly relaxes to the infinite dimensional representation space.

Implementation

See data_loader.py for preparing the dataset for 1D gaussian process regression and 2D image completion.

See model.py for comparing how the models feedforward the set data. Although Neural Process Family is devised to mainly deal with the regression problems, we also expand to classification problems using the linear classifier proposed in several meta-learning algorithms such as VERSA, LEO, CNAP. Specifically, a linear classifier for certain class label (e.g. 1) is constructed by processing a subset of the context input {Cx | Cy = 1} of the label.

See main.py, for the hyperparameter setting and the overall training process.

As a pilot study, the models are evaluated on Gaussian Process regression and CIFAR10 image completion. Followings are some qualitative analysis on performance. (For a detailed empirical comparison between the neural process families, refer to paper.)

While ANP converges much faster than NP and better handles underfitting, its interpolation and extrapolation turns out to be wiggly.

,

The prediction gets accurate as the number of context dataset increases.

neural-process-family's People

Contributors

gokyeongryeol avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.