Robust Statistical Workflow with PyStan

Stan and its implementation of dynamic Hamiltonian Monte Carlo is an extremely powerful tool for specifying and then fitting complex Bayesian models. In order to ensure a robust analysis, however, that power must be complemented with responsibility.

In particular, while dynamic implementations of Hamiltonian Monte Carlo, i.e. implementations where the integration time is dynamic, do perform well over a large class of models their success is not guaranteed. When they do fail, however, their failures manifest in diagnostics that are readily checked.

By acknowledging and respecting these diagnostics you can ensure that Stan is accurately fitting the Bayesian posterior and hence accurately characterizing your model. And only with an accurate characterization of your model can you properly utilize its insights.

A Little Bit About Markov Chain Monte Carlo

Hamiltonian Monte Carlo is an implementation of Markov chain Monte Carlo, an algorithm which approximates expectations with respect to a given target distribution, $\pi$, $$ \mathbb{E}_{\pi} [ f ] = \int \mathrm{d}q \, \pi (q) \, f(q), $$ using the states of a Markov chain, $\{q_{0}, \ldots, q_{N} \}$, $$ \mathbb{E}_{\pi} [ f ] \approx \hat{f}_{N} = \frac{1}{N + 1} \sum_{n = 0}^{N} f(q_{n}). $$ Typically the target distribution is taken to be the posterior distribution of our specified model.

These estimators are guaranteed to be accurate only asymptotically, as the Markov chain grows to be infinitely long, $$ \lim_{N \rightarrow \infty} \hat{f}_{N} = \mathbb{E}_{\pi} [ f ]. $$

To be useful in applied analyses, we need these Markov chain Monte Carlo estimators to converge to the true expectation values sufficiently quickly that they are reasonably accurate before we exhaust our finite computational resources. This fast convergence requires strong ergodicity conditions to hold, typically a condition called geometric ergodicity between the Markov transition and target distribution. In particular, geometric ergodicity is a sufficient condition for Markov chain Monte Carlo estimators to follow a central limit theorem, which ensures not only that they are unbiased after only a finite number of iterations but also that we can empirically quantify their precision, $$ \hat{f}_{N} - \mathbb{E}_{\pi} [ f ] \sim \mathcal{N} \! \left( 0, \sqrt{ \mathrm{Var}[f] / N_{\mathrm{eff}}} \right). $$

Unfortunately proving geometric ergodicity theoretically is infeasible for any nontrivial problem. Instead we must rely on empirical diagnostics that identify obstructions to geometric ergodicity, and hence well-behaved Markov chain Monte Carlo estimators. For a general Markov transition and target distribution, the best known diagnostic is the split $\hat{R}$ statistic over an ensemble of Markov chains initialized from diffuse points in parameter space. To do any better we need to exploit the particular structure of a given transition or target distribution.

Hamiltonian Monte Carlo, for example, is especially powerful in this regard as its failures to be geometrically ergodic with respect to any target distribution manifest in distinct behaviors that have been developed into sensitive diagnostics. One of these behaviors is the appearance of divergences that indicate the Hamiltonian Markov chain has encountered regions of high curvature in the target distribution which it cannot adequately explore. Another is the energy Bayesian fraction of missing information, or E-BFMI, which quantifies the efficacy of the momentum resampling in between Hamiltonian trajectories.

For more details on Markov chain Monte Carlo and Hamiltonian Monte Carlo see "A Conceptual Introduction to Hamiltonian Monte Carlo" arXiv:1701.02434 (https://arxiv.org/abs/1701.02434).

In this case study I will demonstrate the recommended Stan workflow in Python where we not only fit a model but also scrutinize these diagnostics and ensure an accurate fit.

Setting Up The PyStan Environment

We begin by importing the PyStan module as well at the matplotlib module for basic graphics facilities.

In [1]:
import pystan
import matplotlib
import matplotlib.pyplot as plot

Unfortunately diagnostics are a bit ungainly to check in PyStan 2.16.0, so to facilitate the workflow I have included a utility module with some useful functions.

In [2]:
import stan_utility
help(stan_utility)
Help on module stan_utility:

NAME
    stan_utility

FILE
    /Users/Betancourt/Documents/Research/Code/betanalpha/jupyter_case_studies/pystan_workflow/stan_utility.py

FUNCTIONS
    check_div(fit)
        Check transitions that ended with a divergence
    
    check_energy(fit)
        Checks the energy Bayesian fraction of missing information (E-BFMI)
    
    check_treedepth(fit, max_depth=10)
        Check transitions that ended prematurely due to maximum tree depth limit
    
    compile_model(filename, model_name=None, **kwargs)
        This will automatically cache models - great if you're just running a
        script on the command line.
        
        See http://pystan.readthedocs.io/en/latest/avoiding_recompilation.html
    
    partition_div(fit)
        Returns parameter arrays separated into divergent and non-divergent transitions


Specifying and Fitting A Model in Stan

To demonstrate the recommended Stan workflow let's consider a hierarchical model of the eight schools dataset infamous in the statistical literature,

$$\mu \sim \mathcal{N}(0, 5)$$$$\tau \sim \text{Half-Cauchy}(0, 5)$$$$\theta_{n} \sim \mathcal{N}(\mu, \tau)$$$$y_{n} \sim \mathcal{N}(\theta_{n}, \sigma_{n}),$$

where $n \in \left\{1, \ldots, 8 \right\}$ and the $\left\{ y_{n}, \sigma_{n} \right\}$ are given as data.

For more information on the eight schools dataset see "Bayesian Data Analysis" by Gelman et al.

Specifying the Model with a Stan Program

In particular, let's implement a centered-parameterization of the model which is known to frustrate even sophisticated samplers like Hamiltonian Monte Carlo. In Stan the centered parameterization is specified with the Stan program

In [3]:
with open('eight_schools_cp.stan', 'r') as file:
    print(file.read())
data {
  int<lower=0> J;
  real y[J];
  real<lower=0> sigma[J];
}

parameters {
  real mu;
  real<lower=0> tau;
  real theta[J];
}

model {
  mu ~ normal(0, 5);
  tau ~ cauchy(0, 5);
  theta ~ normal(mu, tau);
  y ~ normal(theta, sigma);
}

Note that we have specified the Stan program in its own file. We strongly recommend keeping your workflow modular by separating the Stan program from the Python environment in this way. Not only does it make it easier to identify and read through the Stan-specific components of your analysis, it also makes it easy to share your models Stan users exploiting workflows in environments, such as R and the command line.

Given the Stan program we then use the compile_model method of our stan_utility module to compile the Stan program into a C++ executable,

In [4]:
model = stan_utility.compile_model('eight_schools_cp.stan')
Using cached StanModel

This is not technically necessary, but it allows us to cache the executable and run this model with Stan multiple times without having to recompile between each run.

Specifying the Data

Similarly, we strongly recommend that you specify the data in its own file.

Data specified in a Python dictionary can be immediately converted to an external Stan data file using PyStan's stan_rdump function,

In [5]:
data = dict(J = 8, y = [28,  8, -3,  7, -1,  1, 18, 12], 
            sigma = [15, 10, 16, 11,  9, 11, 10, 18])

pystan.stan_rdump(data, 'eight_schools.data.R')

At the same time, an existing Stan data file can be read into the Python environment using the read_rdump function,

In [6]:
data = pystan.read_rdump('eight_schools.data.R')

Fitting the Model

With the model and data specified we can now turn to Stan to quantify the resulting posterior distribution with Hamiltonian Monte Carlo,

In [7]:
fit = model.sampling(data=data, seed=194838)

We recommend explicitly specifying the seed of Stan's random number generator, as we have done here, so that we can reproduce these exactly results in the future, at least when using the same machine, operating system, and interface. This is especially helpful for the more subtle pathologies that may not always be found, which results in seemingly stochastic behavior.

By default the sampling method runs 4 Markov chains of Hamiltonian Monte Carlo in parallel, each initialized from a diffuse initial condition to maximize the probability that at least one of the chains might encounter a pathological neighborhood of the posterior, if it exists. Each of those chains proceeds with 1000 warmup iterations and 1000 sampling iterations, totaling 4000 sampling iterations available for diagnostics and analysis.

Validating a Fit in Stan

We are now ready to validate the fit using information contained in the fit object.

Checking Split $\hat{R}$ and Effective Sample Sizes

The first diagnostics we check are universal to Markov chain Monte Carlo and are displayed using the print method of the fit object,

In [8]:
print(fit)
Inference for Stan model: anon_model_71b609c34d59a40b345b5328a36dbb39.
4 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=4000.

           mean se_mean     sd   2.5%    25%    50%    75%  97.5%  n_eff   Rhat
mu         4.28    0.14   2.97  -1.79   2.51   4.38   6.05  10.55  482.0   1.01
tau        3.46    0.27   2.94   0.65   1.26   2.66   4.65  11.19  117.0   1.06
theta[0]   5.89     0.2    5.2  -2.97   3.04   4.93   8.22  18.53  652.0   1.01
theta[1]   4.67    0.15   4.29  -3.76   2.23   3.96   7.03  14.08  780.0    1.0
theta[2]   3.83    0.16   4.75  -6.22   1.31   4.21   6.38  12.71  909.0    1.0
theta[3]   4.66    0.15   4.38  -4.14   2.38   4.74   7.01  14.33  892.0    1.0
theta[4]   3.61    0.16   4.38  -6.53   1.37   3.88   5.91  11.83  742.0   1.01
theta[5]   3.86    0.16   4.54  -6.14   1.62   3.93   6.33  12.83  846.0    1.0
theta[6]   5.96    0.22   4.75  -1.96   2.91   4.84   8.35  17.36  461.0   1.01
theta[7]   4.64    0.14   4.93  -5.22   2.16   4.82   7.18   14.9 1250.0    1.0
lp__     -13.81    1.34   6.55  -25.5 -18.62 -14.24  -9.09  -2.42   24.0   1.18

Samples were drawn using NUTS at Wed Jul 26 23:44:33 2017.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).

Firstly we want to ensure that the split $\hat{R}$ for each parameter is close to 1. Empirically we have found that Rhat > 1.1 is usually indicative of problems in the fit. Here all of the parameters look good except for the log posterior density, lp__, which should inspire some small hesitation in our fit.

Then we want to consider the effective sample size, or n_eff. The issue here is related to the fact that we are estimating the effective sample size from the fit output. When n_eff / n_transitions < 0.001 the estimators that we use are often biased and can significantly overestimate the true effective sample size.

Both large split $\hat{R}$ and low effective samples per transition are consequences of poorly mixing Markov chains. Improving the mixing of the Markov chains almost always requires tweaking the model specification, for example with a reparameterization or stronger priors.

Checking the Tree Depth

The dynamic implementation of Hamiltonian Monte Carlo used in Stan has a maximum trajectory length built in to avoid infinite loops that can occur for non-identified models. For sufficiently complex models, however, Stan can saturate this threshold even if the model is identified, which limits the efficacy of the sampler.

We can check whether that threshold was hit using one of our utility functions,

In [9]:
stan_utility.check_treedepth(fit)
0 of 4000 iterations saturated the maximum tree depth of 10 (0%)

We're good here, but if our fit had saturated the threshold then we would have wanted to rerun with a larger maximum tree depth,

fit = model.sampling(data=data, seed=194838, control=dict(max_treedepth=15))

and then check if still saturated this larger threshold with

stan_utility.check_treedepth(fit, 15)

Checking the E-BFMI

Hamiltonian Monte Carlo proceeds in two phases -- the algorithm first simulates a Hamiltonian trajectory that rapidly explores a slice of the target parameter space before resampling the auxiliary momenta to allow the next trajectory to explore another slice of the target parameter space. Unfortunately, the jumps between these slices induced by the momenta resamplings can be short, which often leads to slow exploration.

We can identify this problem by consulting the energy Bayesian Fraction of Missing Information,

In [10]:
stan_utility.check_energy(fit)
Chain 2: E-BFMI = 0.177681346951
E-BFMI below 0.2 indicates you may need to reparameterize your model

The stan_utility module uses the threshold of 0.2 to diagnose problems, although this is based on preliminary empirical studies and should be taken only as a very rough recommendation. In particular, this diagnostic comes out of recent theoretical work and will be better understood as we apply it to more and more problems. For further discussion see Section 4.2 and 6.1 of "A Conceptual Introduction to Hamiltonian Monte Carlo" arXiv:1701.02434 (https://arxiv.org/abs/1701.02434).

As with split $\hat{R}$ and effective sample size per transition, the problems indicated by low E-BFMI are remedied by tweaking the specification of the model. Unfortunately the exact tweaks required depend on the exact structure of the model and, consequently, there are no generic solutions.

Checking Divergences

Finally, we can check divergences which indicate pathological neighborhoods of the posterior that the simulated Hamiltonian trajectories are not able to explore sufficiently well. For this fit we have a significant number of divergences

In [11]:
stan_utility.check_div(fit)
202.0 of 4000 iterations ended with a divergence (5.05%)
Try running with larger adapt_delta to remove the divergences

indicating that the Markov chains did not completely explore the posterior and that our Markov chain Monte Carlo estimators will be biased.

Divergences, however, can sometimes be false positives. To verify that we have real fitting issues we can rerun with a larger target acceptance probability, adapt_delta, which will force more accurate simulations of Hamiltonian trajectories and reduce the false positives.

In [12]:
fit = model.sampling(data=data, seed=194838, control=dict(adapt_delta=0.9))

Checking again,

In [13]:
sampler_params = fit.get_sampler_params(inc_warmup=False)
stan_utility.check_div(fit)
45.0 of 4000 iterations ended with a divergence (1.125%)
Try running with larger adapt_delta to remove the divergences

we see that while the divergences were reduced they did not completely vanish. In order to argue that divergences are only false positives, the divergences have to be completely eliminated for some adapt_delta sufficiently close to 1. Here we could continue increasing adapt_delta, where we would see that the divergences do not completely vanish, or we can analyze the existing divergences graphically.

If the divergences are not false positives then they will tend to concentrate in the pathological neighborhoods of the posterior. Falsely positive divergent iterations, however, will follow the same distribution as non-divergent iterations.

Here we will use the partition_div function of the stan_utility module to separate divergence and non-divergent iterations, but note that this function works only if your model parameters are reals, vectors, or arrays of reals. More robust functionality is planned for future releases of PyStan.

In [14]:
light="#DCBCBC"
light_highlight="#C79999"
mid="#B97C7C"
mid_highlight="#A25050"
dark="#8F2727"
dark_highlight="#7C0000"
green="#00FF00"

nondiv_params, div_params = stan_utility.partition_div(fit)

plot.scatter([x[0] for x in nondiv_params['theta']], nondiv_params['tau'], \
             color = mid_highlight, alpha=0.05)
plot.scatter([x[0] for x in div_params['theta']], div_params['tau'], \
             color = green, alpha=0.5)

plot.gca().set_xlabel("theta_1")
plot.gca().set_ylabel("tau")

plot.show()
WARNING:root:`dtypes` ignored when `permuted` is False.

One of the challenges with a visual analysis of divergences is determining exactly which parameters to examine. Consequently visual analyses are most useful when there are already components of the model about which you are suspicious, as in this case where we know that the correlation between random effects (theta_1 through theta_8) and the hierarchical standard deviation, tau, can be problematic.

Indeed we see the divergences clustering towards small values of tau where the posterior abruptly stops. This abrupt stop is indicative of a transition into a pathological neighborhood that Stan was not able to penetrate.

In order to avoid this issue we have to consider a modification to our model, and in this case we can appeal to a non-centered parameterization of the same model that does not suffer these issues.

A Successful Fit

Multiple diagnostics have indicated that our fit of the centered parameterization of our hierarchical model is not to be trusted, so let's instead consider the complementary non-centered parameterization,

In [15]:
with open('eight_schools_ncp.stan', 'r') as file:
    print(file.read())
data {
  int<lower=0> J;
  real y[J];
  real<lower=0> sigma[J];
}

parameters {
  real mu;
  real<lower=0> tau;
  real theta_tilde[J];
}

transformed parameters {
  real theta[J];
  for (j in 1:J)
    theta[j] = mu + tau * theta_tilde[j];
}

model {
  mu ~ normal(0, 5);
  tau ~ cauchy(0, 5);
  theta_tilde ~ normal(0, 1);
  y ~ normal(theta, sigma);
}

In [16]:
model = stan_utility.compile_model('eight_schools_ncp.stan')
fit = model.sampling(data=data, seed=194838)
Using cached StanModel
In [17]:
print(fit)

stan_utility.check_treedepth(fit)
stan_utility.check_energy(fit)
stan_utility.check_div(fit)
Inference for Stan model: anon_model_b4ca739f9fe7ffcdbf0d530f00d0a587.
4 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=4000.

                 mean se_mean     sd   2.5%    25%    50%    75%  97.5%  n_eff   Rhat
mu               4.37    0.05   3.39  -2.49    2.1   4.51   6.69  10.94 4000.0    1.0
tau              3.59    0.06   3.19   0.13   1.27   2.79   4.95  11.68 2761.0    1.0
theta_tilde[0]   0.31    0.02   0.96  -1.61  -0.34   0.32   0.95   2.16 4000.0    1.0
theta_tilde[1]    0.1    0.01   0.92  -1.72  -0.51   0.12   0.72    1.9 3766.0    1.0
theta_tilde[2]  -0.08    0.02   0.98   -2.0  -0.74  -0.08   0.55   1.91 4000.0    1.0
theta_tilde[3]   0.05    0.01   0.94  -1.86  -0.57   0.07   0.68   1.87 4000.0    1.0
theta_tilde[4]  -0.16    0.02   0.95  -2.03  -0.79  -0.17   0.47   1.71 3681.0    1.0
theta_tilde[5]  -0.09    0.01   0.92  -1.95  -0.71   -0.1   0.52   1.76 3787.0    1.0
theta_tilde[6]   0.37    0.02   0.97  -1.62  -0.26   0.38   1.01   2.28 4000.0    1.0
theta_tilde[7]   0.09    0.02   0.96  -1.77  -0.56   0.11   0.74   1.92 4000.0    1.0
theta[0]         6.15    0.09   5.53  -3.49   2.76   5.66   8.88  19.12 4000.0    1.0
theta[1]          4.9    0.07   4.56  -4.39   2.04    4.9   7.77  14.09 4000.0    1.0
theta[2]         3.91    0.08   5.37  -7.72   0.98   4.19   7.18  13.58 4000.0    1.0
theta[3]         4.63    0.08    4.8  -5.02    1.7   4.68   7.57  14.11 4000.0    1.0
theta[4]         3.62    0.08   4.79  -7.34   0.89   3.97   6.71  12.27 4000.0    1.0
theta[5]         3.97    0.07   4.73   -6.0   1.22   4.15   6.96   12.7 4000.0    1.0
theta[6]         6.37    0.08   5.07  -2.22   3.04   5.97   9.02  18.49 4000.0    1.0
theta[7]         4.92    0.08   5.32  -5.75   1.82   4.86   7.91  15.68 4000.0    1.0
lp__            -6.91    0.06   2.34 -12.17  -8.29  -6.56  -5.23  -3.32 1754.0    1.0

Samples were drawn using NUTS at Wed Jul 26 23:44:34 2017.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).
0 of 4000 iterations saturated the maximum tree depth of 10 (0%)
0.0 of 4000 iterations ended with a divergence (0.0%)

With this more appropriate implementation of our model all of the diagnostics are clean and we can now utilize Markov chain Monte Carlo estimators of expectations, such as parameter means and variances, to accurately characterize our model's posterior distribution.

Acknowledgements

I thank Sean Talts for helping to make the functions in stan_utility more Pythonic and compatible with both Python 2 and Python 3, and Sean Talts and Maggie Lieu for helpful comments on the notebook.

In [ ]: