Comments (2)
Hello Stephan,
I would love to understand at a high level how this package works
- First, I generate view indices using
jax.vmap
:Lines 133 to 163 in f5dd7f2
Lines 63 to 77 in f5dd7f2
- Second, I create a new function that applies the user-function, using view indices as the first input.
Given certain view indices, this function first retrieves the array potion usingjnp.ix_
then applies the user function on it. In the case ofrelative=True
or in other words, the indexing is relative (center is 0 , likenumba.stencil
), then I roll the array portion before applying the function.
for kmap
Lines 27 to 35 in f5dd7f2
for kscan
Lines 23 to 34 in f5dd7f2
- Third
For kmap
, I use jax.vmap
to vectorize the new view indices-accepting function over array of all possible view indices.
Lines 37 to 51 in f5dd7f2
For kscan
- my prime motivation- I use jax.lax.scan
to scan the indices array
Lines 36 to 47 in f5dd7f2
Does it support auto-diff?
Yes, definitely, the library relies on jax.numpy
, jax.vmap
, jax.lax.scan
, and jax.lax.switch
for it's internals.
How does performance compare on CPU/GPU/TPU (or whichever configs you've tried)?
I benchmarked jax.lax.conv_general_dilated_patches
and jax.lax.conv_general_dilated
for based on kmap
The code is under tests_and_benchmarks
. In general, kmap seems faster for many scenarios, especially on CPU*, However it needs more rigorous benchmarking, especially on TPU.
In general, my prime motivation is to solve PDEs using a stencil definition, which might require applying different functions at different locations of the array (ex., boundary),
This is the reason kernex
offers the ability to use kmap
and kscan
along with jax.lax.switch
to apply different functions on different portions of the array. The following example introduces the function mesh concept, where different stencils can be applied using indexing. The backbone for this feature is jax.lax.switch
Function mesh | Array equivalent |
F = kex.kmap(kernel_size=(1,))
F[0] = lambda x:x[0]**2
F[1:] = lambda x:x[0]**3
array = jnp.arange(1,11).astype('float32')
print(F(array))
>>> [1., 8., 27., 64., 125.,
... 216., 343., 512., 729., 1000.]
print(jax.grad(lambda x:jnp.sum(F(x)))(array))
>>> [2.,12.,27.,48.,75.,
... 108.,147.,192.,243.,300.] |
def F(x):
f1 = lambda x:x**2
f2 = lambda x:x**3
x = x.at[0].set(f1(x[0]))
x = x.at[1:].set(f2(x[1:]))
return x
array = jnp.arange(1,11).astype('float32')
print(F(array))
>>> [1., 8., 27., 64., 125.,
... 216., 343., 512., 729., 1000.]
print(jax.grad(lambda x: jnp.sum(F(x)))(array))
>>> [2.,12.,27.,48.,75.,
... 108.,147.,192.,243.,300.] |
from kernex.
OK great, thank you for sharing!
I agree that this is a very promising approach for implementing PDE kernels, and in general this is similar to the way I've implemented PDE solvers in JAX by hand (e.g., the wave equation solver).
conv_general_dilated
and conv_general_dilated_patches
use XLA's Convolution
operation, which is really optimized for convolutional neural networks with large numbers of channels. I wouldn't expect them to work well for PDE kernels, except perhaps on TPUs.
from kernex.
Related Issues (9)
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from kernex.