In the new paper ALX: Large Scale Matrix Factorization on TPUs, a Google Research team presents ALX, an open-source library written in JAX that leverages Tensor Processing Unit (TPU) hardware accelerators to enable efficient distributed matrix factorization using Alternating Least Squares. The team has also released WebGraph, a large-scale, link-prediction dataset designed to encourage additional research on techniques for handling very large-scale sparse matrices.
Matrix factorization is an efficient core technique widely employed in today’s recommender systems. The high-performance implementation of large-scale matrix factorization could significantly accelerate productivity in this growing field.
The proposed matrix factorization approach stems from the appealing properties of TPUs, which the team summarizes as:
- A TPU pod has enough distributed memory to store very big sharded embedding tables.
- TPUs are devised for workloads that can benefit from data parallelism, this is useful for solving a large batch of system of linear equations, a core operation for Alternating Least Squares.
- TPU chips are interconnected directly with dedicated, high bandwidth and low latency interconnects. This makes gather and scatter operations over a large distributed embedding table stored in TPU memory feasible.
- Since any node failure can lead to a halt in the training process, traditional ML workloads require a highly reliable distributed setup, a requirement that a cluster of TPUs can fulfill.
The abovementioned TPU properties make it possible to shard a large embedding table over all available devices while avoiding replication or fault tolerance issues.
To fully utilize available TPU memory, the team presents an algorithm for distributed matrix factorization using the Alternating Least Squares (ALS) approach for learning matrix factorization parameters. The method uniformly shards both user and item embedding tables across the TPU cores. When a data batch is fed from the host CPU to the connected TPU devices, multiple hosts (each connected to 8 TPU cores) are employed in a pod configuration process such that the computational flow is identical and parallelized across distinct batches passed to the TPU devices.
To perform evaluation experiments at scale, the team created WebGraph, a large-scale link prediction dataset comprising Common Crawl data scraped from the Internet, along with several WebGraph variants based on locality and sparsity properties of sub-graphs. These datasets will also be open-sourced.
The team analyzed the scaling properties of WebGraph variants in terms of training time as they increased the number of available TPU cores. The empirical results show that with 256 TPU cores, one epoch of the largest WebGraph variant, WebGraph-sparse (365M x 365M sparse matrix), takes about 20 minutes to finish, indicating that ALX can comfortably scale to matrices up to 1B x 1B in size.
Overall, the study demonstrates the applicability of TPUs for accelerating large-scale matrix factorization. The Google team hopes their work will inspire further research and improvements on scalable methods and implementations of large-scale matrix factorization.
The paper ALX: Large Scale Matrix Factorization on TPUs is on arXiv.
Author: Hecate He | Editor: Michael Sarazen
We know you don’t want to miss any news or research breakthroughs. Subscribe to our popular newsletter Synced Global AI Weekly to get weekly AI updates.