Drawing inspiration from a formal equivalence between today’s linear transformers and 1990s fast weight programmers (FWPs), a team from the Swiss AI Lab IDSIA has proposed recurrent FWPs (RFWPs), a novel approach that can outperform linear and regular transformers on execution and sequential tasks.
Although transformer architectures have achieved impressive results across many sequence-processing tasks, there are limits to what sort of inputs they can handle. This is because transformers’ computational complexity in time and space is quadratic with regard to input sequence length. In auto-regressive transformers, the state size increases linearly with sequence length, making them infeasible for auto-regressive settings dealing with very long or potentially infinite sequences.
Recent studies have attempted to scale transformers to longer sequences by linearising the softmax to create linear transformers with constant memory size and time complexity linear in sequence length.
In the paper Going Beyond Linear Transformers With Recurrent Fast Weight Programmers, the Swiss AI Lab researchers leverage fast weight programmers (FWPs) to advance the capabilities of linear transformers. They also explore the connection between linearised transformers and outer product-based FWPs to realize the full power of improved FWPs.
The team summarizes their study’s key points:
- From the perspective of FWPs, we study novel powerful FWPs for sequence processing, demonstrating that Neural Networks (NNs) can easily learn to control NNs that are more complex than a single feedforward layer.
- From the perspective of Transformer models, our RFWPs augment linear Transformers with recurrence, addressing general limitations of existing auto-regressive Transformer models.
In standard neural networks the weights remain fixed after training. Fast weights, on the other hand, aim to make a network’s weights variable and input-dependent. Such context-dependent FWPs were introduced by Jürgen Schmidhuber in two-network systems of the early 1990s, which comprise a slow and a fast net, each with arbitrary architectures. In this setup, the slow neural network uses backpropagation to generate rapid context-dependent weights for the fast neural network. Simply put, a slow network learns to program a corresponding fast network.
Linear transformers are a class of transformer where the softmax is replaced with a kernel function. Previous studies have shown that the self-attention can then be rewritten as a basic outer product-based FWP. Linearised transformers can thus be regarded as essentially equivalent to outer product-based FWPs.
The Swiss AI researchers focus on outer product-based weight generation, and first present FWPs with recurrent fast nets and slow nets. They obtain a fast weight RNN, which they term Delta RNN, by adding an additional recurrent term to the feedforward fast network of the linear transformer. They obtain a slow net in the Delta Net that is dependent on the previous output of the fast network, which they refer to as the Recurrent Delta Net (RDN).
The researchers evaluated models on the generic language modelling task to obtain a performance overview; tested them on code execution and sequential ListOps synthetic algorithmic tasks to compare their elementary sequence processing abilities; and applied the models to reinforcement learning in 2D game environments as a replacement for LSTMs.
In experiments on the Wikitext-103 dataset, the Delta LSTM variant outperformed the baselines (Katharopoulos et al.’s Linear Transformer and Schlag et al.’s Delta Net), demonstrating that the slow network can successfully control the more complex fast LSTM network.
In a code execution experiment, Delta RNN remained stable for both difficulty levels and had the best performance (85.1 percent) among transformer variants, showing the benefits of recurrence and in particular, the regular RNN architecture.
In a Sequential ListOps experiment, at depth 10, the mutable memory transformer variants (Delta Net, Delta RNN, and RDN) outperformed the regular and linear transformers. At depth 15, while the LSTM’s performance dropped drastically, the performance of the transformer variants remained stable.
The team also evaluated the models in reinforcement learning settings on 20 Atari 2600 environments, where the results showed that the proposed RDN larger can effectively replace LSTM models.
Overall, the study confirms that 1990s FWP frameworks have a strong connection to modern architectures, opening promising avenues for further research into new classes of recurrent transformers.
The paper Going Beyond Linear Transformers With Recurrent Fast Weight Programmers is on arXiv.
Author: Hecate He | Editor: Michael Sarazen, Chain Zhang
We know you don’t want to miss any news or research breakthroughs. Subscribe to our popular newsletter Synced Global AI Weekly to get weekly AI updates.