jaxpole

A differentiable implementation of an all-pole filter in JAX


Keywords
nbdev, jupyter, notebook, python
License
Apache-2.0
Install
pip install jaxpole==0.0.3

Documentation

jaxpole

This is an implementation of a differentiable time-varying all-pole filter in JAX based on torchlpc.

Install

pip install jaxpole

or locally from source

pip install -e '.[dev]'

How to use

import jax.numpy as jnp
import jax

pole = 0.99 * jnp.exp(1j * jnp.pi / 4)
coeffs = jnp.array([-2 * pole.real, pole.real**2 + pole.imag**2])

x = jax.random.normal(jax.random.PRNGKey(0), (1, 1000)) # (B, T)
A = jnp.tile(coeffs, (1, x.shape[-1], 1)) # (B, T, P)
zi = jnp.zeros((1, 2)) # (B, P)

# filter the signal
y = allpole(x, A, zi)