A Google Research team has proposed a novel method for dramatically reducing transformers’ (self-)attention memory requirements. This “trick,” which they believe had been simply overlooked by the machine learning community, addresses the concerning quadratic time and space complexity of transformer architectures’ self-attention mechanisms.
In the paper Self-attention Does Not Need O(n2) Memory, the Google team introduces simple algorithms for attention and self-attention that require only constant memory and logarithmic memory, respectively. At a sequence length of 16384, the approach can reduce the self-attention memory overhead by 59x for inference and by 32x for differentiation
The team first presents an algorithm for the attention operation with a single query, then extends it to self-attention.
Attention-based transformer architectures contain an encoder and a decoder. The encoder first computes an annotation of each word in the input. A weighted average of the encoder’s annotations is then computed as the context vector, which is fed to the decoder to make word predictions.
Each input comprises three representations: key, query and value. Given a query and a set of key-value pairs, attention can be generalized to compute a weighted sum of the values dependent on the query and the corresponding keys. In this way, the query determines which values to focus on — i.e., which values to “attend” to.
The proposed algorithm is very simple. The team observed that when the attention weights are calculated as a softmax, the division process can be moved to the very end of the attention operation using the distributive law. The computation complexity thus becomes constant when inputs are given in a particular order. If the inputs are provided in a different order, the proposed method will also store an index into the sequence, requiring only O(log n) memory.
To extend the algorithm to self-attention, the team computed the results to all queries sequentially, which required adding only one additional index to the list of queries, again using just O(log n) memory.
The team has presented the algorithms’ entire implementation in the JAX Python library, including support for multiple attention heads and memory-efficient differentiation.
The researchers compared their proposed algorithms’ memory and time requirements for inference and differentiation to baseline standards. For a sequence length of 16384, the self-attention memory overhead was reduced by 59x for inference and by 32x for differentiation.
The Google team hopes their work can raise awareness of the fact that attention is not intrinsically memory-hungry and encourage researchers to revisit some of their design choices in popular neural architectures.
The paper Self-attention Does Not Need O(n2) Memory is on arXiv.
Author: Hecate He | Editor: Michael Sarazen
We know you don’t want to miss any news or research breakthroughs. Subscribe to our popular newsletter Synced Global AI Weekly to get weekly AI updates.