Google’s UK-based lab and research company DeepMind has added Jraph to the growing number of open-sourced libraries around JAX, while surveying the machine learning framework’s development and ecosystem.
JAX is a Python library that Google researchers developed and introduced in 2018 for high-performance numerical computing. JAX combines NumPy, automatic differentiation, and GPU/TPU support. In a new blog post, DeepMind researchers look at how JAX and its emergent ecosystem of open source libraries have served and accelerated an increasing number of machine learning projects.
DeepMind’s open-sourced ecosystem of JAX libraries so far includes Haiku for neural network modules, Optax for gradient processing and optimization, RLax for RL algorithms, chex for reliable code and testing, and the recently released Jraph for graph neural networks.
- Developed by some of the authors of Sonnet, a neural network library for TensorFlow, Haiku is a neural network library designed to make managing model parameters and other model states simpler. Users can use familiar object-oriented programming models while “harnessing the power and simplicity of JAX’s pure functional paradigm.“
- Optax is a gradient processing and optimization library for JAX that provides building blocks such as gradient transformations and composition operators so users can implement many standard optimizers in just a single line of code.
- RLax is a library built on top of JAX that exposes useful building blocks for implementing reinforcement learning (RL) agents.
- Use Cases: Acme
- Chex is a library of utilities designed to help users write reliable JAX code. It comprises a collection of testing utilities used by library authors to verify that the common building blocks are correct and robust, and by end-users to check experimental code.
- Jraph is a lightweight library that provides a data structure for graphs, a set of utilities for users working with graphs, and a ‘zoo’ of forkable graph neural network models.
Supporting rapidly evolving AI research requires balancing rapid prototyping and swift iteration while also ensuring experiments are at a scale appropriate to real-world production systems. DeepMind researchers highlight several approaches that have enabled the core JAX libraries to keep up with new research directions:
- Extract the most important and critical building blocks developed in each research project into well tested and efficient components
- Each library has a clearly defined scope, and they are interoperable but independent
- JAX Ecosystems remain consistent with the design of existing Tensor Flow libraries
DeepMind is continuously updating the JAX ecosystem to accelerate ML studies. The open-sourced libraries can be found at the project GitHub.
NeurIPS 2020 will host JAX MD: A Framework for Differentiable Physics, taking JAX in the context of molecular dynamics simulations. A Spotlight Oral is scheduled on Wednesday December 9 from 23:00 – 23:10 EST and a Poster Session on Thursday December 10 from 00:00 – 02:00 EST.
Reporter: Fangyu Cai | Editor: Michael Sarazen
This report offers a look at how China has leveraged artificial intelligence technologies in the battle against COVID-19. It is also available on Amazon Kindle. Along with this report, we also introduced a database covering additional 1428 artificial intelligence solutions from 12 pandemic scenarios.
Click here to find more reports from us.
We know you don’t want to miss any news or research breakthroughs. Subscribe to our popular newsletter Synced Global AI Weekly to get weekly AI updates.