GithubHelp home page GithubHelp logo

Implementation strategy about kernex HOT 2 OPEN

asem000 avatar asem000 commented on May 19, 2024
Implementation strategy

from kernex.

Comments (2)

ASEM000 avatar ASEM000 commented on May 19, 2024

Hello Stephan,

I would love to understand at a high level how this package works

  • First, I generate view indices using jax.vmap :

    kernex/kernex/_src/utils.py

    Lines 133 to 163 in f5dd7f2

    def general_product(*args):
    """Equivalent to tuple(zip(*itertools.product(*args)))` for arrays
    Example:
    >>> general_product(
    ... jnp.array([[1,2],[3,4]]),
    ... jnp.array([[5,6],[7,8]]))
    (
    DeviceArray([[[1, 2],[1, 2]],[[3, 4],[3, 4]]], dtype=int32),
    DeviceArray([[[5, 6],[7, 8]],[[5, 6],[7, 8]]], dtype=int32)
    )
    >>> tuple(zip(*(itertools.product([[1,2],[3,4]],[[5,6],[7,8]]))))
    (
    ([1, 2], [1, 2], [3, 4], [3, 4]),
    ([5, 6], [7, 8], [5, 6], [7, 8])
    )
    """
    def nvmap(n):
    in_axes = [None] * len(args)
    in_axes[-n] = 0
    return (
    vmap(lambda *x: x, in_axes=in_axes)
    if n == 1
    else vmap(nvmap(n - 1), in_axes=in_axes)
    )
    return nvmap(len(args))(*args)

    @cached_property
    def views(self) -> tuple[jnp.ndarray, ...]:
    """Generate absolute sampling matrix"""
    # this function is cached because it is called multiple times
    # and it is expensive to calculate
    # the view is the indices of the array that is used to calculate
    # the output value
    dim_range = tuple(
    general_arange(di, ki, si, x0, xf)
    for (di, ki, si, (x0, xf)) in zip(
    self.shape, self.kernel_size, self.strides, self.border
    )
    )
    matrix = general_product(*dim_range)
    return tuple(map(lambda xi, wi: xi.reshape(-1, wi), matrix, self.kernel_size))

  • Second, I create a new function that applies the user-function, using view indices as the first input.
    Given certain view indices, this function first retrieves the array potion using jnp.ix_ then applies the user function on it. In the case of relative=True or in other words, the indexing is relative (center is 0 , like numba.stencil), then I roll the array portion before applying the function.

for kmap

def reduce_map_func(self, func, *args, **kwargs) -> Callable:
if self.relative:
# if the function is relative, the function is applied to the view
return lambda view, array: func(
roll_view(array[ix_(*view)]), *args, **kwargs
)
else:
return lambda view, array: func(array[ix_(*view)], *args, **kwargs)

for kscan

def reduce_scan_func(self, func, *args, **kwargs) -> Callable:
if self.relative:
# if the function is relative, the function is applied to the view
# the result is a 1D array of the same length as the number of views
return lambda view, array: array.at[self.index_from_view(view)].set(
func(roll_view(array[ix_(*view)]), *args, **kwargs)
)
else:
return lambda view, array: array.at[self.index_from_view(view)].set(
func(array[ix_(*view)], *args, **kwargs)
)


  • Third

For kmap, I use jax.vmap to vectorize the new view indices-accepting function over array of all possible view indices.

def __single_call__(self, array: jnp.ndarray, *args, **kwargs):
padded_array = jnp.pad(array, self.pad_width)
# convert the function to a callable that takes a view and an array
# and returns the result of the function applied to the view
reduced_func = self.reduce_map_func(self.funcs[0], *args, **kwargs)
# apply the function to each view using vmap
# the result is a 1D array of the same length as the number of views
result = vmap(lambda view: reduced_func(view, padded_array))(self.views)
# reshape the result to the output shape
# for example if the input shape is (3, 3) and the kernel shape is (2, 2)
# and the stride is 1 , and the padding is 0, the output shape is (2, 2)
return result.reshape(*self.output_shape, *result.shape[1:])

For kscan - my prime motivation- I use jax.lax.scan to scan the indices array

def __single_call__(self, array, *args, **kwargs):
padded_array = jnp.pad(array, self.pad_width)
reduced_func = self.reduce_scan_func(self.funcs[0], *args, **kwargs)
def scan_body(padded_array, view):
result = reduced_func(view, padded_array).reshape(padded_array.shape)
return result, result[self.index_from_view(view)]
return lax.scan(scan_body, padded_array, self.views)[1].reshape(
self.output_shape
)

Does it support auto-diff?

Yes, definitely, the library relies on jax.numpy, jax.vmap, jax.lax.scan, and jax.lax.switch for it's internals.

How does performance compare on CPU/GPU/TPU (or whichever configs you've tried)?

I benchmarked jax.lax.conv_general_dilated_patches and jax.lax.conv_general_dilated for based on kmap
The code is under tests_and_benchmarks. In general, kmap seems faster for many scenarios, especially on CPU*, However it needs more rigorous benchmarking, especially on TPU.

In general, my prime motivation is to solve PDEs using a stencil definition, which might require applying different functions at different locations of the array (ex., boundary),
This is the reason kernex offers the ability to use kmap and kscan along with jax.lax.switch to apply different functions on different portions of the array. The following example introduces the function mesh concept, where different stencils can be applied using indexing. The backbone for this feature is jax.lax.switch

Function mesh Array equivalent
F = kex.kmap(kernel_size=(1,))
F[0] = lambda x:x[0]**2
F[1:] = lambda x:x[0]**3





array = jnp.arange(1,11).astype('float32')
print(F(array))
>>> [1., 8., 27., 64., 125.,
... 216., 343., 512., 729., 1000.]

print(jax.grad(lambda x:jnp.sum(F(x)))(array))
>>> [2.,12.,27.,48.,75.,
... 108.,147.,192.,243.,300.]
def F(x):
    f1 = lambda x:x**2
    f2 = lambda x:x**3
    x = x.at[0].set(f1(x[0]))
    x = x.at[1:].set(f2(x[1:]))
    return x

array = jnp.arange(1,11).astype('float32')
print(F(array))
>>> [1., 8., 27., 64., 125.,
... 216., 343., 512., 729., 1000.]

print(jax.grad(lambda x: jnp.sum(F(x)))(array))
>>> [2.,12.,27.,48.,75.,
... 108.,147.,192.,243.,300.]

from kernex.

shoyer avatar shoyer commented on May 19, 2024

OK great, thank you for sharing!

I agree that this is a very promising approach for implementing PDE kernels, and in general this is similar to the way I've implemented PDE solvers in JAX by hand (e.g., the wave equation solver).

conv_general_dilated and conv_general_dilated_patches use XLA's Convolution operation, which is really optimized for convolutional neural networks with large numbers of channels. I wouldn't expect them to work well for PDE kernels, except perhaps on TPUs.

from kernex.

Related Issues (9)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.