GithubHelp home page GithubHelp logo

patrick-kidger / torchcubicspline Goto Github PK

View Code? Open in Web Editor NEW
185.0 3.0 17.0 36 KB

Interpolating natural cubic splines. Includes batching, GPU support, support for missing values, evaluating derivatives of the spline, and backpropagation.

License: Apache License 2.0

Python 100.00%
pytorch interpolation spline

torchcubicspline's Introduction

torchcubicspline

Interpolating natural cubic splines using PyTorch. Includes support for:

  • Batching
  • GPU support and backpropagation via PyTorch
  • Support for missing values (represent them as NaN)
  • Evaluating the first derivative of the spline

Installation

pip install git+https://github.com/patrick-kidger/torchcubicspline.git

Example

Simple example:

import torch
from torchcubicspline import(natural_cubic_spline_coeffs, 
                             NaturalCubicSpline)

length, channels = 7, 3
t = torch.linspace(0, 1, length)
x = torch.rand(length, channels)
coeffs = natural_cubic_spline_coeffs(t, x)
spline = NaturalCubicSpline(coeffs)
point = torch.tensor(0.4)
out = spline.evaluate(point)

With multiple batch and evaluation dimensions:

import torch
from torchcubicspline import(natural_cubic_spline_coeffs, 
                             NaturalCubicSpline)

t = torch.linspace(0, 1, 7)
# (2, 1) are batch dimensions. 7 is the time dimension
# (of the same length as t). 3 is the channel dimension.
x = torch.rand(2, 1, 7, 3)
coeffs = natural_cubic_spline_coeffs(t, x)
# coeffs is a tuple of tensors

# ...at this point you can save the coeffs, put them
# through PyTorch's Datasets and DataLoaders, etc...

spline = NaturalCubicSpline(coeffs)

point = torch.tensor(0.4)
# will be a tensor of shape (2, 1, 3), corresponding to
# batch, batch, and channel dimensions
out = spline.derivative(point)

point = torch.tensor([[0.4, 0.5]])
# will be a tensor of shape (2, 1, 1, 2, 3), corresponding to
# batch, batch, time, time and channel dimensions
out = spline.derivative(point)

Functionality

Functionality is provided via the natural_cubic_spline_coeffs function and NaturalCubicSpline class.

natural_cubic_spline_coeffs takes an increasing sequence of times represented by a tensor t of shape (length,) and some corresponding observations x of shape (..., length, channels), where ... are batch dimensions, and each (length, channels) slice represents a sequence of length points, each point with channels many values.

Then calling

coeffs = natural_cubic_spline_coeffs(t, x)
spline = NaturalCubicSpline(coeffs)

produces an instance spline such that

spline.evaluate(t[i]) == x[..., i, :]

for all i.

Why is there a function and a class?

The slow bit is done during natural_cubic_spline_coeffs. The fast bit is NaturalCubicSpline. The returned coeffs are a tuple of PyTorch tensors, so you can take this opportunity to save or load them, push them through torch.utils.data.Dataset or torch.utils.data.DataLoader, etc.

Derivatives

The derivative of the spline at a point may be calculated via spline.derivative. (Not be confused with backpropagation, which is also supported through both spline.evaluate and spline.derivative.)

Missing values

Support for missing values is done by setting that element of x to NaN. In particular this allows for batching elements with different observation times: take times to be the observation times of all elements in the batch, and just set each element to have a missing observation NaN at the times of the observations of the other batch elements.

Limitations

If possible, you should cache the coefficients returned by natural_cubic_spline_coeffs. In particular if there are missing values then the computation can be quite slow.

Any issues?

Any issues or questions - open an issue to let me know. :)

torchcubicspline's People

Contributors

alisterburt avatar jhrmnn avatar mossjacob avatar patrick-kidger avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

torchcubicspline's Issues

inefficient _validate_input and mistake

The function _validate_input seems to be incorrect. The code tries to block non-monotonic data, but since prev_t_i is not updated, it doesn't seem to work. And in a Cuda environment, this part of the code is very inefficient.

Backwards Pass through spline.derivative gives nan gradients

Hi there,
I'm using torchcubicspline as part of a larger optimization workflow where I'm finding the optimal parametric sampling points along a spline (1D vector that spans 0-1) to minimize a loss function.
One part of my algorithm finds the tangent to each sampling point, and for some reason, the backwards pass sometimes returns nan gradients at this point during the spline.derivative function call.

Using torch.autograd.detect_anomaly, i found that the problem is specifically in the following line:
image
Where the error is:
image
If you require any other information please let me know! The optimization code is a bit large, otherwise, I would have included a minimal code that reproduces the problem.
Thanks!

Complex value compatibility

Hi,

Thanks for putting this together, it's a really useful bit of code. I wonder what it would take to make it compatible with complex dtypes? I'd be very grateful for any advice you could provide. Hope this isn't a bother.

Regards,

Stephen

evaluating at different time points per batch

Hi,

I've got a tensor of the shape batch_size x T x nr_channels and I created a cubic spline accordingly.

Let's say I have created this tensor with shape (2, 10, 1)

Now I want to query [[0], [1]] (e.g. value at t=0 for batch 0 and value at t=1 for batch 1)

I know I can query spline.evaluate(torch.tensor(0)) to get the value at t=0 for both batches, but how can I query the above such that it returns a tensor of shape (2, 1, 1) or (2, 1) as the above query returns (2, 2, 1) or even (2, 2, 1, 1).

clamped cubic spline supported?

Hello,

Thanks for this very useful package. I have a question, Does this package support the clamped cubic spline, i.e., s'(t_0)=a, s'(t_T)=bīŧŸ

Qunxi

Applying on typical PyTorch data format

Hi,

Is it possible to apply the cubic spline interpolation on one dimension of a data with the data format as in PyTorch [BatchSize, Channel, Height, Width]?

Thanks

1 Dimension Only?

Hi I wonder if this code works for 2 or 3 dimensional data!

Thanks.

Interpolating over 2D-grid.

Hi Patrick,

first of all, thank you very much for creating this great library.

I've got a very basic question:

I want to interpolate over a 2D-grid. In SciPy, I run the following code:

import numpy as np
from scipy import interpolate
# Data Generation
X = np.linspace(0.0, 1.0, num=12)
Y = np.linspace(0.0, 1.0, num=4)
z_list = [
    [0.8, 1.6, 2.0, 3.0],
    [0.6, 1.4, 2.5, 2.9],
    [0.2, 0.9, 0.6, 2.8],
    [0.5, 1.0, 1.2, 2.7],
    [0.5, 1.5, 1.6, 2.6],
    [0.5, 1.4, 1.5, 2.5],
    [0.3, 1.4, 1.2, 2.4],
    [0.5, 1.2, 1.5, 2.3],
    [0.4, 1.2, 1.5, 2.0],
    [0.35, 0.9, 1.4, 1.5],
    [0.35, 0.9, 1.4, 1.5],
    [0.35, 0.9, 1.4, 1.5],
]
Z = np.array(z_list)
Z = Z.transpose()
interpol_function = interpolate.interp2d(X, Y, Z, kind="cubic")

Then, I can evaluate interpol_function(X, Y) at arbitrary X and Y.

I tried to match the dimensions to your "batch-length-channel" format, resulting in

num_X, num_Y = 12, 4
X = torch.linspace(0.0, 1.0, steps=num_X)
Y = torch.linspace(0.0, 1.0, steps=num_Y)
t = torch.meshgrid([X, Y])
x = torch.Tensor(
    [
        [0.8, 1.6, 2.0, 3.0],
        [0.6, 1.4, 2.5, 2.9],
        [0.2, 0.9, 0.6, 2.8],
        [0.5, 1.0, 1.2, 2.7],
        [0.5, 1.5, 1.6, 2.6],
        [0.5, 1.4, 1.5, 2.5],
        [0.3, 1.4, 1.2, 2.4],
        [0.5, 1.2, 1.5, 2.3],
        [0.4, 1.2, 1.5, 2.0],
        [0.35, 0.9, 1.4, 1.5],
        [0.35, 0.9, 1.4, 1.5],
        [0.35, 0.9, 1.4, 1.5],
    ]
)
x = torch.reshape(x, (num_X, num_Y, 1))
coeffs = natural_cubic_spline_coeffs(t, x)

However, this results in an error due to the time tensor t not containing one-dimensional floating point values.
How would you recommend to reshape the dimensions here?

Availability on pypi

Hi Patrick,

Would it be possible to have torchcubicspline available on pypi? I have a package that uses it (https://github.com/rorymaizels/velvet) and I would like to make this package available for download via pip, but this requires all dependencies to be hosted on pypi as well. I'm sure there are work-arounds if this isn't possible - let me know!

Cheers and thanks for the package,
Rory

Differentiation with respect to coefficients

Hi!

I would like to ask whether it's possible and what would be the best approach to use your implementation to differentiate with respect to coefficients and/or coordinates of interpolated points.

from torchcubicspline import(natural_cubic_spline_coeffs, 
                             NaturalCubicSpline)

length, channels = 7, 3
t = torch.linspace(0, 1, length)
x = torch.rand(length, channels)
coeffs = natural_cubic_spline_coeffs(t, x)
spline = NaturalCubicSpline(coeffs)
point = torch.tensor(0.4)
out = spline.evaluate(point)

Looking at the provided simple example I would like to backpropagate from out to x - is it possible?

Evaluating Each Spline at Unique Time Values in the Output Tensor

I Patrick,
Great repo,
I have a short question regarding the evaluation function. It returns the dim (batch, batch, time, time, and channel)
Is it possible to evaluate a different t for each spline and return that? Let's say I have 3 splines, and I want to assess each at a different t-value. At the moment, I always get an out-tensor where each spline gets evaluated with each t-value. However, I would like to have the first spline evaluated at the first t-value, and so on.

Thank you

Speed of tridiagonal_solve()

Hi. As you warn in the code, tridiagonal_solve() is quite slow. I've compared to plain torch.solve(), which is much faster, so I'll be using that in my application, but I was wondering if you are interested in a patch, or perhaps you had other reasons to use the Thomas algorithm.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤ī¸ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.