This is an implementation of a differentiable time-varying all-pole filter in JAX based on torchlpc.
pip install jaxpole
or locally from source
pip install -e '.[dev]'
import jax.numpy as jnp
import jax
pole = 0.99 * jnp.exp(1j * jnp.pi / 4)
coeffs = jnp.array([-2 * pole.real, pole.real**2 + pole.imag**2])
x = jax.random.normal(jax.random.PRNGKey(0), (1, 1000)) # (B, T)
A = jnp.tile(coeffs, (1, x.shape[-1], 1)) # (B, T, P)
zi = jnp.zeros((1, 2)) # (B, P)
# filter the signal
y = allpole(x, A, zi)