Attention-based transformer architectures have enabled countless breakthroughs on language and vision tasks since their introduction in 2017, but their application remains limited to short context sizes due to their quadratic complexity over the input length. Various techniques have been proposed to create more efficient attention mechanisms and reduce complexity to linear to speedup transformers, but these techniques typically come with drawbacks such as inferior quality, high overhead in practice, or inefficient auto-regressive training.
In the new paper Transformer Quality in Linear Time, a research team from Cornell University and Google Brain proposes FLASH (Fast Linear Attention with a Single Head), which it says is the first model family to achieve quality on par with fully augmented transformers while maintaining linear scalability over the context size on modern accelerators.
The researchers first propose a new layer design that can perform more effective approximation, introducing a Gated Attention Unit (GAU) mechanism whose layers are cheaper than transformer layers and whose qualities rely less on the precision of attention. The team notes that GAU with a small single-head, softmax-free attention performs on par with transformers. Although GAU still suffers from transformers’ quadratic complexity problem, it weakens the role of attention, enabling the team to later perform approximation with a minimal loss of quality. An efficient token-grouping method is employed to approximate quadratic attention in GAU, resulting in a layer variant with linear complexity over the context size. The accelerator-efficient implementation derived via this formula can achieve linear scalability in practice with only a few lines of code change.
The proposed method was informed in part by an exploration of previous studies on modelling long sequences with attention, with the team leveraging the benefits of partial attention and linear attention to propose a novel mixed-chunk attention mechanism. A local quadratic attention is independently applied to each chunk to produce part of the pre-gating state, and the global linear attention mechanism is then adopted to capture long-range interactions across chunks. The combination of the accelerator-efficient approximation strategy and the mixed chunk attention mechanism enables FLASH to achieve its transformer-level quality in linear time on long sequences.
The team compared their FLASH models with two popular linear-complexity transformer variants — Performer (Choromanski et al., 2020) and Combiner (Ren et al., 2021) — on long sequences.
In the evaluations, FLASH achieved training speedups of up to 4.9× on Wiki-40B and 12.1× on PG-19 for auto-regressive language modelling and 4.8× on C4 for masked language modelling. FLASH also achieved a lower perplexity than the full-attention transformer variants, validating the effectiveness of its novel efficient attention design.
Overall, this work demonstrates that the proposed FLASH can achieve quality (perplexity) comparable with fully-augmented transformers while being significantly faster to train than state-of-the-art systems, validating it as a practical method for addressing the drawbacks of existing efficient transformer variants.
The paper Transformer Quality in Linear Time is on arXiv.
Author: Hecate He | Editor: Michael Sarazen
We know you don’t want to miss any news or research breakthroughs. Subscribe to our popular newsletter Synced Global AI Weekly to get weekly AI updates.