GithubHelp home page GithubHelp logo

kkang2097 / simplex-gp Goto Github PK

View Code? Open in Web Editor NEW

This project forked from activatedgeek/simplex-gp

0.0 0.0 0.0 1015 KB

Lattice kernels for scalable Gaussian processes in GPyTorch (Simplex-GPs)

Home Page: https://go.sanyamkapoor.com/simplex-gp

License: Apache License 2.0

C++ 2.72% Python 6.77% Cuda 2.04% Jupyter Notebook 88.47%

simplex-gp's Introduction

Simplex-GPs

PyPI version

This repository hosts the code for SKIing on Simplices: Kernel Interpolation on the Permutohedral Lattice for Scalable Gaussian Processes (Simplex-GPs) by Sanyam Kapoor, Marc Finzi, Ke Alexander Wang, Andrew Gordon Wilson.

The Idea

Fast matrix-vector multiplies (MVMs) are the cornerstone of modern scalable Gaussian processes. By building upon the approximation proposed by Structured Kernel Interpolation (SKI), and leveraging advances in fast high-dimensional image filtering, Simplex-GPs approximate the computation of the kernel matrices by tiling the space using a sparse permutohedral lattice, instead of a rectangular grid.

The matrix-vector product implied by the kernel operations in SKI are now approximated via the three stages visualized above --- splat (projection onto the permutohedral lattice), blur (applying the blur operation as a matrix-vector product), and slice (re-projecting back into the original space).

This alleviates the curse of dimensionality associated with SKI operations, allowing them to scale beyond ~5 dimensions, and provides competitive advantages in terms of runtime and memory costs, at little expense of downstream performance. See our manuscript for complete details.

Usage

The lattice kernels are packaged as GPyTorch modules, and can be used as a fast approximation to either the RBFKernel or the MaternKernel. The corresponding replacement modules are RBFLattice and MaternLattice.

RBFLattice kernel is simple to use by changing a single line of code:

import gpytorch as gp
from gpytorch_lattice_kernel import RBFLattice

class SimplexGPModel(gp.models.ExactGP):
  def __init__(self, train_x, train_y):
    likelihood = gp.likelihoods.GaussianLikelihood()
    super().__init__(train_x, train_y, likelihood)

    self.mean_module = gp.means.ConstantMean()
    self.covar_module = gp.kernels.ScaleKernel(
-      gp.kernels.RBFKernel(ard_num_dims=train_x.size(-1))
+      RBFLattice(ard_num_dims=train_x.size(-1), order=1)
    )

  def forward(self, x):
    mean_x = self.mean_module(x)
    covar_x = self.covar_module(x)
    return gp.distributions.MultivariateNormal(mean_x, covar_x)

The GPyTorch Regression Tutorial provides a simpler example on toy data, where this kernel can be used as a drop-in replacement.

Install

To use the kernel in your code, install the package as:

pip install gpytorch-lattice-kernel

NOTE: The kernel is compiled lazily from source using CMake. If the compilation fails, you may need to install a more recent version. Additionally, ninja is required for compilation. One way to install is:

conda install -c conda-forge cmake ninja

Local Setup

For a local development setup, create the conda environment

$ conda env create -f environment.yml

Remember to add the root of the project to PYTHONPATH if not already.

$ export PYTHONPATH="$(pwd):${PYTHONPATH}"

Test

To verify the code is working as expected, a simple test file is provided, that tests for the training marginal likelihood achieved by Simplex-GPs and Exact-GPs. Run as:

python tests/train_snelson.py

The Snelson 1-D toy dataset is used. A copy is available in snelson.csv.

Results

The proposed kernel can be used with GPyTorch as usual. An example script to reproduce results is,

python experiments/train_simplexgp.py --dataset=elevators --data-dir=<path/to/uci/data/mat/files>

We use Fire to handle CLI arguments. All arguments of the main function are therefore valid arguments to the CLI.

All figures in the paper can be reproduced via notebooks.

NOTE: The UCI dataset mat files are available here.

License

Apache 2.0

simplex-gp's People

Contributors

activatedgeek avatar mfinzi avatar keawang avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.