GithubHelp home page GithubHelp logo

wesselb / stheno Goto Github PK

View Code? Open in Web Editor NEW
213.0 10.0 20.0 11.81 MB

Gaussian process modelling in Python

License: MIT License

Python 98.93% Makefile 1.07%
machine-learning gaussian-processes python

stheno's Introduction

CI Coverage Status Latest Docs Code style: black

Stheno is an implementation of Gaussian process modelling in Python. See also Stheno.jl.

Check out our post about linear models with Stheno and JAX.

Contents:

Nonlinear Regression in 20 Seconds

>>> import numpy as np

>>> from stheno import GP, EQ

>>> x = np.linspace(0, 2, 10)           # Some points to predict at

>>> y = x ** 2                          # Some observations

>>> f = GP(EQ())                        # Construct Gaussian process.

>>> f_post = f | (f(x), y)              # Compute the posterior.

>>> pred = f_post(np.array([1, 2, 3]))  # Predict!

>>> pred.mean
<dense matrix: shape=3x1, dtype=float64
 mat=[[1.   ]
      [4.   ]
      [8.483]]>

>>> pred.var
<dense matrix: shape=3x3, dtype=float64
 mat=[[ 8.032e-13  7.772e-16 -4.577e-09]
      [ 7.772e-16  9.999e-13  2.773e-10]
      [-4.577e-09  2.773e-10  3.313e-03]]>

These custom matrix types are there to accelerate the underlying linear algebra. To get vanilla NumPy/AutoGrad/TensorFlow/PyTorch/JAX arrays, use B.dense:

>>> from lab import B

>>> B.dense(pred.mean)
array([[1.00000068],
       [3.99999999],
       [8.4825932 ]])

>>> B.dense(pred.var)
array([[ 8.03246358e-13,  7.77156117e-16, -4.57690943e-09],
       [ 7.77156117e-16,  9.99866856e-13,  2.77333267e-10],
       [-4.57690943e-09,  2.77333267e-10,  3.31283378e-03]])

Moar?! Then read on!

Installation

pip install stheno

Manual

Note: here is a nicely rendered and more readable version of the docs.

AutoGrad, TensorFlow, PyTorch, or JAX? Your Choice!

from stheno.autograd import GP, EQ
from stheno.tensorflow import GP, EQ
from stheno.torch import GP, EQ
from stheno.jax import GP, EQ

Model Design

The basic building block is a f = GP(mean=0, kernel, measure=prior), which takes in a mean, a kernel, and a measure. The mean and kernel of a GP can be extracted with f.mean and f.kernel. The measure should be thought of as a big joint distribution that assigns a mean and a kernel to every variable f. A measure can be created with prior = Measure(). A GP f can have different means and kernels under different measures. For example, under some prior measure, f can have an EQ() kernel; but, under some posterior measure, f has a kernel that is determined by the posterior distribution of a GP. We will see later how posterior measures can be constructed. The measure with which a f = GP(kernel, measure=prior) is constructed can be extracted with f.measure == prior. If the keyword argument measure is not set, then automatically a new measure is created, which afterwards can be extracted with f.measure.

Definition, where prior = Measure():

f = GP(kernel)

f = GP(mean, kernel)

f = GP(kernel, measure=prior)

f = GP(mean, kernel, measure=prior)

GPs that are associated to the same measure can be combined into new GPs, which is the primary mechanism used to build cool models.

Here's an example model:

>>> prior = Measure()

>>> f1 = GP(lambda x: x ** 2, EQ(), measure=prior)

>>> f1
GP(<lambda>, EQ())

>>> f2 = GP(Linear(), measure=prior)

>>> f2
GP(0, Linear())

>>> f_sum = f1 + f2

>>> f_sum
GP(<lambda>, EQ() + Linear())

>>> f_sum + GP(EQ())  # Not valid: `GP(EQ())` belongs to a new measure!
AssertionError: Processes GP(<lambda>, EQ() + Linear()) and GP(0, EQ()) are associated to different measures.

To avoid setting the keyword measure for every GP that you create, you can enter a measure as a context:

>>> with Measure() as prior:
        f1 = GP(lambda x: x ** 2, EQ())
        f2 = GP(Linear())
        f_sum = f1 + f2

>>> prior == f1.measure == f2.measure == f_sum.measure
True

Compositional Design

  • Add and subtract GPs and other objects.

    Example:

    >>> GP(EQ(), measure=prior) + GP(Exp(), measure=prior)
    GP(0, EQ() + Exp())
    
    >>> GP(EQ(), measure=prior) + GP(EQ(), measure=prior)
    GP(0, 2 * EQ())
    
    >>> GP(EQ()) + 1
    GP(1, EQ())
    
    >>> GP(EQ()) + 0
    GP(0, EQ())
    
    >>> GP(EQ()) + (lambda x: x ** 2)
    GP(<lambda>, EQ())
    
    >>> GP(2, EQ(), measure=prior) - GP(1, EQ(), measure=prior)
    GP(1, 2 * EQ())
  • Multiply GPs and other objects.

    Warning: The product of two GPs it not a Gaussian process. Stheno approximates the resulting process by moment matching.

    Example:

    >>> GP(1, EQ(), measure=prior) * GP(1, Exp(), measure=prior)
    GP(<lambda> + <lambda> + -1 * 1, <lambda> * Exp() + <lambda> * EQ() + EQ() * Exp())
    
    >>> 2 * GP(EQ())
    GP(2, 2 * EQ())
    
    >>> 0 * GP(EQ())
    GP(0, 0)
    
    >>> (lambda x: x) * GP(EQ())
    GP(0, <lambda> * EQ())
  • Shift GPs.

    Example:

    >>> GP(EQ()).shift(1)
    GP(0, EQ() shift 1) 
  • Stretch GPs.

    Example:

    >>> GP(EQ()).stretch(2)
    GP(0, EQ() > 2)
  • Select particular input dimensions.

    Example:

    >>> GP(EQ()).select(1, 3)
    GP(0, EQ() : [1, 3])
  • Transform the input.

    Example:

    >>> GP(EQ()).transform(f)
    GP(0, EQ() transform f)
  • Numerically take the derivative of a GP. The argument specifies which dimension to take the derivative with respect to.

    Example:

    >>> GP(EQ()).diff(1)
    GP(0, d(1) EQ())
  • Construct a finite difference estimate of the derivative of a GP. See Measure.diff_approx for a description of the arguments.

    Example:

    >>> GP(EQ()).diff_approx(deriv=1, order=2)
    GP(50000000.0 * (0.5 * EQ() + 0.5 * ((-0.5 * (EQ() shift (0.0001414213562373095, 0))) shift (0, -0.0001414213562373095)) + 0.5 * ((-0.5 * (EQ() shift (0, 0.0001414213562373095))) shift (-0.0001414213562373095, 0))), 0)
  • Construct the Cartesian product of a collection of GPs.

    Example:

    >>> prior = Measure()
    
    >>> f1, f2 = GP(EQ(), measure=prior), GP(EQ(), measure=prior)
    
    >>> cross(f1, f2)
    GP(MultiOutputMean(0, 0), MultiOutputKernel(EQ(), EQ()))

Displaying GPs

GPs have a display method that accepts a formatter.

Example:

>>> print(GP(2.12345 * EQ()).display(lambda x: f"{x:.2f}"))
GP(2.12 * EQ(), 0)

Properties of GPs

Properties of kernels can be queried on GPs directly.

Example:

>>> GP(EQ()).stationary
True

Naming GPs

It is possible to give a name to a GP. Names must be strings. A measure then behaves like a two-way dictionary between GPs and their names.

Example:

>>> prior = Measure()

>>> p = GP(EQ(), name="name", measure=prior)

>>> p.name
'name'

>>> p.name = "alternative_name"

>>> prior["alternative_name"]
GP(0, EQ())

>>> prior[p]
'alternative_name'

Finite-Dimensional Distributions

Simply call a GP to construct a finite-dimensional distribution at some inputs. You can give a second argument, which specifies the variance of additional additive noise. After constructing a finite-dimensional distribution, you can compute the mean, the variance, sample, or compute a logpdf.

Definition, where f is a GP:

f(x)         # No additional noise

f(x, noise)  # Additional noise with variance `noise`

Things you can do with a finite-dimensional distribution:

  • Use f(x).mean to compute the mean.

  • Use f(x).var to compute the variance.

  • Use f(x).mean_var to compute simultaneously compute the mean and variance. This can be substantially more efficient than calling first f(x).mean and then f(x).var.

  • Use Normal.sample to sample.

    Definition:

    f(x).sample()                # Produce one sample.
    
    f(x).sample(n)               # Produce `n` samples.
    
    f(x).sample(noise=noise)     # Produce one samples with additional noise variance `noise`.
    
    f(x).sample(n, noise=noise)  # Produce `n` samples with additional noise variance `noise`.
  • Use f(x).logpdf(y) to compute the logpdf of some data y.

  • Use means, variances = f(x).marginals() to efficiently compute the marginal means and marginal variances.

    Example:

    >>> f(x).marginals()
    (array([0., 0., 0.]), np.array([1., 1., 1.]))
  • Use means, lowers, uppers = f(x).marginal_credible_bounds() to efficiently compute the means and the marginal lower and upper 95% central credible region bounds.

    Example:

    >>> f(x).marginal_credible_bounds()
    (array([0., 0., 0.]), array([-1.96, -1.96, -1.96]), array([1.96, 1.96, 1.96]))
  • Use Measure.logpdf to compute the joint logpdf of multiple observations.

    Definition, where prior = Measure():

    prior.logpdf(f(x), y)
    
    prior.logpdf((f1(x1), y1), (f2(x2), y2), ...)
  • Use Measure.sample to jointly sample multiple observations.

    Definition, where prior = Measure():

    sample = prior.sample(f(x))
    
    sample1, sample2, ... = prior.sample(f1(x1), f2(x2), ...)

Example:

>>> prior = Measure()

>>> f = GP(EQ(), measure=prior)

>>> x = np.array([0., 1., 2.])

>>> f(x)       # FDD without noise.
<FDD:
 process=GP(0, EQ()),
 input=array([0., 1., 2.]),
 noise=<zero matrix: shape=3x3, dtype=float64>

>>> f(x, 0.1)  # FDD with noise.
<FDD:
 process=GP(0, EQ()),
 input=array([0., 1., 2.]),
 noise=<diagonal matrix: shape=3x3, dtype=float64
        diag=[0.1 0.1 0.1]>>

>>> f(x).mean
array([[0.],
       [0.],
       [0.]])

>>> f(x).var
<dense matrix: shape=3x3, dtype=float64
 mat=[[1.    0.607 0.135]
      [0.607 1.    0.607]
      [0.135 0.607 1.   ]]>
       
>>> y1 = f(x).sample()

>>> y1
array([[-0.45172746],
       [ 0.46581948],
       [ 0.78929767]])
       
>>> f(x).logpdf(y1)
-2.811609567720761

>>> y2 = f(x).sample(2)
array([[-0.43771276, -2.36741858],
       [ 0.86080043, -1.22503079],
       [ 2.15779126, -0.75319405]]

>>> f(x).logpdf(y2)
 array([-4.82949038, -5.40084225])

Prior and Posterior Measures

Conditioning a prior measure on observations gives a posterior measure. To condition a measure on observations, use Measure.__or__.

Definition, where prior = Measure() and f* are GPs:

post = prior | (f(x, [noise]), y)

post = prior | ((f1(x1, [noise1]), y1), (f2(x2, [noise2]), y2), ...)

You can then obtain a posterior process with post(f) and a finite-dimensional distribution under the posterior with post(f(x)). Alternatively, the posterior of a process f can be obtained by conditioning f directly.

Definition, where and f* are GPs:

f_post = f | (f(x, [noise]), y)

f_post = f | ((f1(x1, [noise1]), y1), (f2(x2, [noise2]), y2), ...)

Let's consider an example. First, build a model and sample some values.

>>> prior = Measure()

>>> f = GP(EQ(), measure=prior)

>>> x = np.array([0., 1., 2.])

>>> y = f(x).sample()

Then compute the posterior measure.

>>> post = prior | (f(x), y)

>>> post(f)
GP(PosteriorMean(), PosteriorKernel())

>>> post(f).mean(x)
<dense matrix: shape=3x1, dtype=float64
 mat=[[ 0.412]
      [-0.811]
      [-0.933]]>

>>> post(f).kernel(x)
<dense matrix: shape=3x3, dtype=float64
 mat=[[1.e-12 0.e+00 0.e+00]
      [0.e+00 1.e-12 0.e+00]
      [0.e+00 0.e+00 1.e-12]]>

>>> post(f(x))
<FDD:
 process=GP(PosteriorMean(), PosteriorKernel()),
 input=array([0., 1., 2.]),
 noise=<zero matrix: shape=3x3, dtype=float64>>

>>> post(f(x)).mean
<dense matrix: shape=3x1, dtype=float64
 mat=[[ 0.412]
      [-0.811]
      [-0.933]]>

>>> post(f(x)).var
<dense matrix: shape=3x3, dtype=float64
 mat=[[1.e-12 0.e+00 0.e+00]
      [0.e+00 1.e-12 0.e+00]
      [0.e+00 0.e+00 1.e-12]]>

We can also obtain the posterior by conditioning f directly:

>>> f_post = f | (f(x), y)

>>> f_post
GP(PosteriorMean(), PosteriorKernel())

>>> f_post.mean(x)
<dense matrix: shape=3x1, dtype=float64
 mat=[[ 0.412]
      [-0.811]
      [-0.933]]>

>>> f_post.kernel(x)
<dense matrix: shape=3x3, dtype=float64
 mat=[[1.e-12 0.e+00 0.e+00]
      [0.e+00 1.e-12 0.e+00]
      [0.e+00 0.e+00 1.e-12]]>

>>> f_post(x)
<FDD:
 process=GP(PosteriorMean(), PosteriorKernel()),
 input=array([0., 1., 2.]),
 noise=<zero matrix: shape=3x3, dtype=float64>>

>>> f_post(x).mean
<dense matrix: shape=3x1, dtype=float64
 mat=[[ 0.412]
      [-0.811]
      [-0.933]]>

>>> f_post(x).var
<dense matrix: shape=3x3, dtype=float64
 mat=[[1.e-12 0.e+00 0.e+00]
      [0.e+00 1.e-12 0.e+00]
      [0.e+00 0.e+00 1.e-12]]>

We can further extend our model by building on the posterior.

>>> g = GP(Linear(), measure=post)

>>> f_sum = post(f) + g

>>> f_sum
GP(PosteriorMean(), PosteriorKernel() + Linear())

However, what we cannot do is mixing the prior and posterior.

>>> f + g
AssertionError: Processes GP(0, EQ()) and GP(0, Linear()) are associated to different measures.

Inducing Points

Stheno supports pseudo-point approximations of posterior distributions with various approximation methods:

  1. The Variational Free Energy (VFE; Titsias, 2009) approximation. To use the VFE approximation, use PseudoObs.

  2. The Fully Independent Training Conditional (FITC; Snelson & Ghahramani, 2006) approximation. To use the FITC approximation, use PseudoObsFITC.

  3. The Deterministic Training Conditional (DTC; Csato & Opper, 2002; Seeger et al., 2003) approximation. To use the DTC approximation, use PseudoObsDTC.

The VFE approximation (PseudoObs) is the approximation recommended to use. The following definitions and examples will use the VFE approximation with PseudoObs, but every instance of PseudoObs can be swapped out for PseudoObsFITC or PseudoObsDTC.

Definition:

obs = PseudoObs(
    u(z),               # FDD of inducing points
    (f(x, [noise]), y)  # Observed data
)
                
obs = PseudoObs(u(z), f(x, [noise]), y)

obs = PseudoObs(u(z), (f1(x1, [noise1]), y1), (f2(x2, [noise2]), y2), ...)

obs = PseudoObs((u1(z1), u2(z2), ...), f(x, [noise]), y)

obs = PseudoObs((u1(z1), u2(z2), ...), (f1(x1, [noise1]), y1), (f2(x2, [noise2]), y2), ...)

The approximate posterior measure can be constructed with prior | obs where prior = Measure() is the measure of your model. To quantify the quality of the approximation, you can compute the ELBO with obs.elbo(prior).

Let's consider an example. First, build a model and sample some noisy observations.

>>> prior = Measure()

>>> f = GP(EQ(), measure=prior)

>>> x_obs = np.linspace(0, 10, 2000)

>>> y_obs = f(x_obs, 1).sample()

Ouch, computing the logpdf is quite slow:

>>> %timeit f(x_obs, 1).logpdf(y_obs)
219 ms ยฑ 35.7 ms per loop (mean ยฑ std. dev. of 7 runs, 10 loops each)

Let's try to use inducing points to speed this up.

>>> x_ind = np.linspace(0, 10, 100)

>>> u = f(x_ind)   # FDD of inducing points.

>>> %timeit PseudoObs(u, f(x_obs, 1), y_obs).elbo(prior)
9.8 ms ยฑ 181 ยตs per loop (mean ยฑ std. dev. of 7 runs, 100 loops each)

Much better. And the approximation is good:

>>> PseudoObs(u, f(x_obs, 1), y_obs).elbo(prior) - f(x_obs, 1).logpdf(y_obs)
-3.537934389896691e-10

We finally construct the approximate posterior measure:

>>> post_approx = prior | PseudoObs(u, f(x_obs, 1), y_obs)

>>> post_approx(f(x_obs)).mean
<dense matrix: shape=2000x1, dtype=float64
 mat=[[0.469]
      [0.468]
      [0.467]
      ...
      [1.09 ]
      [1.09 ]
      [1.091]]>

Kernels and Means

See MLKernels.

Batched Computation

Stheno supports batched computation. See MLKernels for a description of how means and kernels work with batched computation.

Example:

>>> f = GP(EQ())

>>> x = np.random.randn(16, 100, 1)

>>> y = f(x, 1).sample()

>>> logpdf = f(x, 1).logpdf(y)

>>> y.shape
(16, 100, 1)

>>> f(x, 1).logpdf(y).shape
(16,)

Important Remarks

Stheno uses LAB to provide an implementation that is backend agnostic. Moreover, Stheno uses an extension of LAB to accelerate linear algebra with structured linear algebra primitives. You will encounter these primitives:

>>> k = 2 * Delta()

>>> x = np.linspace(0, 5, 10)

>>> k(x)
<diagonal matrix: shape=10x10, dtype=float64
 diag=[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]>

If you're using LAB to further process these matrices, then there is absolutely no need to worry: these structured matrix types know how to add, multiply, and do other linear algebra operations.

>>> import lab as B

>>> B.matmul(k(x), k(x))
<diagonal matrix: shape=10x10, dtype=float64
 diag=[4. 4. 4. 4. 4. 4. 4. 4. 4. 4.]>

If you're not using LAB, you can convert these structured primitives to regular NumPy/TensorFlow/PyTorch/JAX arrays by calling B.dense (B is from LAB):

>>> import lab as B

>>> B.dense(k(x))
array([[2., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 2., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 2., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 2., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 2., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 2., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 2., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 2., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 2., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 2.]])

Furthermore, before computing a Cholesky decomposition, Stheno always adds a minuscule diagonal to prevent the Cholesky decomposition from failing due to positive indefiniteness caused by numerical noise. You can change the magnitude of this diagonal by changing B.epsilon:

>>> import lab as B

>>> B.epsilon = 1e-12   # Default regularisation

>>> B.epsilon = 1e-8    # Strong regularisation

Examples

The examples make use of Varz and some utility from WBML.

Simple Regression

Prediction

import matplotlib.pyplot as plt
from wbml.plot import tweak

from stheno import B, GP, EQ

# Define points to predict at.
x = B.linspace(0, 10, 100)
x_obs = B.linspace(0, 7, 20)

# Construct a prior.
f = GP(EQ().periodic(5.0))

# Sample a true, underlying function and noisy observations.
f_true, y_obs = f.measure.sample(f(x), f(x_obs, 0.5))

# Now condition on the observations to make predictions.
f_post = f | (f(x_obs, 0.5), y_obs)
mean, lower, upper = f_post(x).marginal_credible_bounds()

# Plot result.
plt.plot(x, f_true, label="True", style="test")
plt.scatter(x_obs, y_obs, label="Observations", style="train", s=20)
plt.plot(x, mean, label="Prediction", style="pred")
plt.fill_between(x, lower, upper, style="pred")
tweak()
plt.savefig("readme_example1_simple_regression.png")
plt.show()

Hyperparameter Optimisation with Varz

Prediction

import lab as B
import matplotlib.pyplot as plt
import torch
from varz import Vars, minimise_l_bfgs_b, parametrised, Positive
from wbml.plot import tweak

from stheno.torch import EQ, GP

# Increase regularisation because PyTorch defaults to 32-bit floats.
B.epsilon = 1e-6

# Define points to predict at.
x = torch.linspace(0, 2, 100)
x_obs = torch.linspace(0, 2, 50)

# Sample a true, underlying function and observations with observation noise `0.05`.
f_true = torch.sin(5 * x)
y_obs = torch.sin(5 * x_obs) + 0.05**0.5 * torch.randn(50)


def model(vs):
    """Construct a model with learnable parameters."""
    p = vs.struct  # Varz handles positivity (and other) constraints.
    kernel = p.variance.positive() * EQ().stretch(p.scale.positive())
    return GP(kernel), p.noise.positive()


@parametrised
def model_alternative(vs, scale: Positive, variance: Positive, noise: Positive):
    """Equivalent to :func:`model`, but with `@parametrised`."""
    kernel = variance * EQ().stretch(scale)
    return GP(kernel), noise


vs = Vars(torch.float32)
f, noise = model(vs)

# Condition on observations and make predictions before optimisation.
f_post = f | (f(x_obs, noise), y_obs)
prior_before = f, noise
pred_before = f_post(x, noise).marginal_credible_bounds()


def objective(vs):
    f, noise = model(vs)
    evidence = f(x_obs, noise).logpdf(y_obs)
    return -evidence


# Learn hyperparameters.
minimise_l_bfgs_b(objective, vs)

f, noise = model(vs)

# Condition on observations and make predictions after optimisation.
f_post = f | (f(x_obs, noise), y_obs)
prior_after = f, noise
pred_after = f_post(x, noise).marginal_credible_bounds()


def plot_prediction(prior, pred):
    f, noise = prior
    mean, lower, upper = pred
    plt.scatter(x_obs, y_obs, label="Observations", style="train", s=20)
    plt.plot(x, f_true, label="True", style="test")
    plt.plot(x, mean, label="Prediction", style="pred")
    plt.fill_between(x, lower, upper, style="pred")
    plt.ylim(-2, 2)
    plt.text(
        0.02,
        0.02,
        f"var = {f.kernel.factor(0):.2f}, "
        f"scale = {f.kernel.factor(1).stretches[0]:.2f}, "
        f"noise = {noise:.2f}",
        transform=plt.gca().transAxes,
    )
    tweak()


# Plot result.
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.title("Before optimisation")
plot_prediction(prior_before, pred_before)
plt.subplot(1, 2, 2)
plt.title("After optimisation")
plot_prediction(prior_after, pred_after)
plt.savefig("readme_example12_optimisation_varz.png")
plt.show()

Hyperparameter Optimisation with PyTorch

Prediction

import lab as B
import matplotlib.pyplot as plt
import torch
from wbml.plot import tweak

from stheno.torch import EQ, GP

# Increase regularisation because PyTorch defaults to 32-bit floats.
B.epsilon = 1e-6

# Define points to predict at.
x = torch.linspace(0, 2, 100)
x_obs = torch.linspace(0, 2, 50)

# Sample a true, underlying function and observations with observation noise `0.05`.
f_true = torch.sin(5 * x)
y_obs = torch.sin(5 * x_obs) + 0.05**0.5 * torch.randn(50)


class Model(torch.nn.Module):
    """A GP model with learnable parameters."""

    def __init__(self, init_var=0.3, init_scale=1, init_noise=0.2):
        super().__init__()
        # Ensure that the parameters are positive and make them learnable.
        self.log_var = torch.nn.Parameter(torch.log(torch.tensor(init_var)))
        self.log_scale = torch.nn.Parameter(torch.log(torch.tensor(init_scale)))
        self.log_noise = torch.nn.Parameter(torch.log(torch.tensor(init_noise)))

    def construct(self):
        self.var = torch.exp(self.log_var)
        self.scale = torch.exp(self.log_scale)
        self.noise = torch.exp(self.log_noise)
        kernel = self.var * EQ().stretch(self.scale)
        return GP(kernel), self.noise


model = Model()
f, noise = model.construct()

# Condition on observations and make predictions before optimisation.
f_post = f | (f(x_obs, noise), y_obs)
prior_before = f, noise
pred_before = f_post(x, noise).marginal_credible_bounds()

# Perform optimisation.
opt = torch.optim.Adam(model.parameters(), lr=5e-2)
for _ in range(1000):
    opt.zero_grad()
    f, noise = model.construct()
    loss = -f(x_obs, noise).logpdf(y_obs)
    loss.backward()
    opt.step()

f, noise = model.construct()

# Condition on observations and make predictions after optimisation.
f_post = f | (f(x_obs, noise), y_obs)
prior_after = f, noise
pred_after = f_post(x, noise).marginal_credible_bounds()


def plot_prediction(prior, pred):
    f, noise = prior
    mean, lower, upper = pred
    plt.scatter(x_obs, y_obs, label="Observations", style="train", s=20)
    plt.plot(x, f_true, label="True", style="test")
    plt.plot(x, mean, label="Prediction", style="pred")
    plt.fill_between(x, lower, upper, style="pred")
    plt.ylim(-2, 2)
    plt.text(
        0.02,
        0.02,
        f"var = {f.kernel.factor(0):.2f}, "
        f"scale = {f.kernel.factor(1).stretches[0]:.2f}, "
        f"noise = {noise:.2f}",
        transform=plt.gca().transAxes,
    )
    tweak()


# Plot result.
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.title("Before optimisation")
plot_prediction(prior_before, pred_before)
plt.subplot(1, 2, 2)
plt.title("After optimisation")
plot_prediction(prior_after, pred_after)
plt.savefig("readme_example13_optimisation_torch.png")
plt.show()

Decomposition of Prediction

Prediction

import matplotlib.pyplot as plt
from wbml.plot import tweak

from stheno import Measure, GP, EQ, RQ, Linear, Delta, Exp, B

B.epsilon = 1e-10

# Define points to predict at.
x = B.linspace(0, 10, 200)
x_obs = B.linspace(0, 7, 50)


with Measure() as prior:
    # Construct a latent function consisting of four different components.
    f_smooth = GP(EQ())
    f_wiggly = GP(RQ(1e-1).stretch(0.5))
    f_periodic = GP(EQ().periodic(1.0))
    f_linear = GP(Linear())
    f = f_smooth + f_wiggly + f_periodic + 0.2 * f_linear

    # Let the observation noise consist of a bit of exponential noise.
    e_indep = GP(Delta())
    e_exp = GP(Exp())
    e = e_indep + 0.3 * e_exp

    # Sum the latent function and observation noise to get a model for the observations.
    y = f + 0.5 * e

# Sample a true, underlying function and observations.
(
    f_true_smooth,
    f_true_wiggly,
    f_true_periodic,
    f_true_linear,
    f_true,
    y_obs,
) = prior.sample(f_smooth(x), f_wiggly(x), f_periodic(x), f_linear(x), f(x), y(x_obs))

# Now condition on the observations and make predictions for the latent function and
# its various components.
post = prior | (y(x_obs), y_obs)

pred_smooth = post(f_smooth(x))
pred_wiggly = post(f_wiggly(x))
pred_periodic = post(f_periodic(x))
pred_linear = post(f_linear(x))
pred_f = post(f(x))


# Plot results.
def plot_prediction(x, f, pred, x_obs=None, y_obs=None):
    plt.plot(x, f, label="True", style="test")
    if x_obs is not None:
        plt.scatter(x_obs, y_obs, label="Observations", style="train", s=20)
    mean, lower, upper = pred.marginal_credible_bounds()
    plt.plot(x, mean, label="Prediction", style="pred")
    plt.fill_between(x, lower, upper, style="pred")
    tweak()


plt.figure(figsize=(10, 6))

plt.subplot(3, 1, 1)
plt.title("Prediction")
plot_prediction(x, f_true, pred_f, x_obs, y_obs)

plt.subplot(3, 2, 3)
plt.title("Smooth Component")
plot_prediction(x, f_true_smooth, pred_smooth)

plt.subplot(3, 2, 4)
plt.title("Wiggly Component")
plot_prediction(x, f_true_wiggly, pred_wiggly)

plt.subplot(3, 2, 5)
plt.title("Periodic Component")
plot_prediction(x, f_true_periodic, pred_periodic)

plt.subplot(3, 2, 6)
plt.title("Linear Component")
plot_prediction(x, f_true_linear, pred_linear)

plt.savefig("readme_example2_decomposition.png")
plt.show()

Learn a Function, Incorporating Prior Knowledge About Its Form

Prediction

import matplotlib.pyplot as plt
import tensorflow as tf
import wbml.out as out
from varz.spec import parametrised, Positive
from varz.tensorflow import Vars, minimise_l_bfgs_b
from wbml.plot import tweak

from stheno.tensorflow import B, Measure, GP, EQ, Delta

# Define points to predict at.
x = B.linspace(tf.float64, 0, 5, 100)
x_obs = B.linspace(tf.float64, 0, 3, 20)


@parametrised
def model(
    vs,
    u_var: Positive = 0.5,
    u_scale: Positive = 0.5,
    noise: Positive = 0.5,
    alpha: Positive = 1.2,
):
    with Measure():
        # Random fluctuation:
        u = GP(u_var * EQ().stretch(u_scale))
        # Construct model.
        f = u + (lambda x: x**alpha)
    return f, noise


# Sample a true, underlying function and observations.
vs = Vars(tf.float64)
f_true = x**1.8 + B.sin(2 * B.pi * x)
f, y = model(vs)
post = f.measure | (f(x), f_true)
y_obs = post(f(x_obs)).sample()


def objective(vs):
    f, noise = model(vs)
    evidence = f(x_obs, noise).logpdf(y_obs)
    return -evidence


# Learn hyperparameters.
minimise_l_bfgs_b(objective, vs, jit=True)
f, noise = model(vs)

# Print the learned parameters.
out.kv("Prior", f.display(out.format))
vs.print()

# Condition on the observations to make predictions.
f_post = f | (f(x_obs, noise), y_obs)
mean, lower, upper = f_post(x).marginal_credible_bounds()

# Plot result.
plt.plot(x, B.squeeze(f_true), label="True", style="test")
plt.scatter(x_obs, B.squeeze(y_obs), label="Observations", style="train", s=20)
plt.plot(x, mean, label="Prediction", style="pred")
plt.fill_between(x, lower, upper, style="pred")
tweak()

plt.savefig("readme_example3_parametric.png")
plt.show()

Multi-Output Regression

Prediction

import matplotlib.pyplot as plt
from wbml.plot import tweak

from stheno import B, Measure, GP, EQ, Delta


class VGP:
    """A vector-valued GP."""

    def __init__(self, ps):
        self.ps = ps

    def __add__(self, other):
        return VGP([f + g for f, g in zip(self.ps, other.ps)])

    def lmatmul(self, A):
        m, n = A.shape
        ps = [0 for _ in range(m)]
        for i in range(m):
            for j in range(n):
                ps[i] += A[i, j] * self.ps[j]
        return VGP(ps)


# Define points to predict at.
x = B.linspace(0, 10, 100)
x_obs = B.linspace(0, 10, 10)

# Model parameters:
m = 2
p = 4
H = B.randn(p, m)


with Measure() as prior:
    # Construct latent functions.
    us = VGP([GP(EQ()) for _ in range(m)])

    # Construct multi-output prior.
    fs = us.lmatmul(H)

    # Construct noise.
    e = VGP([GP(0.5 * Delta()) for _ in range(p)])

    # Construct observation model.
    ys = e + fs

# Sample a true, underlying function and observations.
samples = prior.sample(*(p(x) for p in fs.ps), *(p(x_obs) for p in ys.ps))
fs_true, ys_obs = samples[:p], samples[p:]

# Compute the posterior and make predictions.
post = prior.condition(*((p(x_obs), y_obs) for p, y_obs in zip(ys.ps, ys_obs)))
preds = [post(p(x)) for p in fs.ps]


# Plot results.
def plot_prediction(x, f, pred, x_obs=None, y_obs=None):
    plt.plot(x, f, label="True", style="test")
    if x_obs is not None:
        plt.scatter(x_obs, y_obs, label="Observations", style="train", s=20)
    mean, lower, upper = pred.marginal_credible_bounds()
    plt.plot(x, mean, label="Prediction", style="pred")
    plt.fill_between(x, lower, upper, style="pred")
    tweak()


plt.figure(figsize=(10, 6))
for i in range(4):
    plt.subplot(2, 2, i + 1)
    plt.title(f"Output {i + 1}")
    plot_prediction(x, fs_true[i], preds[i], x_obs, ys_obs[i])
plt.savefig("readme_example4_multi-output.png")
plt.show()

Approximate Integration

Prediction

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import wbml.plot

from stheno.tensorflow import B, Measure, GP, EQ, Delta

# Define points to predict at.
x = B.linspace(tf.float64, 0, 10, 200)
x_obs = B.linspace(tf.float64, 0, 10, 10)

with Measure() as prior:
    # Construct a model.
    f = 0.7 * GP(EQ()).stretch(1.5)
    e = 0.2 * GP(Delta())

    # Construct derivatives.
    df = f.diff()
    ddf = df.diff()
    dddf = ddf.diff() + e

# Fix the integration constants.
zero = B.cast(tf.float64, 0)
one = B.cast(tf.float64, 1)
prior = prior | ((f(zero), one), (df(zero), zero), (ddf(zero), -one))

# Sample observations.
y_obs = B.sin(x_obs) + 0.2 * B.randn(*x_obs.shape)

# Condition on the observations to make predictions.
post = prior | (dddf(x_obs), y_obs)

# And make predictions.
pred_iiif = post(f)(x)
pred_iif = post(df)(x)
pred_if = post(ddf)(x)
pred_f = post(dddf)(x)


# Plot result.
def plot_prediction(x, f, pred, x_obs=None, y_obs=None):
    plt.plot(x, f, label="True", style="test")
    if x_obs is not None:
        plt.scatter(x_obs, y_obs, label="Observations", style="train", s=20)
    mean, lower, upper = pred.marginal_credible_bounds()
    plt.plot(x, mean, label="Prediction", style="pred")
    plt.fill_between(x, lower, upper, style="pred")
    wbml.plot.tweak()


plt.figure(figsize=(10, 6))

plt.subplot(2, 2, 1)
plt.title("Function")
plot_prediction(x, np.sin(x), pred_f, x_obs=x_obs, y_obs=y_obs)

plt.subplot(2, 2, 2)
plt.title("Integral of Function")
plot_prediction(x, -np.cos(x), pred_if)

plt.subplot(2, 2, 3)
plt.title("Second Integral of Function")
plot_prediction(x, -np.sin(x), pred_iif)

plt.subplot(2, 2, 4)
plt.title("Third Integral of Function")
plot_prediction(x, np.cos(x), pred_iiif)

plt.savefig("readme_example5_integration.png")
plt.show()

Bayesian Linear Regression

Prediction

import matplotlib.pyplot as plt
import wbml.out as out
from wbml.plot import tweak

from stheno import B, Measure, GP

B.epsilon = 1e-10  # Very slightly regularise.

# Define points to predict at.
x = B.linspace(0, 10, 200)
x_obs = B.linspace(0, 10, 10)

with Measure() as prior:
    # Construct a linear model.
    slope = GP(1)
    intercept = GP(5)
    f = slope * (lambda x: x) + intercept

# Sample a slope, intercept, underlying function, and observations.
true_slope, true_intercept, f_true, y_obs = prior.sample(
    slope(0), intercept(0), f(x), f(x_obs, 0.2)
)

# Condition on the observations to make predictions.
post = prior | (f(x_obs, 0.2), y_obs)
mean, lower, upper = post(f(x)).marginal_credible_bounds()

out.kv("True slope", true_slope[0, 0])
out.kv("Predicted slope", post(slope(0)).mean[0, 0])
out.kv("True intercept", true_intercept[0, 0])
out.kv("Predicted intercept", post(intercept(0)).mean[0, 0])

# Plot result.
plt.plot(x, f_true, label="True", style="test")
plt.scatter(x_obs, y_obs, label="Observations", style="train", s=20)
plt.plot(x, mean, label="Prediction", style="pred")
plt.fill_between(x, lower, upper, style="pred")
tweak()

plt.savefig("readme_example6_blr.png")
plt.show()

GPAR

Prediction

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from varz.spec import parametrised, Positive
from varz.tensorflow import Vars, minimise_l_bfgs_b
from wbml.plot import tweak

from stheno.tensorflow import B, GP, EQ

# Define points to predict at.
x = B.linspace(tf.float64, 0, 10, 200)
x_obs1 = B.linspace(tf.float64, 0, 10, 30)
inds2 = np.random.permutation(len(x_obs1))[:10]
x_obs2 = B.take(x_obs1, inds2)

# Construction functions to predict and observations.
f1_true = B.sin(x)
f2_true = B.sin(x) ** 2

y1_obs = B.sin(x_obs1) + 0.1 * B.randn(*x_obs1.shape)
y2_obs = B.sin(x_obs2) ** 2 + 0.1 * B.randn(*x_obs2.shape)


@parametrised
def model(
    vs,
    var1: Positive = 1,
    scale1: Positive = 1,
    noise1: Positive = 0.1,
    var2: Positive = 1,
    scale2: Positive = 1,
    noise2: Positive = 0.1,
):
    # Build layers:
    f1 = GP(var1 * EQ().stretch(scale1))
    f2 = GP(var2 * EQ().stretch(scale2))
    return (f1, noise1), (f2, noise2)


def objective(vs):
    (f1, noise1), (f2, noise2) = model(vs)
    x1 = x_obs1
    x2 = B.stack(x_obs2, B.take(y1_obs, inds2), axis=1)
    evidence = f1(x1, noise1).logpdf(y1_obs) + f2(x2, noise2).logpdf(y2_obs)
    return -evidence


# Learn hyperparameters.
vs = Vars(tf.float64)
minimise_l_bfgs_b(objective, vs)

# Compute posteriors.
(f1, noise1), (f2, noise2) = model(vs)
x1 = x_obs1
x2 = B.stack(x_obs2, B.take(y1_obs, inds2), axis=1)
f1_post = f1 | (f1(x1, noise1), y1_obs)
f2_post = f2 | (f2(x2, noise2), y2_obs)

# Predict first output.
mean1, lower1, upper1 = f1_post(x).marginal_credible_bounds()

# Predict second output with Monte Carlo.
samples = [
    f2_post(B.stack(x, f1_post(x).sample()[:, 0], axis=1)).sample()[:, 0]
    for _ in range(100)
]
mean2 = np.mean(samples, axis=0)
lower2 = np.percentile(samples, 2.5, axis=0)
upper2 = np.percentile(samples, 100 - 2.5, axis=0)

# Plot result.
plt.figure()

plt.subplot(2, 1, 1)
plt.title("Output 1")
plt.plot(x, f1_true, label="True", style="test")
plt.scatter(x_obs1, y1_obs, label="Observations", style="train", s=20)
plt.plot(x, mean1, label="Prediction", style="pred")
plt.fill_between(x, lower1, upper1, style="pred")
tweak()

plt.subplot(2, 1, 2)
plt.title("Output 2")
plt.plot(x, f2_true, label="True", style="test")
plt.scatter(x_obs2, y2_obs, label="Observations", style="train", s=20)
plt.plot(x, mean2, label="Prediction", style="pred")
plt.fill_between(x, lower2, upper2, style="pred")
tweak()

plt.savefig("readme_example7_gpar.png")
plt.show()

A GP-RNN Model

Prediction

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from varz.spec import parametrised, Positive
from varz.tensorflow import Vars, minimise_adam
from wbml.net import rnn as rnn_constructor
from wbml.plot import tweak

from stheno.tensorflow import B, Measure, GP, EQ

# Increase regularisation because we are dealing with `tf.float32`s.
B.epsilon = 1e-6

# Construct points which to predict at.
x = B.linspace(tf.float32, 0, 1, 100)[:, None]
inds_obs = B.range(0, int(0.75 * len(x)))  # Train on the first 75% only.
x_obs = B.take(x, inds_obs)

# Construct function and observations.
#   Draw random modulation functions.
a_true = GP(1e-2 * EQ().stretch(0.1))(x).sample()
b_true = GP(1e-2 * EQ().stretch(0.1))(x).sample()
#   Construct the true, underlying function.
f_true = (1 + a_true) * B.sin(2 * np.pi * 7 * x) + b_true
#   Add noise.
y_true = f_true + 0.1 * B.randn(*f_true.shape)

# Normalise and split.
f_true = (f_true - B.mean(y_true)) / B.std(y_true)
y_true = (y_true - B.mean(y_true)) / B.std(y_true)
y_obs = B.take(y_true, inds_obs)


@parametrised
def model(vs, a_scale: Positive = 0.1, b_scale: Positive = 0.1, noise: Positive = 0.01):
    # Construct an RNN.
    f_rnn = rnn_constructor(
        output_size=1, widths=(10,), nonlinearity=B.tanh, final_dense=True
    )

    # Set the weights for the RNN.
    num_weights = f_rnn.num_weights(input_size=1)
    weights = Vars(tf.float32, source=vs.get(shape=(num_weights,), name="rnn"))
    f_rnn.initialise(input_size=1, vs=weights)

    with Measure():
        # Construct GPs that modulate the RNN.
        a = GP(1e-2 * EQ().stretch(a_scale))
        b = GP(1e-2 * EQ().stretch(b_scale))

        # GP-RNN model:
        f_gp_rnn = (1 + a) * (lambda x: f_rnn(x)) + b

    return f_rnn, f_gp_rnn, noise, a, b


def objective_rnn(vs):
    f_rnn, _, _, _, _ = model(vs)
    return B.mean((f_rnn(x_obs) - y_obs) ** 2)


def objective_gp_rnn(vs):
    _, f_gp_rnn, noise, _, _ = model(vs)
    evidence = f_gp_rnn(x_obs, noise).logpdf(y_obs)
    return -evidence


# Pretrain the RNN.
vs = Vars(tf.float32)
minimise_adam(objective_rnn, vs, rate=5e-3, iters=1000, trace=True, jit=True)

# Jointly train the RNN and GPs.
minimise_adam(objective_gp_rnn, vs, rate=1e-3, iters=1000, trace=True, jit=True)

_, f_gp_rnn, noise, a, b = model(vs)

# Condition.
post = f_gp_rnn.measure | (f_gp_rnn(x_obs, noise), y_obs)

# Predict and plot results.
plt.figure(figsize=(10, 6))

plt.subplot(2, 1, 1)
plt.title("$(1 + a)\\cdot {}$RNN${} + b$")
plt.plot(x, f_true, label="True", style="test")
plt.scatter(x_obs, y_obs, label="Observations", style="train", s=20)
mean, lower, upper = post(f_gp_rnn(x)).marginal_credible_bounds()
plt.plot(x, mean, label="Prediction", style="pred")
plt.fill_between(x, lower, upper, style="pred")
tweak()

plt.subplot(2, 2, 3)
plt.title("$a$")
mean, lower, upper = post(a(x)).marginal_credible_bounds()
plt.plot(x, mean, label="Prediction", style="pred")
plt.fill_between(x, lower, upper, style="pred")
tweak()

plt.subplot(2, 2, 4)
plt.title("$b$")
mean, lower, upper = post(b(x)).marginal_credible_bounds()
plt.plot(x, mean, label="Prediction", style="pred")
plt.fill_between(x, lower, upper, style="pred")
tweak()

plt.savefig(f"readme_example8_gp-rnn.png")
plt.show()

Approximate Multiplication Between GPs

Prediction

import matplotlib.pyplot as plt
from wbml.plot import tweak

from stheno import B, Measure, GP, EQ

# Define points to predict at.
x = B.linspace(0, 10, 100)

with Measure() as prior:
    f1 = GP(3, EQ())
    f2 = GP(3, EQ())

    # Compute the approximate product.
    f_prod = f1 * f2

# Sample two functions.
s1, s2 = prior.sample(f1(x), f2(x))

# Predict.
f_prod_post = f_prod | ((f1(x), s1), (f2(x), s2))
mean, lower, upper = f_prod_post(x).marginal_credible_bounds()

# Plot result.
plt.plot(x, s1, label="Sample 1", style="train")
plt.plot(x, s2, label="Sample 2", style="train", ls="--")
plt.plot(x, s1 * s2, label="True product", style="test")
plt.plot(x, mean, label="Approximate posterior", style="pred")
plt.fill_between(x, lower, upper, style="pred")
tweak()

plt.savefig("readme_example9_product.png")
plt.show()

Sparse Regression

Prediction

import matplotlib.pyplot as plt
import wbml.out as out
from wbml.plot import tweak

from stheno import B, GP, EQ, PseudoObs

# Define points to predict at.
x = B.linspace(0, 10, 100)
x_obs = B.linspace(0, 7, 50_000)
x_ind = B.linspace(0, 10, 20)

# Construct a prior.
f = GP(EQ().periodic(2 * B.pi))

# Sample a true, underlying function and observations.
f_true = B.sin(x)
y_obs = B.sin(x_obs) + B.sqrt(0.5) * B.randn(*x_obs.shape)

# Compute a pseudo-point approximation of the posterior.
obs = PseudoObs(f(x_ind), (f(x_obs, 0.5), y_obs))

# Compute the ELBO.
out.kv("ELBO", obs.elbo(f.measure))

# Compute the approximate posterior.
f_post = f | obs

# Make predictions with the approximate posterior.
mean, lower, upper = f_post(x).marginal_credible_bounds()

# Plot result.
plt.plot(x, f_true, label="True", style="test")
plt.scatter(
    x_obs,
    y_obs,
    label="Observations",
    style="train",
    c="tab:green",
    alpha=0.35,
)
plt.scatter(
    x_ind,
    obs.mu(f.measure)[:, 0],
    label="Inducing Points",
    style="train",
    s=20,
)
plt.plot(x, mean, label="Prediction", style="pred")
plt.fill_between(x, lower, upper, style="pred")
tweak()

plt.savefig("readme_example10_sparse.png")
plt.show()

Smoothing with Nonparametric Basis Functions

Prediction

import matplotlib.pyplot as plt
from wbml.plot import tweak

from stheno import B, Measure, GP, EQ

# Define points to predict at.
x = B.linspace(0, 10, 100)
x_obs = B.linspace(0, 10, 20)

with Measure() as prior:
    w = lambda x: B.exp(-(x**2) / 0.5)  # Basis function
    b = [(w * GP(EQ())).shift(xi) for xi in x_obs]  # Weighted basis functions
    f = sum(b)

# Sample a true, underlying function and observations.
f_true, y_obs = prior.sample(f(x), f(x_obs, 0.2))

# Condition on the observations to make predictions.
post = prior | (f(x_obs, 0.2), y_obs)

# Plot result.
for i, bi in enumerate(b):
    mean, lower, upper = post(bi(x)).marginal_credible_bounds()
    kw_args = {"label": "Basis functions"} if i == 0 else {}
    plt.plot(x, mean, style="pred2", **kw_args)
plt.plot(x, f_true, label="True", style="test")
plt.scatter(x_obs, y_obs, label="Observations", style="train", s=20)
mean, lower, upper = post(f(x)).marginal_credible_bounds()
plt.plot(x, mean, label="Prediction", style="pred")
plt.fill_between(x, lower, upper, style="pred")
tweak()

plt.savefig("readme_example11_nonparametric_basis.png")
plt.show()

stheno's People

Contributors

patel-zeel avatar wesselb avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

stheno's Issues

CI fails without any edit

Description of the bug

I have just forked Stheno and executed CI. It failed without any of the changes made by me.

Documentation about Multi-Output Regression

Hi @wesselb,

I am trying to use your example of Multi-Output Regression with some data I have. I don't understand how to correctly give them to the VGP and them make a prediction.
My data as input x_obs are not the same, so it's not exactly as the example. I have nine x observation as [x1,x2,x3,x4,x5,x6,x7,x8,x9] with their y observation as [y1,y2,y3,y4,y5,y6,y7,y8,y9].
Also, with your example provided, is it possible to optimize some hyperparameters if we had some in the VGP ?

Here are my code I was trying to use, with 3 different outputs to simulate data. Thank you in advance for your help.

import matplotlib.pyplot as plt
from wbml.plot import tweak
from stheno import B, Measure, GP, EQ, Delta, Matern52

class VGP:
    """A vector-valued GP."""

    def __init__(self, ps):
        self.ps = ps

    def __add__(self, other):
        return VGP([f + g for f, g in zip(self.ps, other.ps)])

    def lmatmul(self, A):
        m, n = A.shape
        ps = [0 for _ in range(m)]
        for i in range(m):
            for j in range(n):
                ps[i] += A[i, j] * self.ps[j]
        return VGP(ps)

# Define points to predict at.
x = B.linspace(0, 10, 5)

# Create some sample data.
x1 = np.atleast_2d(np.linspace(0, 10, 5)).T
x2 = np.atleast_2d(np.linspace(0, 9, 5)).T
x3 = np.atleast_2d(np.linspace(0, 7, 5)).T
y1 = np.atleast_2d(np.linspace(0, 10, 5)).T
y2 = np.atleast_2d(np.linspace(0, 10, 5)).T
y3 = np.atleast_2d(np.linspace(0, 10, 5)).T

x_obs = [x1,x2,x3]
y_obs = [y1,y2,y3]

# Model parameters:
m = 3
p = 3
H = B.randn(p, m)

with Measure() as prior:
    # Construct latent functions.
    us = VGP([GP(Matern52()) for _ in range(m)])
    # Construct multi-output prior.
    fs = us.lmatmul(H)
    # Construct noise.
    e = VGP([GP(0 * Delta()) for _ in range(p)])
    # Construct observation model.
    ys = e + fs

# Sample a true, underlying function and observations.
samples = prior.sample(*(p(x) for p in zip(fs.ps)), *(p(x_obs) for p, x_obs in zip(ys.ps, x_obs)))
fs_true, ys_obs = samples[:p], samples[p:]

# Compute the posterior and make predictions.
post = prior.condition(*((p(x_obs), y_obs) for p, y_obs, x_obs in zip(ys.ps, ys_obs, x_obs)))
preds = [post(p(x)) for p in fs.ps]

# Plot results.
def plot_prediction(x, f, pred, x_obs=None, y_obs=None):
    plt.plot(x, f, label="True", style="test")
    if x_obs is not None:
        plt.scatter(x_obs, y_obs, label="Observations", style="train", s=20)
    mean, lower, upper = pred.marginal_credible_bounds()
    plt.plot(x, mean, label="Prediction", style="pred")
    plt.fill_between(x, lower, upper, style="pred")
    tweak()

plt.figure(figsize=(10, 6))
for i in range(3):
    plt.subplot(3, 1, i + 1)
    plt.title(f"Output {i + 1}")
    plot_prediction(x, fs_true[i], preds[i], x_obs, ys_obs[i])
plt.show()

B.jit_to_numpy pervents JAX transforms in AbstractObservations.__init__

For example this code don't work:

import jax
import jax.numpy as jnp
from stheno.jax import GP, EQ
from lab import B

def compute(y):
    f = GP(EQ())
    f = f | (f(jnp.array([0., 1.])), y)
    return B.dense(f(0).mean)[0][0]

jax.vmap(compute)(jnp.array([[1., 2.], [3., 4.]]))

Converting to NumPy arrays is also slow ig?

Can this be changed so that it calls jax.lax.cond under the hood instead?

problem with running examples

Hi Wessel!

when I run the "readme_example8_gp-rnn.py" I get this error:

TypeError: Can not convert a NoneType into a Tensor or Operation.

what should I do for this problem?

Any help will be grateful.

TypeError: __init__() got multiple values for argument 'measure'

Hi @wesselb

I was having some issues with Stheno. The minimal code was failing with the error

In [1]: from stheno import GP

In [2]: f = GP(1)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-2-47d7da13e50a> in <module>
----> 1 f = GP(1)

/usr/local/anaconda3/lib/python3.8/site-packages/plum/function.cpython-38-darwin.so in plum.function._BoundFunction.__call__()

/usr/local/anaconda3/lib/python3.8/site-packages/plum/function.cpython-38-darwin.so in plum.function.Function.__call__()

/usr/local/anaconda3/lib/python3.8/site-packages/stheno/model/gp.py in __init__(self, kernel, measure, name)
    100     @_dispatch
    101     def __init__(self, kernel, measure=None, name=None):
--> 102         self.__init__(ZeroMean(), kernel, measure=measure, name=name)
    103 
    104     @_dispatch

/usr/local/anaconda3/lib/python3.8/site-packages/plum/function.cpython-38-darwin.so in plum.function._BoundFunction.__call__()

/usr/local/anaconda3/lib/python3.8/site-packages/plum/function.cpython-38-darwin.so in plum.function.Function.__call__()

TypeError: __init__() got multiple values for argument 'measure'

I thus purged my Anaconda environment and installed Stheno again to no effect.

This is the output I get from running pip install --upgrade --upgrade-strategy eager stheno

Successfully installed algebra-1.1.0 backends-1.3.6 backends-matrix-1.1.2 
cftime-1.5.0 charset-normalizer-2.0.4 fdm-0.4.1 idna-3.2 matplotlib-3.4.3 
mlkernels-0.3.0 netCDF4-1.5.7 numpy-1.21.2 pandas-1.3.2 plum-dispatch-1.5.3 
python-slugify-5.0.2 requests-2.26.0 scipy-1.7.1 setuptools-57.4.0 sklearn-0.0 
stheno-1.1.3 text-unidecode-1.3 varz-0.7.2 wbml-0.3.4 xarray-0.19.0

Underfitting model with `@parametrised`

Description of the bug

Hi @wesselb,

I am not sure if this issue is better suited in varz. I found an abnormal behaviour in the optimization process when using two equivalent variants of defining the model.

  • Variant 1: As per #18 - Predictive variance seems properly calibrated in final fit.
    image

  • Variant 2: Enabling @parametrised version of code in #18 - Predictive variance seems underfitting in final fit.
    image

I tried disabling fixed random seed and still experienced the same phenomenon.

Description of your environment

Package Version Location


absl-py 0.13.0
alabaster 0.7.12
algebra 1.1.0
argon2-cffi 21.1.0
astunparse 1.6.3
attrs 21.2.0
autograd 1.3
Babel 2.9.1
backcall 0.2.0
backends 1.4.13
backends-matrix 1.2.6
black 21.9b0
bleach 4.1.0
cachetools 4.2.2
certifi 2021.5.30
cffi 1.14.6
cftime 1.5.1
charset-normalizer 2.0.6
clang 5.0
click 8.0.2
cloudpickle 2.0.0
coverage 5.5
coveralls 3.2.0
cycler 0.10.0
debugpy 1.5.0
decorator 5.1.0
defusedxml 0.7.1
dm-tree 0.1.6
docopt 0.6.2
docutils 0.17.1
entrypoints 0.3
fdm 0.4.1
flatbuffers 1.12
future 0.18.2
gast 0.4.0
google-auth 1.32.0
google-auth-oauthlib 0.4.4
google-pasta 0.2.0
gpytorch 1.5.1
grpcio 1.34.1
h5py 3.1.0
idna 3.2
imagesize 1.2.0
iniconfig 1.1.1
ipykernel 6.4.1
ipython 7.28.0
ipython-genutils 0.2.0
ipywidgets 7.6.5
jax 0.2.22
jaxlib 0.1.72
jedi 0.18.0
Jinja2 3.0.2
joblib 1.1.0
jsonschema 4.1.0
jupyter 1.0.0
jupyter-client 7.0.6
jupyter-console 6.4.0
jupyter-core 4.8.1
jupyterlab-pygments 0.1.2
jupyterlab-widgets 1.0.2
keras 2.6.0
keras-nightly 2.5.0.dev2021032900
Keras-Preprocessing 1.1.2
kiwisolver 1.3.2
Markdown 3.3.4
MarkupSafe 2.0.1
matplotlib 3.4.3
matplotlib-inline 0.1.3
mistune 0.8.4
mlkernels 0.3.4
mpmath 1.2.1
mypy-extensions 0.4.3
nbclient 0.5.4
nbconvert 6.2.0
nbformat 5.1.3
nest-asyncio 1.5.1
netCDF4 1.5.7
notebook 6.4.4
numpy 1.21.4
oauthlib 3.1.1
opt-einsum 3.3.0
packaging 21.0
pandas 1.3.3
pandocfilters 1.5.0
parso 0.8.2
pathspec 0.9.0
pexpect 4.8.0
pickleshare 0.7.5
Pillow 8.3.2
pip 21.2.4
platformdirs 2.4.0
pluggy 1.0.0
plum-dispatch 1.5.4
prometheus-client 0.11.0
prompt-toolkit 3.0.20
protobuf 3.17.3
ptyprocess 0.7.0
py 1.10.0
pyasn1 0.4.8
pyasn1-modules 0.2.8
pycparser 2.20
Pygments 2.10.0
pyparsing 2.4.7
pyrsistent 0.18.0
pytest 6.2.5
pytest-cov 3.0.0
python-dateutil 2.8.2
python-slugify 5.0.2
pytz 2021.3
pyzmq 22.3.0
qtconsole 5.1.1
QtPy 1.11.2
redeal 0.2.0
regdata 1.0.2
regex 2021.10.8
requests 2.26.0
requests-oauthlib 1.3.0
rsa 4.7.2
scikit-learn 1.0
scipy 1.7.1
Send2Trash 1.8.0
setuptools 58.0.4
setuptools-scm 6.3.2
setuptools-scm-git-archive 1.1
six 1.15.0
sklearn 0.0
snowballstemmer 2.1.0
Sphinx 4.2.0
sphinx-rtd-theme 1.0.0
sphinxcontrib-applehelp 1.0.2
sphinxcontrib-devhelp 1.0.2
sphinxcontrib-htmlhelp 2.0.0
sphinxcontrib-jsmath 1.0.1
sphinxcontrib-qthelp 1.0.3
sphinxcontrib-serializinghtml 1.1.5
stheno 1.1.8.dev21+gaa3ce24 /home/patel_zeel/stheno
sympy 1.9
tensorboard 2.6.0
tensorboard-data-server 0.6.1
tensorboard-plugin-wit 1.8.0
tensorflow 2.5.0
tensorflow-estimator 2.5.0
tensorflow-probability 0.13.0
termcolor 1.1.0
terminado 0.12.1
testpath 0.5.0
text-unidecode 1.3
threadpoolctl 3.0.0
toml 0.10.2
tomli 1.2.1
torch 1.9.1
tornado 6.1
traitlets 5.1.0
typing_extensions 4.0.0
urllib3 1.26.7
varz 0.7.4
wbml 0.3.14
wcwidth 0.2.5
webencodings 0.5.1
Werkzeug 2.0.2
wheel 0.37.0
widgetsnbextension 3.5.1
wrapt 1.12.1
xarray 0.19.0

Import error

import stheno gives error "ImportError: /data/hpcdata/users/risno/miniconda3/envs/stheno/bin/../lib/libstdc++.so.6: version `GLIBCXX_3.4.21' not found (required by /data/hpcdata/users/risno/miniconda3/envs/stheno/lib/python3.6/site-packages/google/protobuf/pyext/_message.cpython-36m-x86_64-linux-gnu.so)"

Steps taken

On a fresh conda installation, I did:

conda install -c anaconda gcc
conda install -c anaconda gfortran_linux-64
conda install tensorflow-gpu  # this installs 2.0
pip install stheno
python -c "import stheno"

The first 3 lines attempt to follow the installation instructions. Perhaps this is the wrong way to install gcc and gfortran? Installation instructions suggest conda install gcc but this gives PackagesNotFoundError error.

Full error message:

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/data/hpcdata/users/risno/miniconda3/envs/stheno/lib/python3.6/site-packages/stheno/__init__.py", line 11, in <module>
    from .mean import *
  File "/data/hpcdata/users/risno/miniconda3/envs/stheno/lib/python3.6/site-packages/stheno/mean.py", line 7, in <module>
    import tensorflow as tf
  File "/data/hpcdata/users/risno/miniconda3/envs/stheno/lib/python3.6/site-packages/tensorflow/__init__.py", line 98, in <module>
    from tensorflow_core import *
  File "/data/hpcdata/users/risno/miniconda3/envs/stheno/lib/python3.6/site-packages/tensorflow_core/__init__.py", line 40, in <module>
    from tensorflow.python.tools import module_util as _module_util
  File "/data/hpcdata/users/risno/miniconda3/envs/stheno/lib/python3.6/site-packages/tensorflow/__init__.py", line 50, in __getattr__
    module = self._load()
  File "/data/hpcdata/users/risno/miniconda3/envs/stheno/lib/python3.6/site-packages/tensorflow/__init__.py", line 44, in _load
    module = _importlib.import_module(self.__name__)
  File "/data/hpcdata/users/risno/miniconda3/envs/stheno/lib/python3.6/importlib/__init__.py", line 126, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
  File "/data/hpcdata/users/risno/miniconda3/envs/stheno/lib/python3.6/site-packages/tensorflow_core/python/__init__.py", line 52, in <module>
    from tensorflow.core.framework.graph_pb2 import *
  File "/data/hpcdata/users/risno/miniconda3/envs/stheno/lib/python3.6/site-packages/tensorflow_core/core/framework/graph_pb2.py", line 7, in <module>
    from google.protobuf import descriptor as _descriptor
  File "/data/hpcdata/users/risno/miniconda3/envs/stheno/lib/python3.6/site-packages/google/protobuf/descriptor.py", line 47, in <module>
    from google.protobuf.pyext import _message
ImportError: /data/hpcdata/users/risno/miniconda3/envs/stheno/bin/../lib/libstdc++.so.6: version `GLIBCXX_3.4.21' not found (required by /data/hpcdata/users/risno/miniconda3/envs/stheno/lib/python3.6/site-packages/google/protobuf/pyext/_message.cpython-36m-x86_64-linux-gnu.so)

OS version:
Linux bslws05 3.10.0-1062.1.2.el7.x86_64 #1 SMP Mon Sep 30 14:19:46 UTC 2019 x86_64 x86_64 x86_64 GNU/Linux
LSB Version: :core-4.1-amd64:core-4.1-noarch
Distributor ID: CentOS
Description: CentOS Linux release 7.7.1908 (Core)
Release: 7.7.1908
Codename: Core

Kernel hyperparamaters

How to check the GP hyperparameters (kernel lengthscale, kernel variance, likelihood variance)?

A suggestion: Opening the GitHub discussions tab would be useful for such queries.

problem installing stheno in windows 10 using anaconda

Hi,

I have installed gfortran in windows 10 by installing mingw-64. In my anaconda command line, I am able to compile code using gfortran. However, when I tried pip install stheno, following error occurs:
ERROR: Command errored out with exit status 1:
command: 'D:\Anaconda\bin\envs\cosimo_pytorch\python.exe' 'D:\Anaconda\bin\envs\cosimo_pytorch\lib\site-packages\pip_vendor\pep517_in_process.py' get_requires_for_build_wheel 'C:\Users\XIAN_L1\AppData\Local\Temp\tmp7uvkgbia'
cwd: C:\Users\XIAN_L
1\AppData\Local\Temp\pip-install-l5cmtk4q\backends
Complete output (20 lines):
Der Befehl "which" ist entweder falsch geschrieben oder
konnte nicht gefunden werden.
Traceback (most recent call last):
File "D:\Anaconda\bin\envs\cosimo_pytorch\lib\site-packages\pip_vendor\pep517_in_process.py", line 280, in
main()
File "D:\Anaconda\bin\envs\cosimo_pytorch\lib\site-packages\pip_vendor\pep517_in_process.py", line 263, in main
json_out['return_val'] = hook(**hook_input['kwargs'])
File "D:\Anaconda\bin\envs\cosimo_pytorch\lib\site-packages\pip_vendor\pep517_in_process.py", line 114, in get_requires_for_build_wheel
return hook(config_settings)
File "C:\Users\XIAN_L1\AppData\Local\Temp\pip-build-env-c_j_teqt\overlay\Lib\site-packages\setuptools\build_meta.py", line 150, in get_requires_for_build_wheel
config_settings, requirements=['wheel'])
File "C:\Users\XIAN_L
1\AppData\Local\Temp\pip-build-env-c_j_teqt\overlay\Lib\site-packages\setuptools\build_meta.py", line 130, in _get_build_requires
self.run_setup()
File "C:\Users\XIAN_L1\AppData\Local\Temp\pip-build-env-c_j_teqt\overlay\Lib\site-packages\setuptools\build_meta.py", line 254, in run_setup
self).run_setup(setup_script=setup_script)
File "C:\Users\XIAN_L
1\AppData\Local\Temp\pip-build-env-c_j_teqt\overlay\Lib\site-packages\setuptools\build_meta.py", line 145, in run_setup
exec(compile(code, file, 'exec'), locals())
File "setup.py", line 10, in
raise RuntimeError('gfortran cannot be found. Please install gfortran. '
RuntimeError: gfortran cannot be found. Please install gfortran. On OS X, this can be done with "brew install gcc". On Linux, "apt-get install gfortran" should suffice.

ERROR: Command errored out with exit status 1: 'D:\Anaconda\bin\envs\cosimo_pytorch\python.exe' 'D:\Anaconda\bin\envs\cosimo_pytorch\lib\site-packages\pip_vendor\pep517_in_process.py' get_requires_for_build_wheel 'C:\Users\XIAN_L~1\AppData\Local\Temp\tmp7uvkgbia' Check the logs for full command output.

I have already set up my windows path after installing minGW for program to locate gfortran compiler.

Any help will be grateful.

Typing issue in `kernel.py`

Getting a typing issue when trying to call a periodicised kernel to create a covariance matrix:

dense(kern(tz))

The error is as follows:

...
~/Documents/Coding/msc_proj/tf2rc_env/lib/python3.7/site-packages/stheno/kernel.py in __call__(self, x)
    106     @_dispatch(object)
    107     def __call__(self, x):
--> 108         return self(x, x)
    109 
    110     @_dispatch(Input, Input)

~/Documents/Coding/msc_proj/tf2rc_env/lib/python3.7/site-packages/plum/function.cpython-37m-darwin.so in plum.function.WrappedMethod.__call__()

~/Documents/Coding/msc_proj/tf2rc_env/lib/python3.7/site-packages/plum/function.cpython-37m-darwin.so in plum.function.Function.__call__()

~/Documents/Coding/msc_proj/tf2rc_env/lib/python3.7/site-packages/stheno/kernel.py in __call__(self, x, y)
    333     @_dispatch(object, object)
    334     def __call__(self, x, y):
--> 335         return B.add(self[0](x, y), self[1](x, y))
    336 
    337     @_dispatch(object, object)

~/Documents/Coding/msc_proj/tf2rc_env/lib/python3.7/site-packages/plum/function.cpython-37m-darwin.so in plum.function.WrappedMethod.__call__()

~/Documents/Coding/msc_proj/tf2rc_env/lib/python3.7/site-packages/plum/function.cpython-37m-darwin.so in plum.function.Function.__call__()

~/Documents/Coding/msc_proj/tf2rc_env/lib/python3.7/site-packages/stheno/util.py in wrapped_f(*args)
     50     @wraps(f)
     51     def wrapped_f(*args):
---> 52         return f(*[uprank(x) for x in args])
     53 
     54     return wrapped_f

~/Documents/Coding/msc_proj/tf2rc_env/lib/python3.7/site-packages/stheno/kernel.py in __call__(self, x, y)
    395     @uprank
    396     def __call__(self, x, y):
--> 397         return self[0](*self._compute(x, y))
    398 
    399     @_dispatch(B.Numeric, B.Numeric)

~/Documents/Coding/msc_proj/tf2rc_env/lib/python3.7/site-packages/plum/function.cpython-37m-darwin.so in plum.function.WrappedMethod.__call__()

~/Documents/Coding/msc_proj/tf2rc_env/lib/python3.7/site-packages/plum/function.cpython-37m-darwin.so in plum.function.Function.__call__()

~/Documents/Coding/msc_proj/tf2rc_env/lib/python3.7/site-packages/stheno/util.py in wrapped_f(*args)
     50     @wraps(f)
     51     def wrapped_f(*args):
---> 52         return f(*[uprank(x) for x in args])
     53 
     54     return wrapped_f

~/Documents/Coding/msc_proj/tf2rc_env/lib/python3.7/site-packages/stheno/kernel.py in __call__(self, x, y)
    596     @uprank
    597     def __call__(self, x, y):
--> 598         return self[0](*self._compute(x, y))
    599 
    600     @_dispatch(B.Numeric, B.Numeric)

~/Documents/Coding/msc_proj/tf2rc_env/lib/python3.7/site-packages/stheno/kernel.py in _compute(self, x, y)
    608             return B.concat(B.sin(z), B.cos(z), axis=1)
    609 
--> 610         return feat_map(x), feat_map(y)
    611 
    612     @property

~/Documents/Coding/msc_proj/tf2rc_env/lib/python3.7/site-packages/stheno/kernel.py in feat_map(z)
    605     def _compute(self, x, y):
    606         def feat_map(z):
--> 607             z = B.divide(B.multiply(B.multiply(z, 2), B.pi), self.period)
    608             return B.concat(B.sin(z), B.cos(z), axis=1)
    609 

~/Documents/Coding/msc_proj/tf2rc_env/lib/python3.7/site-packages/plum/function.cpython-37m-darwin.so in plum.function.Function.__call__()

~/Documents/Coding/msc_proj/tf2rc_env/lib/python3.7/site-packages/lab/util.py in wrapper(*args, **kw_args)
    133 
    134             # Retry call.
--> 135             return getattr(B, f.__name__)(*args, **kw_args)
    136 
    137         return wrapper

~/Documents/Coding/msc_proj/tf2rc_env/lib/python3.7/site-packages/plum/function.cpython-37m-darwin.so in plum.function.Function.__call__()

~/Documents/Coding/msc_proj/tf2rc_env/lib/python3.7/site-packages/lab/tensorflow/generic.py in multiply(a, b)
    111 @dispatch(TFNumeric, TFNumeric)
    112 def multiply(a, b):
--> 113     return tf.multiply(a, b)
    114 
    115 

~/Documents/Coding/msc_proj/tf2rc_env/lib/python3.7/site-packages/tensorflow_core/python/util/dispatch.py in wrapper(*args, **kwargs)
    178     """Call target, and fall back on dispatchers if there is a TypeError."""
    179     try:
--> 180       return target(*args, **kwargs)
    181     except (TypeError, ValueError):
    182       # Note: convert_to_eager_tensor currently raises a ValueError, not a

~/Documents/Coding/msc_proj/tf2rc_env/lib/python3.7/site-packages/tensorflow_core/python/ops/math_ops.py in multiply(x, y, name)
    329 @dispatch.add_dispatch_support
    330 def multiply(x, y, name=None):
--> 331   return gen_math_ops.mul(x, y, name)
    332 
    333 

~/Documents/Coding/msc_proj/tf2rc_env/lib/python3.7/site-packages/tensorflow_core/python/ops/gen_math_ops.py in mul(x, y, name)
   6696       else:
   6697         message = e.message
-> 6698       _six.raise_from(_core._status_to_exception(e.code, message), None)
   6699   # Add nodes to the TensorFlow graph.
   6700   _, _, _op = _op_def_lib._apply_op_helper(

~/Documents/Coding/msc_proj/tf2rc_env/lib/python3.7/site-packages/six.py in raise_from(value, from_value)

InvalidArgumentError: cannot compute Mul as input #1(zero-based) was expected to be a double tensor but is a int32 tensor [Op:Mul]

Fixed it locally in kernel.py as follows, as a temporary solution:

 605     def _compute(self, x, y):
 606         def feat_map(z):
 607             z = B.divide(B.multiply(B.multiply(z, np.float64(2.0)), np.float64(B.pi)), np.float64(self.period))
 608             return B.concat(B.sin(z), B.cos(z), axis=1)
 609 
 610         return feat_map(x), feat_map(y)

Unable to jit logpdf

Description of the bug

Hi @wesselb,

I am trying to write some GP code in JAX and accelerate it with jax.jit, but it is failing due to a numpy conversion happening in the process. A potential solution seems to comment out code checking for NaN values in logpdf function (and it works), but you can suggest a better solution for this. Also, chex mentions that it allows testing code with and without jitting; it could be used in testing at some point in the future.

Code

import jax
import jax.numpy as jnp

from stheno.jax import GP, EQ

x = jnp.arange(10)
y = jnp.arange(10)
lengthscale = jnp.array(1.0)
loss_fn = lambda lengthscale: GP(EQ().stretch(lengthscale))(x).logpdf(y)
grad_fn = jax.jit(jax.grad(loss_fn))
grad_fn(lengthscale)

Output

---------------------------------------------------------------------------
UnfilteredStackTrace                      Traceback (most recent call last)
[<ipython-input-6-9ed89be9714a>](https://localhost:8080/#) in <module>
     10 grad_fn = jax.jit(jax.grad(loss_fn))
---> 11 grad_fn(lengthscale)

45 frames
[/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py](https://localhost:8080/#) in reraise_with_filtered_traceback(*args, **kwargs)
    161     try:
--> 162       return fun(*args, **kwargs)
    163     except Exception as e:

[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in cache_miss(*args, **kwargs)
    530         device=device, backend=backend, name=flat_fun.__name__,
--> 531         donated_invars=donated_invars, inline=inline, keep_unused=keep_unused)
    532     out_pytree_def = out_tree()

[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in bind(self, fun, *args, **params)
   1962   def bind(self, fun, *args, **params):
-> 1963     return call_bind(self, fun, *args, **params)
   1964 

[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in call_bind(primitive, fun, *args, **params)
   1978   fun_ = lu.annotate(fun_, fun.in_type)
-> 1979   outs = top_trace.process_call(primitive, fun_, tracers, params)
   1980   return map(full_lower, apply_todos(env_trace_todo(), outs))

[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in process_call(self, primitive, f, tracers, params)
    688   def process_call(self, primitive, f, tracers, params):
--> 689     return primitive.impl(f, *tracers, **params)
    690   process_map = process_call

[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in _xla_call_impl(***failed resolving arguments***)
    233   compiled_fun = xla_callable(fun, device, backend, name, donated_invars,
--> 234                               keep_unused, *arg_specs)
    235   try:

[/usr/local/lib/python3.7/dist-packages/jax/linear_util.py](https://localhost:8080/#) in memoized_fun(fun, *args)
    294     else:
--> 295       ans = call(fun, *args)
    296       cache[key] = (ans, fun.stores)

[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in _xla_callable_uncached(fun, device, backend, name, donated_invars, keep_unused, *arg_specs)
    324     return lower_xla_callable(fun, device, backend, name, donated_invars, False,
--> 325                               keep_unused, *arg_specs).compile().unsafe_call
    326 

[/usr/local/lib/python3.7/dist-packages/jax/_src/profiler.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
    312     with TraceAnnotation(name, **decorator_kwargs):
--> 313       return func(*args, **kwargs)
    314     return wrapper

[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in lower_xla_callable(fun, device, backend, name, donated_invars, always_lower, keep_unused, *arg_specs)
    400     jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
--> 401         fun, pe.debug_info_final(fun, "jit"))
    402   out_avals, kept_outputs = util.unzip2(out_type)

[/usr/local/lib/python3.7/dist-packages/jax/_src/profiler.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
    312     with TraceAnnotation(name, **decorator_kwargs):
--> 313       return func(*args, **kwargs)
    314     return wrapper

[/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py](https://localhost:8080/#) in trace_to_jaxpr_final2(fun, debug_info)
   2024     with core.new_sublevel():
-> 2025       jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
   2026     del fun, main

[/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py](https://localhost:8080/#) in trace_to_subjaxpr_dynamic2(fun, main, debug_info)
   1974     in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
-> 1975     ans = fun.call_wrapped(*in_tracers_)
   1976     out_tracers = map(trace.full_raise, ans)

[/usr/local/lib/python3.7/dist-packages/jax/linear_util.py](https://localhost:8080/#) in call_wrapped(self, *args, **kwargs)
    167     try:
--> 168       ans = self.f(*args, **dict(self.params, **kwargs))
    169     except:

[/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py](https://localhost:8080/#) in reraise_with_filtered_traceback(*args, **kwargs)
    161     try:
--> 162       return fun(*args, **kwargs)
    163     except Exception as e:

[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in grad_f(*args, **kwargs)
   1002   def grad_f(*args, **kwargs):
-> 1003     _, g = value_and_grad_f(*args, **kwargs)
   1004     return g

[/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py](https://localhost:8080/#) in reraise_with_filtered_traceback(*args, **kwargs)
    161     try:
--> 162       return fun(*args, **kwargs)
    163     except Exception as e:

[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in value_and_grad_f(*args, **kwargs)
   1078     if not has_aux:
-> 1079       ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
   1080     else:

[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in _vjp(fun, has_aux, reduce_axes, *primals)
   2497     out_primal, out_vjp = ad.vjp(
-> 2498         flat_fun, primals_flat, reduce_axes=reduce_axes)
   2499     out_tree = out_tree()

[/usr/local/lib/python3.7/dist-packages/jax/interpreters/ad.py](https://localhost:8080/#) in vjp(traceable, primals, has_aux, reduce_axes)
    132   if not has_aux:
--> 133     out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
    134   else:

[/usr/local/lib/python3.7/dist-packages/jax/interpreters/ad.py](https://localhost:8080/#) in linearize(traceable, *primals, **kwargs)
    121   jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
--> 122   jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
    123   out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)

[/usr/local/lib/python3.7/dist-packages/jax/_src/profiler.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
    312     with TraceAnnotation(name, **decorator_kwargs):
--> 313       return func(*args, **kwargs)
    314     return wrapper

[/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py](https://localhost:8080/#) in trace_to_jaxpr_nounits(fun, pvals, instantiate)
    768     fun = trace_to_subjaxpr_nounits(fun, main, instantiate)
--> 769     jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
    770     assert not env

[/usr/local/lib/python3.7/dist-packages/jax/linear_util.py](https://localhost:8080/#) in call_wrapped(self, *args, **kwargs)
    167     try:
--> 168       ans = self.f(*args, **dict(self.params, **kwargs))
    169     except:

[<ipython-input-6-9ed89be9714a>](https://localhost:8080/#) in <lambda>(lengthscale)
      8 lengthscale = jnp.array(1.0)
----> 9 loss_fn = lambda lengthscale: GP(EQ().stretch(lengthscale))(x).logpdf(y)
     10 grad_fn = jax.jit(jax.grad(loss_fn))

[/usr/local/lib/python3.7/dist-packages/stheno/random.py](https://localhost:8080/#) in logpdf(self, x)
    261         if B.rank(x) == 2 and B.shape(x, 1) == 1:
--> 262             available = B.jit_to_numpy(~B.isnan(x[:, 0]))
    263             if not B.all(available):

[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in __call__(self, *args, **kw_args)
    583             if return_type is default_obj_type:
--> 584                 return method(*args, **kw_args)
    585             else:

[/usr/local/lib/python3.7/dist-packages/lab/generic.py](https://localhost:8080/#) in jit_to_numpy(*args)
   1532     else:
-> 1533         res = B.to_numpy(*args)
   1534         if B.control_flow.caching:

[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in __call__(self, *args, **kw_args)
    583             if return_type is default_obj_type:
--> 584                 return method(*args, **kw_args)
    585             else:

[/usr/local/lib/python3.7/dist-packages/lab/generic.py](https://localhost:8080/#) in to_numpy(a)
   1496     """
-> 1497     return convert(a, NPOrNum)
   1498 

[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in __call__(self, *args, **kw_args)
    583             if return_type is default_obj_type:
--> 584                 return method(*args, **kw_args)
    585             else:

[/usr/local/lib/python3.7/dist-packages/plum/promotion.py](https://localhost:8080/#) in convert(obj, type_to)
     31     """
---> 32     return _convert.invoke(type_of(obj), type_to)(obj, type_to)
     33 

[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in wrapped_method(*args, **kw_args)
    606         def wrapped_method(*args, **kw_args):
--> 607             return _convert(method(*args, **kw_args), return_type)
    608 

[/usr/local/lib/python3.7/dist-packages/plum/promotion.py](https://localhost:8080/#) in perform_conversion(obj, _)
     60     def perform_conversion(obj: type_from, _: type_to):
---> 61         return f(obj)
     62 

[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in __array__(self, *args, **kw)
    535   def __array__(self, *args, **kw):
--> 536     raise TracerArrayConversionError(self)
    537 

UnfilteredStackTrace: jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(bool[10])>with<DynamicJaxprTrace(level=0/1)>
The error occurred while tracing the function <lambda> at <ipython-input-6-9ed89be9714a>:9 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:i64[10,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(10, 1)] b
    from line /usr/local/lib/python3.7/dist-packages/lab/jax/shaping.py:17 (_expand_dims)

  operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    from line /usr/local/lib/python3.7/dist-packages/stheno/random.py:262 (logpdf)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

TracerArrayConversionError                Traceback (most recent call last)
[<ipython-input-6-9ed89be9714a>](https://localhost:8080/#) in <module>
      9 loss_fn = lambda lengthscale: GP(EQ().stretch(lengthscale))(x).logpdf(y)
     10 grad_fn = jax.jit(jax.grad(loss_fn))
---> 11 grad_fn(lengthscale)

[<ipython-input-6-9ed89be9714a>](https://localhost:8080/#) in <lambda>(lengthscale)
      7 y = jnp.arange(10)
      8 lengthscale = jnp.array(1.0)
----> 9 loss_fn = lambda lengthscale: GP(EQ().stretch(lengthscale))(x).logpdf(y)
     10 grad_fn = jax.jit(jax.grad(loss_fn))
     11 grad_fn(lengthscale)

[/usr/local/lib/python3.7/dist-packages/stheno/random.py](https://localhost:8080/#) in logpdf(self, x)
    260         # Handle missing data. We don't handle missing data for batched computation.
    261         if B.rank(x) == 2 and B.shape(x, 1) == 1:
--> 262             available = B.jit_to_numpy(~B.isnan(x[:, 0]))
    263             if not B.all(available):
    264                 # Take the elements of the mean, variance, and inputs corresponding to

[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in __call__(self, *args, **kw_args)
    582             # to speed up the common case.
    583             if return_type is default_obj_type:
--> 584                 return method(*args, **kw_args)
    585             else:
    586                 return _convert(method(*args, **kw_args), return_type)

[/usr/local/lib/python3.7/dist-packages/lab/generic.py](https://localhost:8080/#) in jit_to_numpy(*args)
   1531         return B.control_flow.get_outcome("to_numpy")
   1532     else:
-> 1533         res = B.to_numpy(*args)
   1534         if B.control_flow.caching:
   1535             B.control_flow.set_outcome("to_numpy", res)

[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in __call__(self, *args, **kw_args)
    582             # to speed up the common case.
    583             if return_type is default_obj_type:
--> 584                 return method(*args, **kw_args)
    585             else:
    586                 return _convert(method(*args, **kw_args), return_type)

[/usr/local/lib/python3.7/dist-packages/lab/generic.py](https://localhost:8080/#) in to_numpy(a)
   1495         `np.ndarray`: `a` as NumPy.
   1496     """
-> 1497     return convert(a, NPOrNum)
   1498 
   1499 

[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in __call__(self, *args, **kw_args)
    582             # to speed up the common case.
    583             if return_type is default_obj_type:
--> 584                 return method(*args, **kw_args)
    585             else:
    586                 return _convert(method(*args, **kw_args), return_type)

[/usr/local/lib/python3.7/dist-packages/plum/promotion.py](https://localhost:8080/#) in convert(obj, type_to)
     30         object: `obj` converted to type `type_to`.
     31     """
---> 32     return _convert.invoke(type_of(obj), type_to)(obj, type_to)
     33 
     34 

[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in wrapped_method(*args, **kw_args)
    605         @wraps(self._f)
    606         def wrapped_method(*args, **kw_args):
--> 607             return _convert(method(*args, **kw_args), return_type)
    608 
    609         return wrapped_method

[/usr/local/lib/python3.7/dist-packages/plum/promotion.py](https://localhost:8080/#) in perform_conversion(obj, _)
     59     @_convert.dispatch
     60     def perform_conversion(obj: type_from, _: type_to):
---> 61         return f(obj)
     62 
     63 

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(bool[10])>with<DynamicJaxprTrace(level=0/1)>
The error occurred while tracing the function <lambda> at <ipython-input-6-9ed89be9714a>:9 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:i64[10,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(10, 1)] b
    from line /usr/local/lib/python3.7/dist-packages/lab/jax/shaping.py:17 (_expand_dims)

  operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    from line /usr/local/lib/python3.7/dist-packages/stheno/random.py:262 (logpdf)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

Description of your environment

Tried this in Google colab.

Slow import with tensorflow_probability

Description of the bug

Hi @wesselb, when I try to import stheno with tensorflow_probability (reproduced in a fresh Google colab environment), it takes a long time, around 30 seconds.

Code Time taken
from time import time
init = time()
import tensorflow_probability.substrates.jax as tfp
from stheno import GP, EQ, PseudoObs, PseudoObservationsFITC, PseudoObsDTC
print(time()-init)
28 seconds
from time import time
init = time()
from stheno import GP, EQ, PseudoObs, PseudoObservationsFITC, PseudoObsDTC
print(time()-init)
3.6 seconds
from time import time
init = time()
import tensorflow_probability.substrates.jax as tfp
print(time()-init)
3.79 seconds

Hyperparameters not getting optimized.

Hi, I was trying to create a minimal working example for hyperparameter optimization. I tried the following code.

import tensorflow as tf
from varz.tensorflow import Vars, minimise_l_bfgs_b
from stheno import GP, EQ
import lab.tensorflow as B

# Sample a true, underlying function and observations with known noise.
x_obs = B.linspace(0, 7, 100)
f_fixed = GP(3 * EQ().periodic(5))
noise_true = 0.01
f_true, y_obs = f_fixed.measure.sample(f_fixed(x_obs), f_fixed(x_obs, noise_true))

# Construct a model with learnable parameters.
def model(vs):
    kernel = vs.positive(name="variance") * EQ().periodic(vs.positive(name="period"))
    noise = vs.positive(name="noise")
    return GP(kernel), noise

# Define an objective function.
def objective(vs):
    f, noise = model(vs)
    return -f(x_obs, noise).logpdf(y_obs)

# Perform optimisation and print the learned parameters.
vs = Vars(tf.float64)
minimise_l_bfgs_b(objective, vs, trace=True, jit=True)
vs.print()

Output

Minimisation of "objective":
    Iteration 1/1000:
        Time elapsed: 0.3 s
        Time left:  281.6 s
        Objective value: 115.7
    Iteration 12/1000:
        Time elapsed: 0.3 s
        Time left:  3.5 s
        Objective value: 115.6
    Done!
Termination message:
    CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH
variance:   0.4585
period:     0.4418
noise:      0.5115

The optimizer is not learning the hyperparameters correctly. What am I doing wrong here?

Multiplication of GPs

Currently multiplication between GPs works. What are the odds of renaming this functionality to prevent users from multiplying GPs and getting unexpected results?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.