GithubHelp home page GithubHelp logo

interpax's People

Contributors

allen-adastra avatar f0uriest avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

interpax's Issues

Avoid recompilation after model surgery

Hi @f0uriest
Thank you so much for sharing this amazing library! It makes my life easier.
I have a question regarding jit compilation of the interpolated function (generated by Interpolator1D for example). In my problem, I want to update the coefficient of the interpolation function by a model surgery using equinox:

interpolated_fun = eqx.tree_at(lambda m: m.f, interpolated_fun , new_coefficient)

The thing is this 'interpolated_fun' is embedded in a large training step. If I updated the interpolation function in this way, the updated is not reflected expectedly, and I would have to redefine the whole training step, which is not an option. Do you have any suggestions for this time of thing? What I am thinking is to just implement the interpolated function as a normal function and treat the coefficients as an argument for the function.

Feature request for resampling of rectilinear grids

Hello, I was wondering what the timeline is on adding interpolation routines for nd rectilinear grids. I am currently a user of jax.scipy.map_coordinates for data on a 3D cartesian grid, and Iโ€™m hoping to transition soon to this package for a routine that is similarly fast and supports cubic splines.

grad, jacfwd, jacrev

Nice job indeed.
it would be nice to provide an example how to compute jax (grad, jacfwd, jacrev) on the interpolated function. Thanks

New release for check bugfix

The spline methods in PPoly are not jax transformable as #27 is not included in the current version of the pip package 0.3.1.

`interp1d` with periodic data not working as expected

When using interp1d with period given, the expected result is an interpolation assuming the signal is periodic. However in 0.3.2 it seems that this is not what is returned.

MWE:

from interpax import interp1d
import jax.numpy as jnp

xq = jnp.linspace(0,2*jnp.pi,endpoint=True)
x = jnp.linspace(0,2*jnp.pi,4,endpoint=False)
f = jnp.array([12,10,8,10])

fq= interp1d(xq, x, f, period=2*jnp.pi,method="cubic",derivative=0)

print(fq)

import matplotlib.pyplot as plt
plt.figure()
plt.plot(xq,fq)
plt.scatter(x,f)

import interpax
plt.title(f"interpax version {interpax.__version__}")

image

image

The change in 0.3.2 seems to be that the _make_periodic function no longer pads the x points to enforce periodicity (i.e. does not make the x and f arrays first and last points be the same). The above snippet works fine in 0.3.2 with linear as the interpolation method, but not with cubic or cubic2 (I have not tried the others)

Cubic Spline optimization

To start with, nice package! Somehow the Jax devs don't get how often the scientific computing community uses interpolation, and this seems to fill the gap quite well.

Context:
I needed a cubic spline a while ago, so I was searching it online and there were some implementations out there. But the problem is the performance is not quite what I need, and the bigger issue is they usually have a pretty large memory footprint because at some point in their code, they usually have jnp.diag, which creates a diagonal matrix that allocates O(n^2) memory. This is problematic when my basis is 10^6 points. Then I coded a lightweight version with equinox and lineax, Here is the code that build the spline representation using lineax (https://github.com/kazewong/JaxNRSur/blob/aa24a2d5d4221f420c24d7b02391ee1c762ca9cc/jaxNRSur/Spline.py#L59). This instead only allocate O(3n) memory since it knows the system we are solving is tridiagonal.

I am looking through your implementation of the cubic spline, which seems to do similar construct around

def _approx_df(x, f, method, axis, **kwargs):
.

Do you have a benchmark of the performance of your implementation? If you think this is relevant, I can help open a PR to put it in this code. It makes more sense that I ship my version of code out.

Dependency conflicts

Hello,

When I install interpax and then install a copy of Jax for my nvidia GPU I get the following error:

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. 
This behaviour is the source of the following dependency conflicts.
interpax 0.2.4 requires jax[cpu]<=0.4.20,>=0.3.2, but you have jax 0.4.24 which is incompatible.
Successfully installed jax-0.4.24 jaxlib-0.4.24+cuda11.cudnn86

Could you please comment? Thanks

jaxlib version

Great work that fills a clear need IMO!!

Any thoughts on bumping jax/jaxlib to 0.4.20?

Different results on different GPUs

Hi f0uriest,

I encountered an issue that interpolation results vary along different machines.

I used a 1d interpolator with the monotonic method, allowing extrap=True.

test machines: [CPU, RTX Titan, RTX 4090].
reference machine: CPU with double precision (x64).

below table presents relative $L^1$ error: abs(a - b).sum() / abs(b).sum()

precision CPU RTX Titan RTX 4090
x32 5.87719e-08 5.89367e-08 1.78212e-04
x64 reference 4.16375e-17 4.16375e-17

Since I used the same (xq, xp, yp), the errors of each row must coincide, respectively.

However, as you can see, interpolation on RTX 4090 with single precision produced quite an inaccurate result.

Do you have any ideas on this?

`axis` attribute of the `Interpolator2D` object?

There isn't any class documentation on this attribute and doesn't seem to have an effect on the output. What is the purpose of this attribute, is simply to give a unified API between the 1D and higher order interpolators?

JAX scipy

Hi, your lib is quite nice !

I have a question: JAX has a scipy module lacking an interpolate equivalent while your lib is close to file the gap. Do you plane to make a JAX PR?

In jax.scipy the idea is to get the same API as scipy. I have been able to make available JAX scipy.signal.fftconvolve. Even, if I had made some tests before entering in the PR work, it was a quite "long" process to conform to JAX lib coding style (at least for me).

Along this line, I am not sure that the use a equinox lib is necessary and would be questioned, as this is a third party lib which is dealing with NN.
What do you think?

Best.

CubicSpline not supporting grad JIT.

Jax: 0.4.28
Interpax: 0.3.1

Minimal reproduction:

import jax
import jax.numpy as jnp
from interpax import CubicSpline

x0, y0 = 0,0
x1, y1 = 1,-0.3

xm, ym = 0.5, 1

N = 6
p_x = jnp.linspace(x0, x1, N)
p_y = jnp.linspace(y0, y1, N) # spline control points, force them into a line at the start

def loss(p_y):
    f = CubicSpline(p_x, p_y,check=False) # My own guess that the check could have been at fault. Makes no difference.

    return (f(x0)-y0)**2 + (f(x1)-y1)**2 + (f(xm)-ym)**2

print('loss', loss(p_y))

dloss = jax.grad(loss)

print('dloss', dloss(p_y)) # Works fine


jdloss = jax.jit(jax.grad(loss))

print('jdloss', jdloss(p_y)) #Fails to get here. 

Fails on:
operation a:bool[6] = is_finite b
from line C:....\interpax_test.py:15 (loss)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

I have attached the full output.
interpax_test_output.txt

`DeprecationWarning` from `jax.core.pp_eqn_rules` with `jax` version `0.4.30` when using latest version of `interpax`

In a code which uses interpax as a dependency for interpolation, we find a Deprecation warning is being emitted:

Warning Traceback:

tests/conftest.py:12: in <module>
    from desc.coils import (
desc/coils.py:10: in <module>
    from desc.compute import get_params, rpz2xyz, rpz2xyz_vec, xyz2rpz, xyz2rpz_vec
desc/compute/__init__.py:31: in <module>
    from . import (
desc/compute/_curve.py:1: in <module>
    from interpax import interp1d
/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/interpax/__init__.py:6: in <module>
    from ._ppoly import (
/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/interpax/_ppoly.py:41: in <module>
    import equinox as eqx
/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/equinox/__init__.py:3: in <module>
    from . import debug as debug, internal as internal, nn as nn
/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/equinox/internal/__init__.py:46: in <module>
    from ._finalise_jaxpr import (
/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/equinox/internal/_finalise_jaxpr.py:186: in <module>
    from ._noinline import noinline_p
/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/equinox/internal/_noinline.py:3[79](https://github.com/PlasmaControl/DESC/actions/runs/9568396025/job/26378504793?pr=1058#step:7:80): in <module>
    jax.core.pp_eqn_rules[noinline_p] = _noinline_pretty_print
/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/jax/_src/deprecations.py:53: in getattr
    warnings.warn(message, DeprecationWarning, stacklevel=2)
E   DeprecationWarning: jax.core.pp_eqn_rules is deprecated.

extrap does not take float

The API for Interpolator2D says that extrap can be a float. The code returns an error with the function call len on a float.

elif len(extrap) == 2 and jnp.isscalar(extrap[0]): # same l,h for all dimensions

Here is code that produces the error

from interpax import Interpolator2D
import jax.numpy as jnp

x = jnp.linspace(0,10, 10)
y = jnp.linspace(0,8, 8)
z = jnp.zeros((10,8))+1.
interpol = Interpolator2D(x, y, z, extrap=0.)
interpol( 4.5, 5.3)

returns

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
File ~/miniforge3/envs/jsalt/lib/python3.11/site-packages/jax/_src/core.py:1605, in ShapedArray._len(self, ignored_tracer)
   1604 try:
-> 1605   return self.shape[0]
   1606 except IndexError as err:

IndexError: tuple index out of range

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
Cell In[13], line 8
      6 z = jnp.zeros((10,8))+1.
      7 interpol = Interpolator2D(x, y, z, extrap=0.)
----> 8 interpol( 4.5, 5.3)

File ~/miniforge3/envs/jsalt/lib/python3.11/site-packages/interpax/_spline.py:237, in Interpolator2D.__call__(self, xq, yq, dx, dy)
    222 def __call__(self, xq: jax.Array, yq: jax.Array, dx: int = 0, dy: int = 0):
    223     """Evaluate the interpolated function or its derivatives.
    224 
    225     Parameters
   (...)
    235         Interpolated values.
    236     """
--> 237     return interp2d(
    238         xq,
    239         yq,
    240         self.x,
    241         self.y,
    242         self.f,
    243         self.method,
    244         (dx, dy),
    245         self.extrap,
    246         self.period,
    247         **self.derivs,
    248     )

    [... skipping hidden 12 frame]

File ~/miniforge3/envs/jsalt/lib/python3.11/site-packages/interpax/_spline.py:648, in interp2d(xq, yq, x, y, f, method, derivative, extrap, period, **kwargs)
    646 periodx, periody = _parse_ndarg(period, 2)
    647 derivative_x, derivative_y = _parse_ndarg(derivative, 2)
--> 648 lowx, highx, lowy, highy = _parse_extrap(extrap, 2)
    650 if periodx is not None:
    651     xq, x, f, fx, fy, fxy = _make_periodic(xq, x, periodx, 0, f, fx, fy, fxy)

File ~/miniforge3/envs/jsalt/lib/python3.11/site-packages/interpax/_spline.py:1107, in _parse_extrap(extrap, n)
   1105 if isbool(extrap):  # same for lower,upper in all dimensions
   1106     return tuple(extrap for _ in range(2 * n))
-> 1107 elif len(extrap) == 2 and jnp.isscalar(extrap[0]):  # same l,h for all dimensions
   1108     return tuple(e for _ in range(n) for e in extrap)
   1109 elif len(extrap) == n and all(len(extrap[i]) == 2 for i in range(n)):

    [... skipping hidden 1 frame]

File ~/miniforge3/envs/jsalt/lib/python3.11/site-packages/jax/_src/core.py:1607, in ShapedArray._len(self, ignored_tracer)
   1605   return self.shape[0]
   1606 except IndexError as err:
-> 1607   raise TypeError("len() of unsized object") from err

TypeError: len() of unsized object

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.