WaveMix: Multi-Resolution Token Mixing for Images

29 Sep 2021  ·  Pranav Jeevan P, Amit Sethi ·

Even though vision transformers (ViTs) have provided state-of-the-art results on image classification, their requirements of large data, model size, and GPU usage have put them out of reach of most practitioners of computer vision. We present WaveMix as an alternative to self-attention mechanisms in ViT and convolutional neural networks to significantly reduce computational costs and memory footprint without compromising on image classification accuracy. WaveMix uses a multi-level two-dimensional discrete wavelet transform for mixing tokens and aggregating multi-resolution pixel information over long distances, which gives it the following advantages. Firstly, unlike the self-attention mechanism of ViT, WaveMix does not unroll the image. Thus, it has the right inductive bias to utilize the 2-D structure of an image, which reduces the demand for large training data. Additionally, the quadratic complexity with respect to sequence length is also eliminated. Secondly, due to its multi-resolution token-mixing, WaveMix also requires much fewer layers than a CNN does for comparable accuracy. Preliminary results from our experiments on supervised learning using CIFAR-10 dataset show that a four-layer WaveMix model can be 37% more accurate than a ViT with a comparable number of parameters, while consuming only 3% of the latter's GPU RAM and memory. This model also performs better than efficient transformers and models not based on attention, such as, FNet, and MLP Mixer. Scaling up the WaveMix model to achieve a top-1 accuracy of over 85% on CIFAR-10 could be done on a 16 GB GPU, while consuming only 6% of the GPU RAM used by the largest ViT which could fit in that GPU. Our work suggests that research on model structures that exploit the right inductive bias is far from over, and that such models can enable the training of computer vision models in settings with limited GPU resources.

PDF Abstract

Results from the Paper


Task Dataset Model Metric Name Metric Value Global Rank Benchmark
Image Classification CIFAR-10 WaveMix Percentage correct 85.21 # 203

Methods