Heavy-Tailed Class Imbalance and Why Adam Outperforms Gradient Descent on Language Models

29 Feb 2024  ·  Frederik Kunstner, Robin Yadav, Alan Milligan, Mark Schmidt, Alberto Bietti ·

Adam has been shown to outperform gradient descent in optimizing large language transformers empirically, and by a larger margin than on other tasks, but it is unclear why this happens. We show that the heavy-tailed class imbalance found in language modeling tasks leads to difficulties in the optimization dynamics. When training with gradient descent, the loss associated with infrequent words decreases slower than the loss associated with frequent ones. As most samples come from relatively infrequent words, the average loss decreases slowly with gradient descent. On the other hand, Adam and sign-based methods do not suffer from this problem and improve predictions on all classes. To establish that this behavior is indeed caused by class imbalance, we show empirically that it persist through different architectures and data types, on language transformers, vision CNNs, and linear models. We further study this phenomenon on a linear classification with cross-entropy loss, showing that heavy-tailed class imbalance leads to ill-conditioning, and that the normalization used by Adam can counteract it.

PDF Abstract
No code implementations yet. Submit your code now

Datasets


Results from the Paper


  Submit results from this paper to get state-of-the-art GitHub badges and help the community compare results to other papers.

Methods