02 - Anchor Regression as a diluted form of causality

| Apr 3, 2023 min read

We here summarise the work of Rothenhäusler et al. (2021). We recommend reading the original paper for a more complete overview of the ideas sketched here.

Introduction

We are still interested in worst-case risk minimisation as described in the previous chapter. Given $X \in \mathbb{R}^d$ and $Y \in \mathbb{R}$, we aim to minimise

$$ \arg\min_{\mathbf{b}} \sup_{Q \in \mathbb{Q}} \mathbb{E}[(Y - X\mathbf{b})^2]. $$

We use the squared loss and again assume a linear SCM, but now with a slightly more elaborate structure. We assume that a part of the covariates are known to be exogenous. We call these anchor variables, denoted $A$. Anchor variables aim to generalise the framework of instrumental variables (IV) regression, which assumes the presence of an exogenous variable that affects $Y$ only through $X$ (exclusion restriction). The anchor framework relaxes this exclusion restriction by only assuming exogeneity of $A$.

More concretely, the distribution of $(A, X, Y)$ is entailed by the following SCM:

$$ \begin{pmatrix} X \ Y \ H \end{pmatrix} = \mathbf{B} \begin{pmatrix} X \ Y \ H \end{pmatrix} + \varepsilon + \mathbf{M} A, $$

where $H$ represents hidden confounders.

Assuming the graph is acyclic, the matrix $(I - \mathbf{B})$ is invertible, and the SCM can be written as:

$$ \begin{pmatrix} X \ Y \ H \end{pmatrix} = (I - \mathbf{B})^{-1}(\varepsilon + \mathbf{M} A). $$

In this framework, we allow interventions only on the anchor variables $A$. The intervened SCM is:

$$ \begin{pmatrix} X \ Y \ H \end{pmatrix} = (I - \mathbf{B})^{-1}(\varepsilon + \nu), $$

where $\nu$ is the hard intervention. We denote by $P^\nu$ the interventional distribution, and $\mathbb{E}_\nu$ its corresponding expectation (also referred to as the test distribution). The training distribution is denoted by $P$.

Because of the presence of hidden confounders $H$, the causal parameters are non-identifiable. Thus, learning a causal model that remains stable under all possible interventions is impossible (see previous chapter). Interestingly, Rothenhäusler et al. (2021) also show that causal parameters might even be sub-optimal for prediction in the presence of hidden confounders.

Bounded Interventions

The core idea of anchor regression is to restrict the strength of interventions on $A$. This is formalised as:

$$ \mathbb{Q}^{\text{anchor}} = { P^\nu \mid \nu \nu^\top \preceq \gamma , \mathbb{E}_P[AA^\top] }. $$

By constraining the set of possible interventions, we may obtain better generalisation performance. In practice, interventions are typically bounded, making this assumption reasonable.

Robustness Under Diluted Causality

Rothenhäusler et al. (2021) show that the worst-case risk over the set $\mathbb{Q}^{\text{anchor}}$ takes a particularly simple form as a causal regularisation of Empirical Risk Minimisation (ERM):

$$ \sup_{P^\nu \in \mathbb{Q}^{\text{anchor}}} \mathbb{E}_\nu[(Y - X\mathbf{b})^2] = \mathbb{E}_P[(Y - X\mathbf{b})^2] + (\gamma - 1) , \mathbb{E}_P\left[(P_A(Y - X\mathbf{b}))^2\right], $$

where $P_A(\cdot) = \mathbb{E}[\cdot \mid A]$ denotes the linear projection onto the space spanned by $A$. The parameter $\gamma$ captures the strength of the interventions in $\mathbb{Q}^{\text{anchor}}$ to which we want robustness.

  • The first term on the right-hand side is the standard ERM loss.
  • The second term is a causal regularisation term.

This second term is equivalent to the two-stage least squares formulation of instrumental variables. Special cases include:

  • $\gamma = 1$: standard ERM,
  • $\gamma = \infty$: IV regression,
  • $\gamma = 0$: partialling-out regression, another causal inference technique.

Discussion

We have shown how restricting the set of potential interventions may lead to better generalisation under bounded interventions. So far, this framework has considered only the quadratic loss. In the next chapter, we will see how this regularisation idea can be extended to a wider class of algorithms.

👉 Anchor regression generalisation for multivariate algorithms