Nested Hierarchical Transformer: Towards Accurate, Data-Efficient and Interpretable Visual Understanding

26 May 2021  ·  Zizhao Zhang, Han Zhang, Long Zhao, Ting Chen, Sercan O. Arik, Tomas Pfister ·

Hierarchical structures are popular in recent vision transformers, however, they require sophisticated designs and massive datasets to work well. In this paper, we explore the idea of nesting basic local transformers on non-overlapping image blocks and aggregating them in a hierarchical way. We find that the block aggregation function plays a critical role in enabling cross-block non-local information communication. This observation leads us to design a simplified architecture that requires minor code changes upon the original vision transformer. The benefits of the proposed judiciously-selected design are threefold: (1) NesT converges faster and requires much less training data to achieve good generalization on both ImageNet and small datasets like CIFAR; (2) when extending our key ideas to image generation, NesT leads to a strong decoder that is 8$\times$ faster than previous transformer-based generators; and (3) we show that decoupling the feature learning and abstraction processes via this nested hierarchy in our design enables constructing a novel method (named GradCAT) for visually interpreting the learned model. Source code is available https://github.com/google-research/nested-transformer.

PDF Abstract

Results from the Paper


Task Dataset Model Metric Name Metric Value Global Rank Uses Extra
Training Data
Result Benchmark
Image Classification CIFAR-10 Transformer local-attention (NesT-B) Percentage correct 97.2 # 84
PARAMS 90.1M # 236
Top-1 Accuracy 97.2 # 19
Parameters 90.1M # 2
Image Classification CIFAR-100 Transformer local-attention (NesT-B) Percentage correct 82.56 # 101
Image Classification ImageNet Transformer local-attention (NesT-B) Top 1 Accuracy 83.8% # 358
Number of params 68M # 785
GFLOPs 17.9 # 357
Image Classification ImageNet Transformer local-attention (NesT-S) Top 1 Accuracy 83.3% # 403
Number of params 38M # 663
GFLOPs 10.4 # 301
Image Classification ImageNet Transformer local-attention (NesT-T) Top 1 Accuracy 81.5% # 577
Number of params 17M # 521
GFLOPs 5.8 # 239

Methods