AI Machine Learning & Data Science Research

DeepMind Releases New JAX Libraries for Neural Networks and Reinforcement Learning

DeepMind announced yesterday the release of Haiku and RLax — new JAX libraries designed for neural networks and reinforcement learning respectively.

DeepMind announced yesterday the release of Haiku and RLax — new JAX libraries designed for neural networks and reinforcement learning respectively.

Introduced in 2018 by Google, JAX is a numerical computing library that combines NumPy, automatic differentiation, and GPU/TPU support. The basic function of JAX is specializing and translating high-level Python and NumPy functions into representations can be transformed and lifted back into Python functions. This enables speeding up scientific computing and machine learning with the normal NumPy API and additional APIs for special accelerator operations when needed.

A simple neural network library for JAX, Haiku was developed by some of the authors of and built on Sonnet, a neural network library for TensorFlow that has near universal adoption at DeepMind. Haiku has shown favourable results in large-scale experiments in image and language processing, generative models, and reinforcement learning.

There are already a number of neural network libraries for JAX, including Google’s homegrown and flexibility-focused Flax. So what makes Haiku stand out?

First of all, Haiku is designed to make managing model parameters and other model states simpler — it works well with the rest of JAX and can be composed with other libraries. Haiku preserves Sonnet’s module-based programming model for state management while retaining access to JAX’s function transformations. The APIs and abstractions in Haiku are also similar to those in Sonnet, which makes transitioning from TensorFlow and Sonnet to JAX and Haiku easy by design. Haiku can also make other aspects of JAX simpler, offering for example a trivial model for working with random numbers.

The concurrent release, RLax (pronounced ‘relax’), is a library built on top of JAX that exposes useful building blocks for implementing reinforcement learning (RL) agents. The operations and functions RLax provides are not complete algorithms, but rather implementations of RL specific mathematical operations that are needed when building fully-functional agents.

RLax can be installed directly from GitHub with the ready-made command: pip install git+git://github.com/deepmind/rlax.git. All RLax code can then be compiled for different hardware (e.g. CPU, GPU, TPU) using JAX’s jax.jit function. The DeepMind project page also gives an example of using RLax functions to implement a Q-learning agent capable of learning to play Catch, a common RL unit-test for agents.

Currently, JAX remains a research project and not an official Google product. Although Haiku has been tested for several months and researchers have reproduced a number of experiments at scale, DeepMind stresses it is an alpha, and invites interested developers to test the library for themselves and let the team know “if things don’t look right.”

The project pages for the libraries are on GitHub at Haiku and RLax.


Journalist: Yuan Yuan | Editor: Michael Sarazen

1 comment on “DeepMind Releases New JAX Libraries for Neural Networks and Reinforcement Learning

  1. Pingback: DeepMind Releases New JAX Libraries for Neural Networks and Reinforcement Learning : Rlogger

Leave a Reply

Your email address will not be published.

%d bloggers like this: