AI Machine Learning & Data Science Research

Google Introduces Flax: A Neural Network Library for JAX

Google recently introduce Flax - a neural network library for JAX that is designed for flexibility.

In optimization theory, a loss or cost function measures the distance between the fitting or predicted values and real values. For the majority of machine learning models, improving performance means minimizing the loss function.

But for deep neural networks, performing gradient descent to minimize the loss function for every parameter can be prohibitively resource-consuming. Traditional approaches include manually deriving and coding, or implementing the neural model using syntactic and semantic constraints of a machine learning framework like TensorFlow.

But what if it were possible to simply write down loss functions using a NumPy library and have the work done automatically? That’s a job for JAX — the Just-in-time compiler Google introduced in 2018 that uses Autograd and XLA (Accelerated Linear Algebra) and can automatically differentiate native Python and NumPy code through a large subset of Python features such as ifs, loops, recursion and closures. JAX also allows for fast scientific computing by automatically parallelising code across multiple accelerators such as GPUs and TPUs.

Taking this one step further, Google recently introduce Flax — a neural network library for JAX that is designed for flexibility. Flax can train neural networks by forking an example from its official GitHub repository. When it comes to modifying models, developers need no longer add features to the framework, they can simply modify the training loop (such as train_step setting) to achieve the same result. At its core, Flax is built around parameterised functions called Modules, which override apply and can be used as normal functions.

from flax import nn
import jax.numpy as jnp

class Linear(nn.Module):
def apply(self, x, num_features, kernel_init_fn):
input_features = x.shape[-1]
W = self.param('W', (input_features, num_features), kernel_init_fn)
return, W)

Flax code used to define a learned linear transformation.

The Flax release has created a buzz on social media. Director of Machine Learning research at NVIDIA Anima Anandkumar tweeted the Flax GitHub link, adding: “We used CGD for training GANs and for constrained problems in RL. This library will be very useful.” Google Brain Research Scientist David Ha (twitter name hardmaru) also endorsed the new repository.

For those interested in trying Flax, there are currently three examples available for testing: MNIST, a database of handwritten digits that is mainly used as handwritten digits recognition task; ResNet, a deep residual learning architecture for image recognition that is trained in ImageNet and mostly used to measure large-scale cluster computing capability; And 1 Billion Word Language Model Benchmark, a standard training and test setup for language modeling experiments.

The Flax team is also calling on developers to help to build additional end-to-end examples, such as Translation, Semantic Segmentation, GAN , VAE etc.

The Google Research: Flax repository is on GitHub.

Author: Hecate He | Editor: Michael Sarazen

3 comments on “Google Introduces Flax: A Neural Network Library for JAX

  1. Pingback: Google Introduces Flax: A Neural Network Library for JAX : Rlogger

  2. edison

    series de timepo

  3. Hey Hecate,
    A rookie question. I never into Machine learning and AI.
    But out of curiosity, will this new neural network ‘Flax’ has anything to do with Search Engine Optimization or anything related to Google Search engine?
    Thank you

Leave a Reply

Your email address will not be published. Required fields are marked *