img2col.py in nn/_functions/ dataloader.py in utils/data/ Because some parts of img2col use numpy not jax.numpy, and dataloader is based on python so it is quite slow