asem000 / kernex Goto Github PK
View Code? Open in Web Editor NEWStencil computations in JAX
License: MIT License
Stencil computations in JAX
License: MIT License
How does Kernex handle borders and out-of-bounds indices, especially with padding="same"
? I skimmed the source code but did not find it. Very superficially it seems to me that it performs 0-padding. Is this the case?
If yes, it would be useful to support different schemes like "mirror", "edge", or ideally dropping the out-of-bounds indices.
Repro:
from pylab import *
import jax.numpy as jnp
import kernex
@kernex.kmap(kernel_size=(11, 11), padding="same")
def kernel(patch):
return jnp.mean(patch)
data = ones((100, 100))
out = kernel(data)
figure()
plot(out[50], "+-", label="Kernex")
plot(ones_like(out[50]), "+-", label="Ideal")
legend()
Does kernex support Pytrees? I did not find an example. It would be very useful to support moving-window filters with "global" weights or simply multiple inputs, such as a cross-channel bilateral filter in my case.
Repro:
import jax.numpy as jnp
import kernex
@kernex.kmap(kernel_size=(3, 3))
def kernel(tree):
x, y = tree
return jnp.sum(x * jnp.square(y))
data = jnp.arange(20 * 30).reshape((20, 30))
out = kernel((data, data))
raises
Traceback (most recent call last):
File "/home/clemisch/kernex_tree.py", line 52, in <module>
out = kernel((data, data))
^^^^^^^^^^^^^^^^^^^^
File "/home/clemisch/venvs/11/lib64/python3.11/site-packages/kernex/interface/kernel_interface.py", line 131, in call
self.shape = array.shape
^^^^^^^^^^^
AttributeError: 'tuple' object has no attribute 'shape'
This project looks really cool!
I would love to understand at a high level how this package works -- how do you actually implement stencil computations in JAX? Do you reuse jax.lax.scan
or something else? Does it support auto-diff? How does performance compare on CPU/GPU/TPU (or whichever configs you've tried)?
Enable the kernel map backend.
The default now is vmap
, however it is beneficial to enable the user to choose among other variants like pmap
and lax.map
Docstring and some example in readme.
kernex operate on views of the first function array argument, it is useful to add argnum/argname to select which argument to operate on in a similar fashion in jax transformation
I think mode="reflect"
for padding_kwargs
is incorrect:
import jax.numpy as jnp
import kernex
@kernex.kmap(
kernel_size=(3,),
padding=("same"),
relative=False,
padding_kwargs=dict(mode="reflect"),
)
def f(x):
return x
x = jnp.array([1, 2, 3, 4, 5])
y = f(x)
z = jnp.pad(x, 1, mode="reflect")
print("x: ", x)
print("y: ", y)
print("z: ", z)
gives
x: [1 2 3 4 5]
y: [[3 1 2] # <-- the `3` is incorrect, should be `2`
[1 2 3]
[2 3 4]
[3 4 5]
[4 5 4]]
z: [2 1 2 3 4 5 4] # <-- here, the first element is `2`
The Kernex output reflects incorrectly: the first element is 3
instead of 2
.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.