Toward Causal Representation Learning
Introduction
In this blog post, I will delve into the paper titled “Toward Causal Representation Learning” (2021) (link). The primary objective of this paper is to address open challenges in machine learning and propose various approaches through which causality can aid in overcoming them.
It is my observation, shared by the causal learning community, that the machine learning community often underestimates the potential contributions of causality. While current trends in machine learning focus on larger models and datasets, there’s a growing recognition that these factors alone are insufficient. Thus, papers like this serve the vital purpose of drawing attention to the importance of causality in advancing the field.
Although causal learning encompasses a wide array of topics, the authors of this paper have chosen to focus specifically on Causal Representation Learning. Throughout this blog post, I will offer my insights on this choice and other aspects discussed in the paper.
Questions
For reading purposes, I’ll start by presenting a series of questions related to the topic at hand, which I’ll subsequently address. These questions will be accompanied by my own comprehensive answers, based on my understanding of the subject. I invite you to reflect upon this questions first. Do you agree with my insights?
- What is Causal Representation Learning?
- What is Causal Inference?
- What is Causal Discovery?
- What is Causal Reasoning?
- What does causality add to machine learning?
- What is the difference between Statistical Model, Causal Graphical Models and Structured Causal Models?
- Is Causal Representation Learning compatible with Reinforcement Learning?
- Is Causal Representation Learning useful for Multitask Learning? How so?
Answers
1. What is Causal Representation Learning?
Causal Representation Learning (CRL) involves discovering high-level causal variables (e.g., concepts) from low-level observations (e.g., sensors).
The motivation behind this subfield of Causal Learning arises from the recognition that established areas of Causality (such as Causal Inference and Causal Discovery) often assume the availability of causal variables, typically determined through human reasoning. CRL offers the opportunity to automate the identification of these causal variables, thereby facilitating the integration of more traditional methods in Causality.
Traditional causal discovery and reasoning assume that the units are random variables connected by a causal graph. However, real-world observations are usually not structured into those units, to begin with, for example, objects in images [162]. Hence, the emerging field of causal representation learning strives to learn these variables from data, much like machine learning went beyond symbolic AI in not requiring that the symbols that algorithms manipulate be given a priori (see [34]).
CRL also addresses the learning of disentangled representations. This entails obtaining a latent vector (from observations) that represents the disentangled properties of the scene. By achieving statistical independence among encoded variables or utilizing the Sparse Mechanism Shift hypothesis as a signal, interventions can be made in these disentangled variables sparsely.
In the context of learning transferable mechanisms (knowledge transfer), modularity is emphasized. Assuming that the mechanisms generating observations are modular, models aiming to predict observations should also exhibit modularity to mirror the environment. Modular models can share components across tasks where the mechanisms remain consistent.
In other words, if the world is indeed modular, in the sense that components/mechanisms of the world play roles across a range of environments, tasks, and settings, then it would be prudent for a model to employ corresponding modules [85].
2. What is Causal Inference?
Causal inference entails comprehending and identifying (measuring) cause-and-effect relationships between variables, transcending mere statistical associations. It centers on discerning how changes in a variable $X$ affect another variable $Y$. The concept of cause-and-effect introduces a sense of directionality, indicating that alterations in $X$ cause changes in $Y$, as opposed to the reverse.
3. What is Causal Discovery?
Causal discovery focuses on uncovering the causal structure or causal graph from observational data, identifying the direct causal relationships between variables in a system. While Causal Inference attempts to measure the causal association between two variables, Causal Discovery aims to extract a model or function that governs this association.
4. What is Causal Reasoning?
Causal reasoning facilitates the prediction of outcomes resulting from interventions, counterfactuals, and potential outcomes. Causal Reasoning typically operates under the assumption of the existence of a causal model (such as a causal graphical model or structured causal model).
5. What does causality add to machine learning?
Robustness
In the i.i.d. (independent and identically distributed) setting, statistical models can excel within the distribution they were trained on, which suffices for many use cases. However, due to their focus solely on correlation (and not causation), they perform poorly when subjected to interventions that typically alter the distribution.
Mere statistical correlations lack robustness against distribution shifts or out-of-distribution data.
Conversely, a Causal Model can effectively capture distribution shifts through the concept of interventions.
Essentially, Causal Models offer the most effective approach for achieving out-of-distribution generalization.
Classic Deep Learning attempts to enhance this robustness by employing techniques such as data augmentation, pretraining, self-supervision, and architectures with appropriate inductive biases concerning the perturbation of interest. While these solutions seem to address the out-of-distribution problem by incorporating out-of-distribution data points (or similar) into the dataset (in-distribution), they do not fundamentally resolve the challenges of out-of-distribution generalization; rather, they expand the represented distribution within the dataset.
So far, there has been no definitive consensus on how to solve these problems, although progress has been made using data augmentation, pretraining, self-supervision, and architectures with suitable inductive biases with respect to a perturbation of interest [60], [64], [137], [170], [206], [233]. It has been argued [188] that such fixes may not be sufficient, and generalizing well outside the i.i.d. setting requires learning not mere statistical associations between variables, but an underlying causal model.
Learning Reusable Mechanisms (Transfer Learning)
Causal Models exhibit modularity. As certain world mechanisms remain consistent across various tasks, we can capitalize on this knowledge to expedite learning of new tasks. In essence, this entails achieving faster transfer learning by organizing the understanding of the world in a more intelligent manner.
In a modular representation of the world where the modules correspond to physical causal mechanisms, many modules can be expected to behave similarly across different tasks and environments. An agent facing a new environment or task may thus only need to adapt a few modules in its internal representation of the world.
Answering Counterfactual Questions
This represents a novel addition to machine learning, primarily because the concept of counterfactuals has not been formalized within it previously. Answering a counterfactual involves querying what would have happened if a certain action, represented by $X=0$, had been taken instead of another, represented by $X=1$. However, to obtain a meaningful answer, it’s imperative to constrain the remaining variables, including the random ones, to closely resemble the scenario under consideration for counterfactual reasoning. This necessitates modeling the interactions between causal variables using some form of equations, thereby enabling manipulation of the variables. The only paradigm capable of accommodating this requirement is the Structured Causal Model with its structured equations.
Counterfactual problems involve reasoning about why things happened, imagining the consequences of different actions in hindsight, and determining which actions would have achieved the desired outcome.
6. What is the difference between Statistical Model, Causal Graphical Models and Structured Causal Models?
Statistical models exclusively capture correlations and are tailored to a specific distribution. In contrast, the other models have the capability to represent interventional distributions, encompassing multiple distributions.
A Causal Graphical Model comprises a graph with edges that denote the causal direction of cause-and-effect relationships between variables.
Structured Causal Models introduce greater complexity by incorporating structural equations that define the relationships between variables. Consequently, structural causal models possess the capacity to model counterfactual scenarios by manipulating the values of unobserved (random) variables.
Once a causal model is available, either by external human knowledge or a learning process, causal reasoning allows drawing conclusions on the effect of interventions, counterfactuals, and potential outcomes. In contrast, statistical models only allow reasoning about the outcome of i.i.d. experiments.
7. Is Causal Representation Learning compatible with Reinforcement Learning?
In this study, the authors don’t directly discuss how Causal Representation Learning (CRL) applies to Reinforcement Learning (RL). Instead, they talk about how causality fits into RL. They argue that RL, with its ability to intervene in the environment, is closely linked to causality. They also mention two aspects of Causal Learning in RL: Causal Discovery and Causal Inference. In simple terms, Causal Discovery means figuring out causal relationships from intervention data, while Causal Inference means learning how to plan and act based on a given causal model.
With this in mind, the paper focuses on exploring specific aspects of causality that could benefit RL.
World Model
When implemented effectively, world models built upon generative models can craft a rich, diverse imagined space reminiscent of a Lorenzian structure. Such a space holds immense potential for various applications and explorations.
This would take the field a step closer to a form of artificial intelligence that involves thinking in the sense of Konrad Lorenz, that is, acting in an imagined space.
Generalization, Robustness, and Fast Transfer
Causal reasoning holds promise in addressing two key challenges in Reinforcement Learning (RL): sample complexity and poor generalization to changes in the environment.
By employing RL to extract insights about the world via interventions and learning the invariances within a causal graph structure, we can potentially achieve the robustness and generalization needed. This approach eliminates the need for random exploration, as the agent actively seeks out new valuable information.
Moreover, if this information is encapsulated within Independent Causal Mechanisms, when deployed in new environments, it’s likely that only a few mechanisms would require updating, streamlining the adaptation process.
Counterfactuals
Counterfactual reasoning has been shown to enhance sample efficiency in RL. With the capability to analyze what actions were taken and their outcomes, we may discover more efficient methods to learn and accomplish a desired task.
Off-line RL
Offline RL already disrupts the i.i.d. (independent and identically distributed) assumption. Causality learning endeavors to acquire information that is beneficial in non-i.i.d. settings, suggesting its potential usefulness for offline RL.
8. Is Causal Representation Learning useful for Multitask Learning? How so?
Once again, the authors refrain from directly addressing the intersection of Causal Representation Learning (CRL) and Multitask Learning. However, they underscore the inherent advantages of causality itself. Given the modularity and knowledge transfer capabilities of Causal Learning, it appears to be a well-suited framework for application in Multitask Learning. In Multitask Learning, the assumption is already made that some shared structure must exist between the tasks being learned. By explicitly leveraging this knowledge, one may uncover more efficient solutions overall.