Although the remarkable power and successes of transformer architectures have been well documented by the machine learning research community in recent years, there remains a lack of literature providing a rigorous theoretical analysis of transformer networks and interpretations of the functions learned by them.
In the new paper Convexifying Transformers: Improving Optimization and Understanding of Transformer Networks, a Stanford University and Google Research team provides a solid theoretical analysis of transformers’ fundamental mechanisms and introduces a novel convex analytic training framework for improving their optimization.
The team summarizes their main contributions as follows:
- We propose an alternative formulation to the standard self-attention mechanism and study the regularized training problem of attention/transformer networks with it.
- We convexify the regularized training problem of attention/transformer networks with the proposed attention layer and therefore enable finding a globally optimal solution without requiring any nonconvex optimization heuristic, e.g., layer normalization and skip connections.
- We also apply our convex analytic framework to various architectures, e.g., networks with or without an FCN layer. Thus, we are able to explain the impact of each component on the models learned throughout training.
- We reveal an implicit regularization mechanism induced by our attention mechanism. We further characterize this regularization as a sparsity-inducing factor across tokens.
- We demonstrate the effectiveness of our convex reformulation via various experimental results. We also show that our reformulation significantly mitigates the grokking phenomenon studied in recent papers (Power et al., 2022; Thilak et al., 2022).
The team first proposes a convex alternative to transformers’ self-attention mechanism and reformulates model training as a convex optimization problem. The proposed convex reformulation provides numerous benefits: it enables researchers to globally optimize their network parameters without nonconvex optimization heuristics, the learned functions are transparent and interpretable, and it provides insights on the structures of the resulting functions and their generalization properties.
In their empirical studies, the team compared their proposed convex training approach to standard nonconvex training in a student-teacher setting with a pretrained BERT model and against standard transformer networks with self-attention mechanisms on algorithmic datasets. The results show that convex training converges to perfect generalization accuracy 10x faster than standard nonconvex training and with significantly lower test losses.
Overall, this work provides a welcome peek into the hidden mechanisms of transformer networks, which the team hopes follow-up papers can build on to make further progress in this important research area.
The paper Convexifying Transformers: Improving Optimization and Understanding of Transformer Networks is on arXiv.
Author: Hecate He | Editor: Michael Sarazen
We know you don’t want to miss any news or research breakthroughs. Subscribe to our popular newsletter Synced Global AI Weekly to get weekly AI updates.