Large language models (LLMs), such as GPT-3, PaLM, and OPT, have dazzled the AI world with their exceptional performance and ability to learn in-context. However, their significant drawback is their high cost at inference time. Existing approaches to reduce this cost through sparsity techniques either necessitate expensive retraining, compromise the LLM’s in-context learning capability, or fail to provide the desired speedup on contemporary hardware.
To address these challenges, in a new paper Deja Vu: Contextual Sparsity for Efficient LLMs at Inference Time, a research team from Rice University, Zhe Jiang University, Stanford University, University of California, San Diego, ETH Zurich Adobe Research, Meta AI (FAIR) and Carnegie Mellon University presents DEJAVU, a system that employs a cost-effective algorithm to predict contextual sparsity dynamically for each layer, combined with an asynchronous and hardware-aware implementation to accelerate LLM inference.
The research team sets out to define the ideal sparsity for LLMs, which should meet three crucial criteria: (i) no need for model retraining, (ii) preservation of quality and in-context learning ability, and (iii) enhancement of wall-clock time speed on modern hardware. To fulfill these demanding prerequisites, they move beyond conventional static sparsity and introduce the concept of contextual sparsity, which comprises small, input-dependent subsets of attention heads and MLP parameters that produce nearly identical results as the full model for a given input.
Their hypothesis is that contextual sparsity exists for pre-trained LLMs with any input. This concept guides them in dynamically pruning specific attention heads and MLP parameters during inference, without altering the pre-trained models. DEJAVU leverages contextual sparsity to make LLMs more efficient for applications with strict latency constraints.
Specifically, the researchers present a low-cost, learning-based algorithm to predict sparsity on the fly. Given the input to a particular layer, this algorithm anticipates a relevant subset of attention heads or MLP parameters in the subsequent layer and only loads them for computation. They also introduce an asynchronous predictor, similar to a classic branch predictor, to mitigate sequential overhead.
By incorporating a hardware-aware implementation of sparse matrix multiplication, DEJAVU achieves a remarkable reduction in latency for open-source LLMs such as OPT-175B. It outperforms the state-of-the-art library FasterTransformer from Nvidia by over 2× in end-to-end latency without compromising quality and surpasses the widely used Hugging Face implementation at small batch sizes.
This research demonstrates that DEJAVU effectively utilizes asynchronous lookahead predictors and hardware-efficient sparsity to enhance LLM inference in wall-clock time. The promising empirical results underscore the potential of contextual sparsity in significantly reducing inference latency compared to state-of-the-art models. The research team envisions their work as a step toward making LLMs more accessible to the wider AI community, potentially unlocking exciting new AI applications.
Author: Hecate He | Editor: Chain Zhang
We know you don’t want to miss any news or research breakthroughs. Subscribe to our popular newsletter Synced Global AI Weekly to get weekly AI updates.