Discovering a system’s causal relationships and structure is a crucial yet challenging problem in scientific disciplines ranging from medicine and biology to economics. While researchers typically adopt the graphical formalism of causal Bayesian networks (CBNs) to induce a graph structure that best describes these relationships, such unsupervised score-based approaches can quickly lead to prohibitively heavy computation burdens.
A research team from DeepMind, Mila – University of Montreal and Google Brain challenges the conventional causal induction approach in their new paper Learning to Induce Causal Structure, proposing a neural network architecture that learns the graph structure of observational and/or interventional data via supervised training on synthetic graphs. The team’s proposed Causal Structure Induction via Attention (CSIvA) method effectively makes causal induction a black-box problem and generalizes favourably to new synthetic and naturalistic graphs.
The team summarizes their main contributions as:
- We tackle causal structure induction with a supervised approach (CSIvA) that maps datasets composed of both observational and interventional samples to structures.
- We introduce a variant of a transformer architecture whose attention mechanism is structured to discover relationships among variables across samples.
- We show that CSIvA generalizes to novel structures, whether or not training and test distributions match. Most importantly, training on synthetic data transfers effectively to naturalistic CBNs.
The team first trains their model on synthetic data generated using different CBNs to capture dataset and graph structure relationships, then leverages these mappings to induce the structures underlying datasets of interest. This novel variant of the conventional transformer architecture takes observational and interventional samples corresponding to the same CBN pairs as inputs, and outputs the predicted CBN graph structure.
The proposed method’s mapping between datasets and graph structures is achieved by leveraging a transformer-based attention mechanism that alternates between attending to different variables in the graph and different samples from a variable, analyzing the data and computing a distribution of candidate graphs. The decoder mechanism used to output the CBN graph structures is an autoregressive generative model that operates on the inferred structure.
The resulting CSIvA can thus be viewed as a meta-learning approach, as the model itself learns the relationships between datasets and their underlying structures. This design also enables the model to generalize well to data from naturalistic CBNs.
In their empirical study, the team compared their supervised learning paradigm to unsupervised baseline methods that included non-linear ICP (Heinze-Deml et al., 2018a), and DAG-GNN (Yu et al., 2019) on classic benchmarks such as the Sachs (Sachs et al., 2005) and Asia (Lauritzen & Spiegelhalter, 1988) datasets. In the tests, CSIvA outperformed all baselines, indicating its ability to effectively induce causal structures in realistic real-world CBNs.
Overall, the results show the proposed CSIvA generalizes well to out-of-distribution data even when trained only on synthetic data, taking a significant step forward in causal graph structure inference.
The paper Learning to Induce Causal Structure is on arXiv.
Author: Hecate He | Editor: Michael Sarazen
We know you don’t want to miss any news or research breakthroughs. Subscribe to our popular newsletter Synced Global AI Weekly to get weekly AI updates.