The 2017 paper Attention is All You Need introduced transformer architectures based on attention mechanisms, marking one of the biggest machine learning (ML) breakthroughs ever. A recent study proposes a new way to study self-attention, its biases, and the problem of rank collapse.
Attention-based architectures have proven effective for improving ML applications in natural language processing (NLP), speech recognition, and most recently in computer vision. Research aimed at understanding the inner workings of transformers and attention in general, however, has been limited.
In the paper Attention is Not All You Need: Pure Attention Loses Rank Doubly Exponentially with Depth, a research team from Google and EPFL (École polytechnique fédérale de Lausanne) proposes a novel approach that sheds light on the operation and inductive biases of self-attention networks (SANs) and finds that pure attention decays in rank doubly exponentially with respect to depth.
The researchers summarize their work as follows:
- Present a systematic study of building blocks of the transformer, revealing opposing impacts between self-attention and the counteracting forces: skip connections and MLP, in contributing and preventing a rank collapse in transformers.
- Propose a new method for analyzing SANs via a path decomposition, revealing SANs as an ensemble of shallow networks.
- Verify the theory with experiments on common transformer architectures.
The team began by studying the building structure of SANs with skip connections and multi-layer perceptrons (MLPs) disabled. They considered the SAN as a directed acyclic graph, with every node corresponding to a self-attention head and directed edges connecting heads of consecutive layers. Based on this, they then built a path decomposition to describe the actions of a multi-head SAN as the combination of simpler single-head networks. Through the path interactions, they observed that biases are not particularly meaningful; and that each path converges rapidly to a rank-1 matrix with identical rows. The interesting part came when the paths were increased exponentially: each path then degenerated doubly exponentially, resulting in a rank-1 output.
The researchers considered the behaviour of each path separately, examining how the residual changes during the forward pass. They discovered that the residual norm converges to zero surprisingly quickly (doubly exponentially with a cubic rate). As the rank of attention matrices also depends on the rank of the input, the identified cubic rate of convergence is significantly faster than what would have been expected. In other words, deeper SANs will lead to a cascading effect.
In an effort to obtain a deeper understanding of the structure of SANs, the team expanded their analysis by incorporating the three key transformer components that SANs lack: skip connections, MLPs, and layer normalization. This examination revealed that the SANs with enabled skip connections relied heavily rely on short paths, behaving like ensembles of shallow single-head self-attention networks. The team also discovered that MLPs counteract convergence, such that as MLPs become more powerful, convergence becomes slower; and that layer normalization does not mitigate the rank collapse.
The team conducted the following experiments:
- Rank collapse in real architectures, examining the residual of popular transformer architectures BERT, Albert, and XLNet.
- Visualizing the biases of different architectures, studying the behaviour of a single-layer transformer when applied recurrently to predict a simple 2D circular sequence.
- Testing path effectiveness with respect to length through three tasks: Sequence memorization, Learning to sort, and Convex hull prediction.
The first experiment confirmed that when skip connections are removed, all networks exhibit a rapid rank collapse, while the second showed that adding MLP or skip connections either stops or drastically slows down rank collapse. The last experiment supported the researchers’ hypothesis that short paths are responsible for the majority of SANs’ expressive power.
The paper Attention is Not All You Need: Pure Attention Loses Rank Doubly Exponentially with Depth is on arXiv.
Author: Hecate He | Editor: Michael Sarazen
We know you don’t want to miss any news or research breakthroughs. Subscribe to our popular newsletter Synced Global AI Weekly to get weekly AI updates.