A new Google Research study has proposed a unified, efficient and modular approach for implicit differentiation of optimization problems that combines the benefits of implicit differentiation and automatic differentiation (autodiff). The researchers say solvers equipped with implicit differentiation set up by the proposed framework can make the autodiff process more efficient for end-users.
Autodiff is a revolutionary technique used in machine learning (ML) solvers for optimization problems. A key advantage of autodiff is that it frees human experts from the tedious burden of manually computing the derivatives of a system’s complex optimization functions.
In most cases, however, the inputs of an optimization solution are not explicit formulas, and so autodiff cannot be used directly to differentiate these functions. A popular way to circumvent this problem is to treat optimization problem solutions as implicitly-defined functions of certain optimality conditions. These implicit differentiation formulas however also require performing case-by-case mathematical derivations.
In the paper Efficient and Modular Implicit Differentiation, the Google researchers adopt a novel strategy that combines implicit differentiation and autodiff in a unified, efficient and modular framework for the implicit differentiation of optimization problems. The team says the proposed efficient and modular approach can be added on top of any state-of-the-art solver, enables the recovery of many previously proposed implicit differentiation methods, and can easily create new ones.
The researchers summarize their contributions as:
- Delineate extremely general principles for implicitly differentiating through an optimization problem solution. The approach can be seen as “hybrid” in the sense that it combines implicit differentiation with autodiff of the optimality conditions.
- Show how to instantiate our framework in order to recover many recently proposed implicit differentiation schemes, thereby providing a unifying perspective. The team also obtain new implicit differentiation schemes, such as one based on the mirror descent fixed point.
- On the theoretical side, provide new bounds on the Jacobian error when the optimization problem is only solved approximately.
- Describe a JAX implementation and provide a blueprint for implementing our approach in other frameworks. The team is in the process of open-sourcing a full-fledged library for implicit differentiation in JAX.
- Implement four illustrative applications, demonstrating our framework’s ease of use.
The team outlines the general principles of the proposed framework, which make it easy and convenient to add implicit differentiation on top of existing solvers. Specifically, the user defines a mapping function capturing the optimality conditions of the problem being solved by the algorithm. The method provides reusable building blocks to express the mapping function, then leverages the autodiff of the mapping functions to automatically differentiate the optimization problem solution.
The researchers then explain their method for differentiating a root and a fixed point, and how the vector-Jacobian product (VJP), Jacobian-vector product (JVP) and pre-processing and post-processing mappings are computed. This analysis of the proposed framework’s mechanism demonstrates both its efficiency and how it can be added on top of any state-of-the-art solver.
The proposed framework’s implementation is based on JAX, and the team provides various optimality condition mapping examples and their corresponding codes. The examples include stationary point condition, KKT conditions, proximal gradient fixed point, projected gradient fixed point, mirror descent fixed point, etc.
The team conducted a series of experiments to demonstrate their modular framework’s ease of formulation and performance in solving bi-level optimization problems — which typically involve computing the derivatives of a nested optimization problem in order to solve an outer one. The experiments also included task-driven dictionary learning (Task-driven DictL), a way to learn sparse codes for input data such that the codes solve an outer learning problem.
In an experiment on breast cancer survival prediction from gene expression data, the Task-driven DictL approach achieved performance competitive with state-of-the-art L1 or L2 regularized logistic regression baselines.
In a data distillation experiment conducted on an MNIST dataset — aiming to learn a small synthetic training dataset such that a model trained on this learned dataset achieves small losses on the original training set — the team used gradient descent and implicit differentiation. The results show that the proposed implicit differentiation approach was four times faster than a baseline method using differentiation of the algorithm’s unrolled iterates.
The Google team’s analytical and empirical analysis shows that the proposed modular approach for implicit differentiation is generic and can easily exploit and improve the efficiency of state-of-the-art solvers.
The paper Efficient and Modular Implicit Differentiation is on arXiv.
Author: Hecate He | Editor: Michael Sarazen, Chain Zhang
We know you don’t want to miss any news or research breakthroughs. Subscribe to our popular newsletter Synced Global AI Weekly to get weekly AI updates.