JAX ResNet - Implementations and Checkpoints for ResNet Variants
A Flax (Linen) implementation of ResNet (He et al. 2015), Wide ResNet (Zagoruyko & Komodakis 2016), ResNeXt (Xie et al. 2017), ResNet-D (He et al. 2020), and ResNeSt (Zhang et al. 2020). The code is modular so you can mix and match the various stem, residual, and bottleneck implementations.
You can install this package from PyPI:
pip install jax-resnet
Or directly from GitHub:
pip install --upgrade git+https://github.com/n2cholas/jax-resnet.git
See the bottom of
jax-resnet/resnet.py for the available aliases/options for
the ResNet variants (all models are in Flax)
Pretrained checkpoints from
torch.hub are available for the
- ResNet [18, 34, 50, 101, 152]
- WideResNet [50, 101]
- ResNeXt [50, 101]
- ResNeSt [50-Fast, 50, 101, 200, 269]
The models are
to have the same intermediate activations and outputs as the
implementations, except ResNeSt-50 Fast, whose activations don't match exactly
but the final accuracy does.
A pretrained checkpoint for ResNetD-50 is available from fast.ai. The activations do not match exactly, but the final accuracy matches.
import jax.numpy as jnp from jax_resnet import pretrained_resnest ResNeSt50, variables = pretrained_resnest(50) model = ResNeSt50() out = model.apply(variables, jnp.ones((32, 224, 224, 3)), # ImageNet sized inputs. mutable=False) # Ensure `batch_stats` aren't updated.
You must install PyTorch yourself (instructions) to use these functions.
To extract a subset of the model, you can use
slice_variables function (found in in
allows you to extract the corresponding subset of the variables dict. Check out
that docstring for more information.
The top 1 and top 5 accuracies reported below are on the ImageNet2012 validation split. The data was preprocessed as in the official PyTorch example.
|Model||Size||Top 1||Top 5|
The ResNeSt validation data was preprocessed as in zhang1989/ResNeSt.
|Model||Size||Crop Size||Top 1||Top 5|
- Deep Residual Learning for Image Recognition. Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. arXiv 2015.
- Wide Residual Networks. Sergey Zagoruyko, Nikos Komodakis. BMVC 2016
- Aggregated Residual Transformations for Deep Neural Networks. Saining Xie, Ross Girshick, Piotr Dollár, Zhuowen Tu, Kaiming He. CVPR 2017.
- Bag of Tricks for Image Classification with Convolutional Neural Networks. Tong He, Zhi Zhang, Hang Zhang, Zhongyue Zhang, Junyuan Xie, Mu Li. CVPR 2019.
- ResNeSt: Split-Attention Networks. Hang Zhang, Chongruo Wu, Zhongyue Zhang, Yi Zhu, Zhi Zhang, Haibin Lin, Yue Sun, Tong He, Jonas Mueller, R. Manmatha, Mu Li, Alexander Smola. arXiv 2020.