Differentiable Stencil computations in JAX
Installation |Description |Quick example |Function mesh |More Examples |Benchmarking
π οΈ Installation
pip install kernex
π Description
Kernex extends jax.vmap
/jax.lax.map
/jax.pmap
with kmap
and jax.lax.scan
with kscan
for general stencil computations.
The prime motivation for this package is to blend the solution process of PDEs into a NN setting.
β© Quick Example
kmap | kscan |
import kernex as kex
import jax.numpy as jnp
@kex.kmap(kernel_size=(3,))
def sum_all(x):
return jnp.sum(x)
>>> x = jnp.array([1,2,3,4,5])
>>> print(sum_all(x))
[ 6 9 12] |
import kernex as kex
import jax.numpy as jnp
@kex.kscan(kernel_size=(3,))
def sum_all(x):
return jnp.sum(x)
> > > x = jnp.array([1,2,3,4,5])
> > > print(sum_all(x))
> > > [ 6 13 22] |
πΈοΈ Function mesh concept
The objective is to apply f(x) = x^2 at index=0 and f(x) = x^3 at index=(1,10)
To achieve the following operation with jax.lax.switch
, we need a list of 10 functions correspoing to each cell of the example array.
For this reason , kernex adopts a modified version of jax.lax.switch
to reduce the number of branches required.
# function applies x^2 at boundaries, and applies x^3 to to the interior
βββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ
f = β x^2 β x^3 β x^3 β x^3 β x^3 β x^3 β x^3 β x^3 β x^3 β x^3 β
βββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ
βββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ
f( β 1 β 2 β 3 β 4 β 5 β 6 β 7 β 8 β 9 β 10 β ) =
βββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ
βββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ
β 1 β 8 β 27 β 64 β 125 β 216 β 343 β 512 β 729 β1000 β
βββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ
# Gradient of this function
βββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ
df/dx = β 2x β3x^2 β3x^2 β3x^2 β3x^2 β3x^2 β3x^2 β3x^2 β3x^2 β3x^2 β
βββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ
βββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ
df/dx( β 1 β 2 β 3 β 4 β 5 β 6 β 7 β 8 β 9 β 10 β ) =
βββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ
βββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ¬ββββββ
β 2 β 12 β 27 β 48 β 75 β 108 β 147 β 192 β 243 β 300 β
βββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ΄ββββββ
Function mesh | Array equivalent |
F = kex.kmap(kernel_size=(1,))
F[0] = lambda x:x[0]**2
F[1:] = lambda x:x[0]**3
array = jnp.arange(1,11).astype('float32')
print(F(array))
>>> [1., 8., 27., 64., 125.,
... 216., 343., 512., 729., 1000.]
print(jax.grad(lambda x:jnp.sum(F(x)))(array))
>>> [2.,12.,27.,48.,75.,
... 108.,147.,192.,243.,300.] |
def F(x):
f1 = lambda x:x**2
f2 = lambda x:x**3
x = x.at[0].set(f1(x[0]))
x = x.at[1:].set(f2(x[1:]))
return x
array = jnp.arange(1,11).astype('float32')
print(F(array))
>>> [1., 8., 27., 64., 125.,
... 216., 343., 512., 729., 1000.]
print(jax.grad(lambda x: jnp.sum(F(x)))(array))
>>> [2.,12.,27.,48.,75.,
... 108.,147.,192.,243.,300.] |
Additionally , we can combine the function mesh concept with stencil computation for scientific computing. See Linear convection in More examples section
π’ More examples
1οΈβ£ Convolution operation
import jax
import jax.numpy as jnp
import kernex as kex
@jax.jit
@kex.kmap(
kernel_size= (3,3,3),
padding = ('valid','same','same'))
def kernex_conv2d(x,w):
# JAX channel first conv2d with 3x3x3 kernel_size
return jnp.sum(x*w)
2οΈβ£ Laplacian operation
# see also
# https://numba.pydata.org/numba-doc/latest/user/stencil.html#basic-usage
import jax
import jax.numpy as jnp
import kernex as kex
@kex.kmap(
kernel_size=(3,3),
padding= 'valid',
relative=True) # `relative`= True enables relative indexing
def laplacian(x):
return ( 0*x[1,-1] + 1*x[1,0] + 0*x[1,1] +
1*x[0,-1] +-4*x[0,0] + 1*x[0,1] +
0*x[-1,-1] + 1*x[-1,0] + 0*x[-1,1] )
# apply laplacian
>>> print(laplacian(jnp.ones([10,10])))
DeviceArray(
[[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)
3οΈβ£ Get Patches of an array
import jax
import jax.numpy as jnp
import kernex as kex
@kex.kmap(kernel_size=(3,3),relative=True)
def identity(x):
# similar to numba.stencil
# this function returns the top left cell in the padded/unpadded kernel view
# or center cell if `relative`=True
return x[0,0]
# unlike numba.stencil , vector output is allowed in kernex
# this function is similar to
# `jax.lax.conv_general_dilated_patches(x,(3,),(1,),padding='same')`
@jax.jit
@kex.kmap(kernel_size=(3,3),padding='same')
def get_3x3_patches(x):
# returns 5x5x3x3 array
return x
mat = jnp.arange(1,26).reshape(5,5)
>>> print(mat)
[[ 1 2 3 4 5]
[ 6 7 8 9 10]
[11 12 13 14 15]
[16 17 18 19 20]
[21 22 23 24 25]]
# get the view at array index = (0,0)
>>> print(get_3x3_patches(mat)[0,0])
[[0 0 0]
[0 1 2]
[0 6 7]]
4οΈβ£ Linear convection
import jax
import jax.numpy as jnp
import kernex as kex
import matplotlib.pyplot as plt
# see https://nbviewer.org/github/barbagroup/CFDPython/blob/master/lessons/01_Step_1.ipynb
tmax,xmax = 0.5,2.0
nt,nx = 151,51
dt,dx = tmax/(nt-1) , xmax/(nx-1)
u = jnp.ones([nt,nx])
c = 0.5
# kscan moves sequentially in row-major order and updates in-place using lax.scan.
F = kernex.kscan(
kernel_size = (3,3),
padding = ((1,1),(1,1)),
named_axis={0:'n',1:'i'}, # n for time axis , i for spatial axis (optional naming)
relative=True
)
# boundary condtion as a function
def bc(u):
return 1
# initial condtion as a function
def ic1(u):
return 1
def ic2(u):
return 2
def linear_convection(u):
return ( u['i','n-1'] - (c*dt/dx) * (u['i','n-1'] - u['i-1','n-1']) )
F[:,0] = F[:,-1] = bc # assign 1 for left and right boundary for all t
# square wave initial condition
F[:,:int((nx-1)/4)+1] = F[:,int((nx-1)/2):] = ic1
F[0:1, int((nx-1)/4)+1 : int((nx-1)/2)] = ic2
# assign linear convection function for
# interior spatial location [1:-1]
# and start from t>0 [1:]
F[1:,1:-1] = linear_convection
kx_solution = F(jnp.array(u))
plt.figure(figsize=(20,7))
for line in kx_solution[::20]:
plt.plot(jnp.linspace(0,xmax,nx),line)
5οΈβ£ Gaussian blur
import jax
import jax.numpy as jnp
import kernex as kex
def gaussian_blur(image, sigma, kernel_size):
x = jnp.linspace(-(kernel_size - 1) / 2.0, (kernel_size- 1) / 2.0, kernel_size)
w = jnp.exp(-0.5 * jnp.square(x) * jax.lax.rsqrt(sigma))
w = jnp.outer(w, w)
w = w / w.sum()
@kex.kmap(kernel_size=(kernel_size, kernel_size), padding="same")
def conv(x):
return jnp.sum(x * w)
return conv(image)
6οΈβ£ Depthwise convolution
import jax
import jax.numpy as jnp
import kernex as kex
@jax.jit
@jax.vmap
@kex.kmap(
kernel_size= (3,3),
padding = ('same','same'))
def kernex_depthwise_conv2d(x,w): # Channel-first depthwise convolution # jax.debug.print("x=\n{a}\nw=\n{b} \n\n",a=x, b=w)
return jnp.sum(x\*w)
h,w,c = 5,5,2
k=3
x = jnp.arange(1,h*w*c+1).reshape(c,h,w)
w = jnp.arange(1,k*k*c+1).reshape(c,k,k)
print(kernex_depthwise_conv2d(x,w))</summary>
7οΈβ£ Maxpooling2D and Averagepooling2D
@jax.vmap # vectorize over the channel dimension
@kex.kmap(kernel_size=(3,3), strides=(2,2))
def maxpool_2d(x):
# define the kernel for the Maxpool operation over the spatial dimensions
return jnp.max(x)
@jax.vmap # vectorize over the channel dimension
@kex.kmap(kernel_size=(3,3), strides=(2,2))
def avgpool_2d(x):
# define the kernel for the Average pool operation over the spatial dimensions
return jnp.mean(x)
8οΈβ£ Runge-Kutta integration
# lets solve dydt = y, where y0 = 1 and y(t)=e^t
# using Runge-Kutta 4th order method
# f(t,y) = y
import jax.numpy as jnp
import matplotlib.pyplot as plt
import kernex as kex
t = jnp.linspace(0, 1, 5)
y = jnp.zeros(5)
x = jnp.stack([y, t], axis=0)
dt = t[1] - t[0] # 0.1
f = lambda tn, yn: yn
def ic(x):
""" initial condition y0 = 1 """
return 1.
def rk4(x):
""" runge kutta 4th order integration step """
# ββββββ¬βββββ¬βββββ ββββββββ¬βββββββ¬βββββββ
# β y0 β*y1*β y2 β β[0,-1]β[0, 0]β[0, 1]β
# ββββββΌβββββΌβββββ€ ==> ββββββββΌβββββββΌβββββββ€
# β t0 β t1 β t2 β β[1,-1]β[1, 0]β[1, 1]β
# ββββββ΄βββββ΄βββββ ββββββββ΄βββββββ΄βββββββ
t0 = x[1, -1]
y0 = x[0, -1]
k1 = dt * f(t0, y0)
k2 = dt * f(t0 + dt / 2, y0 + 1 / 2 * k1)
k3 = dt * f(t0 + dt / 2, y0 + 1 / 2 * k2)
k4 = dt * f(t0 + dt, y0 + k3)
yn_1 = y0 + 1 / 6 * (k1 + 2 * k2 + 2 * k3 + k4)
return yn_1
F = kex.kscan(kernel_size=(2, 3), relative=True, padding=((0, 1))) # kernel size = 3
F[0:1, 1:] = rk4
F[0, 0] = ic
# compile the solver
solver = jax.jit(F.__call__)
y = solver(x)[0, :]
plt.plot(t, y, '-o', label='rk4')
plt.plot(t, jnp.exp(t), '-o', label='analytical')
plt.legend()