Transformers learn to implement preconditioned gradient descent for in-context learning
Motivated by the striking ability of transformers for in-context learning, several works demonstrate that transformers can implement algorithms like gradient descent. By a careful construction of weights, these works show that multiple layers of transformers are expressive enough to simulate gradient descent iterations. Going beyond the question of expressivity, we ask: \emph{Can transformers can learn to implement such algorithms by training over random problem instances?} To our knowledge, we make the first theoretical progress toward this question via analysis of the loss landscape for linear transformers trained over random instances of linear regression. For a single attention layer, we prove the global minimum of the training objective implements a single iteration of preconditioned gradient descent. Notably, the preconditioning matrix not only adapts to the input distribution but also to the variance induced by data inadequacy. For a transformer with $k$ attention layers, we prove certain critical points of the training objective implement $k$ iterations of preconditioned gradient descent. Our results call for future theoretical studies on learning algorithms by training transformers.
PDF Abstract