MCMC - Performance & Stan

Lecture 20

Dr. Colin Rundel

.pytensorrc setup

On the departmental server, to avoid missing BLAS library warnings with PyMC / PyTensor, create ~/.pytensorrc with the following contents:

[blas]
ldflags = -lflexiblas

Example - Gaussian Process

Data

d
            x         y
0    0.000258 -0.126679
1    0.019467  0.125848
2    0.021670  0.222179
3    0.022712  0.228689
4    0.029758  0.487060
..        ...       ...
245  1.925221 -0.090028
246  1.936691 -0.295139
247  1.948376 -0.106351
248  1.958606  0.029466
249  1.995473 -0.071615

[250 rows x 2 columns]

GP model

X = d.x.to_numpy().reshape(-1,1)
y = d.y.to_numpy()

with pm.Model() as model:
  l = pm.Gamma("l", alpha=2, beta=1)
  s = pm.HalfCauchy("s", beta=5)
  nug = pm.HalfCauchy("nug", beta=5)

  cov = s**2 * pm.gp.cov.ExpQuad(input_dim=1, ls=l)
  gp = pm.gp.Marginal(cov_func=cov)

  like = gp.marginal_likelihood(
    "y", X=X, y=y, sigma=nug
  )

MAP estimates

with model:
  gp_map = pm.find_MAP()
MAP                                            1% 0:00:16 logp = 189.21, ||grad|| = 0.091468
pprint(gp_map)
{'l': array(0.18938),
 'l_log__': array(-1.66401),
 'nug': array(0.09815),
 'nug_log__': array(-2.32123),
 's': array(0.37139),
 's_log__': array(-0.99049)}

Full Posterior Sampling

with model:
  post_nuts = pm.sample(cores=2)
                                                                                                              
  Progress               Draws   Divergences   Step size   Grad evals   Sampling Speed   Elapsed   Remaining  
 ──────────────────────────────────────────────────────────────────────────────────────────────────────────── 
  ━━━━━━━━━━━━━━━━━━━━   1999    0             0.473       3            83.89 drawss/s   0:00:23   0:00:00    
  ━━━━━━━━━━━━━━━━━━━━   1999    0             0.562       7            82.32 drawss/s   0:00:24   0:00:00    
                                                                                                              
az.summary(post_nuts)
      mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  ess_tail  r_hat
l    0.210  0.039   0.136    0.283      0.002    0.002     690.0     503.0    1.0
s    0.554  0.296   0.208    1.028      0.013    0.024     687.0     572.0    1.0
nug  0.099  0.005   0.090    0.107      0.000    0.000    1069.0    1041.0    1.0
# gp_map w/o logged parameters
{k: v for k, v in gp_map.items() if "log" not in k}
{'l': array(0.18938), 's': array(0.37139), 'nug': array(0.09815)}

Trace plots

ax = az.plot_trace(post_nuts)
plt.gcf().set_layout_engine("constrained")
plt.show()

Conditional Predictions (MAP)

X_new = np.linspace(0, 2.2, 221).reshape(-1, 1)

with model:
  y_pred = gp.conditional("y_pred", X_new)
  pred_map = pm.sample_posterior_predictive(
    [gp_map], var_names=["y_pred"]
  )
Sampling ... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 / 0:00:00

Conditional Predictions (full posterior)

with model:
  pred_post = pm.sample_posterior_predictive(
    post_nuts.sel(draw=slice(None,None,10)), var_names=["y_pred"]
  )
Sampling ... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 / 0:00:00

Conditional Predictions (posterior + nugget)

with model:
  y_star = gp.conditional("y_star", X_new, pred_noise=True)
  predn_post = pm.sample_posterior_predictive(
    post_nuts.sel(draw=slice(None,None,10)), var_names=["y_star"]
  )
Sampling ... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 / 0:00:00

Sampler Backends

Alternative NUTS sampler backends

Beyond the ability of PyMC to use different sampling steps - it can also use different sampler algorithm implementations to run your model.

These can be changed via the nuts_sampler argument which currently supports:

  • pymc - standard NUTS sampler using PyMC’s PyTensor backend

  • blackjax - uses the blackjax library which is a collection of samplers written for JAX

  • numpyro - probabilistic programming library inspired by Pyro, built using JAX

  • nutpie - provides a wrapper to the nuts-rs Rust library (slight variation on NUTS implementation)

Notes on installation

Before using the above sampler backends, you will need to install the relevant packages.

For example, to use the blackjax sampler, you will need to install the blackjax package and its dependencies (e.g. jax). Similarly, for numpyro, you will need to install the numpyro package and its dependencies.

Many of these packages have extras that also need to be designated if you want full functionality (e.g. GPU support). Some common examples:

  • uv add "jax[cuda]" to add CUDA support to JAX

  • uv add "nutpie[all]" to add both pymc and stan support for nutpie

Sampler backend comparison

The four backends differ in their underlying implementation and parallelism model:

Backend Language Parallelism
pymc C (Aesara/PyTensor) 1 core per chain
blackjax JAX (XLA-compiled) Multiple cores / GPU across chains
numpyro JAX (XLA-compiled) Multiple cores / GPU across chains
nutpie Rust (nuts-rs) 1 core per chain
  • JAX-based samplers (blackjax, numpyro) JIT-compile the model and can exploit multi-core CPUs or GPUs, but have higher compilation overhead on first run
  • pymc and nutpie run each chain on a single core; chains are run in parallel via Python multiprocessing
  • On small models the JAX compilation cost can dominate; on large models or with GPU hardware the JAX backends tend to win

The nutpie sampler

nutpie wraps the nuts-rs Rust implementation of NUTS:

  • Written in Rust for low-overhead, cache-friendly execution — no Python/C interpreter overhead per leapfrog step
  • Uses a mass-matrix adaptation scheme that estimates the full dense mass matrix (vs. PyMC’s diagonal default), which can improve sampling efficiency on correlated posteriors
  • Supports both PyMC models and Stan models
  • Supports pre-compiling the model which separates the compilation cost from sampling

Performance

start = time.time()
with model:
    post_nuts = pm.sample(
      nuts_sampler="pymc", chains=4, progressbar=False
    )
print(f"pymc: {time.time() - start:.1f}s")
pymc: 31.3s
start = time.time()
with model:
    post_blackjax = pm.sample(
      nuts_sampler="blackjax", chains=4, progressbar=False
    )
print(f"blackjax: {time.time() - start:.1f}s")
blackjax: 27.8s
start = time.time()
with model:
    post_pyro = pm.sample(
      nuts_sampler="numpyro", chains=4, progressbar=False
    )
print(f"numpyro: {time.time() - start:.1f}s")
numpyro: 26.5s
start = time.time()
with model:
    post_nutpie = pm.sample(
      nuts_sampler="nutpie", chains=4, progressbar=False
    )
print(f"nutpie: {time.time() - start:.1f}s")
nutpie: 29.8s
import nutpie
compiled = nutpie.compile_pymc_model(model)
start = time.time()
post_nutpie2 = nutpie.sample(compiled, chains=4, progress_bar=False)
print(f"nutpie (compiled): {time.time() - start:.1f}s")
nutpie (compiled): 13.0s

Results

az.summary(post_nuts)
      mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  ess_tail  r_hat
l    0.212  0.041   0.137    0.286      0.001    0.001    1558.0     835.0   1.01
s    0.569  0.339   0.210    1.107      0.013    0.033    1453.0     916.0   1.01
nug  0.099  0.005   0.091    0.107      0.000    0.000    2050.0    1749.0   1.00
az.summary(post_blackjax)
      mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  ess_tail  r_hat
l    0.211  0.040   0.140    0.283      0.001    0.001    1599.0    1379.0    1.0
s    0.555  0.313   0.197    1.051      0.009    0.024    1696.0    1596.0    1.0
nug  0.099  0.005   0.090    0.108      0.000    0.000    2247.0    1707.0    1.0
az.summary(post_pyro)
      mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  ess_tail  r_hat
l    0.211  0.041   0.139    0.290      0.001    0.001    1603.0    1535.0    1.0
s    0.566  0.325   0.201    1.092      0.010    0.020    1512.0    1303.0    1.0
nug  0.099  0.005   0.091    0.108      0.000    0.000    2306.0    2134.0    1.0
az.summary(post_nutpie.posterior[["l","s","nug"]])
      mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  ess_tail  r_hat
l    0.211  0.042   0.140    0.292      0.001    0.001    1150.0    1173.0   1.01
s    0.568  0.350   0.194    1.070      0.014    0.043    1226.0    1095.0   1.01
nug  0.099  0.005   0.090    0.107      0.000    0.000    2573.0    2511.0   1.00
az.summary(post_nutpie2.posterior[["l","s","nug"]])
      mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  ess_tail  r_hat
l    0.211  0.040   0.143    0.290      0.001    0.001    1408.0    1439.0    1.0
s    0.561  0.317   0.207    1.047      0.009    0.023    1470.0    1516.0    1.0
nug  0.099  0.005   0.090    0.108      0.000    0.000    2534.0    2384.0    1.0

Stan

Stan in Python & R

At the moment both Python & R offer two variants of Stan:

  • pystan & RStan - native language interface to the underlying Stan C++ libraries

    • The former does not play nicely with Jupyter (or quarto or positron) - see here for a fix
  • CmdStanPy & CmdStanR - are wrappers around the CmdStan command line interface

    • Interface is through files (e.g. ./model.stan)

Any of the above tools will require a modern C++ toolchain (C++17 support required).

Stan process

Stan file basics

Stan code is divided up into specific blocks depending on usage - all of the following blocks are optional but the ordering has to match what is given below.

functions {
  // user-defined functions
}
data {
  // declares the required data for the model
}
transformed data {
   // allows the definition of constants and transformations of the data
}
parameters {
   // declares the model’s parameters
}
transformed parameters {
   // variables defined in terms of data and parameters
}
model {
   // defines the log probability function
}
generated quantities {
   // derived quantities based on parameters, data, and random number generation
}

GP model in Stan

      Lec20/gp.stan

data {
  int<lower=1> N;
  array[N] real x;
  vector[N] y;
}
parameters {
  real<lower=0> l;
  real<lower=0> s;
  real<lower=0> nug;
}
model {
  // Covariance
  matrix[N, N] K = gp_exp_quad_cov(x, s, l);
  K = add_diag(K, nug^2);
  matrix[N, N] L = cholesky_decompose(K);
  
  // priors
  l ~ gamma(2, 1);
  s ~ cauchy(0, 5);
  nug ~ cauchy(0, 1);
  
  // model
  y ~ multi_normal_cholesky(rep_vector(0, N), L);
}

Fit

from cmdstanpy import CmdStanModel
d_stan = d.to_dict('list')
d_stan["N"] = len(d["x"])
gp = CmdStanModel(stan_file='Lec20/gp.stan')
gp_fit = gp.sample(data=d_stan, show_progress=False)
gp_fit.summary()
            Mean      MCSE    StdDev       MAD          5%         50%         95%  ESS_bulk  ESS_tail    R_hat
lp__  416.696000  0.031684  1.228810  1.015580  414.302000  416.996000  418.034000   1529.35   2550.13  1.00266
l       0.211213  0.001178  0.041281  0.038427    0.150708    0.207788    0.283732   1302.66   1274.87  1.00239
s       0.568497  0.010401  0.334768  0.198640    0.265488    0.481437    1.158120   1270.61   1198.52  1.00352
nug     0.098575  0.000104  0.004480  0.004533    0.091598    0.098449    0.106188   1894.82   1899.83  1.00129

Trace plots

ax = az.plot_trace(gp_fit, compact=False)
plt.show()

Diagnostics

gp_fit.divergences
array([0, 0, 0, 0])
gp_fit.max_treedepths
array([0, 0, 0, 0])
gp_fit.method_variables().keys()
dict_keys(['lp__', 'accept_stat__', 'stepsize__', 'treedepth__', 'n_leapfrog__', 'divergent__', 'energy__'])
print(gp_fit.diagnose())
Checking sampler transitions treedepth.
Treedepth satisfactory for all transitions.

Checking sampler transitions for divergences.
No divergent transitions found.

Checking E-BFMI - sampler transitions HMC potential energy.
E-BFMI satisfactory.

Rank-normalized split effective sample size satisfactory for all parameters.

Rank-normalized split R-hat values satisfactory for all parameters.

Processing complete, no problems detected.

nutpie & stan

The nutpie package can also be used to compile and run stan models, it uses a package called bridgestan to interface with Stan.

import nutpie
m = nutpie.compile_stan_model(filename="Lec20/gp.stan")
m = m.with_data(x=d["x"],y=d["y"],N=len(d["x"]))
gp_fit_nutpie = nutpie.sample(m, chains=4)
az.summary(gp_fit_nutpie)
      mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  ess_tail  r_hat
l    0.211  0.042   0.140    0.292      0.001    0.001    1342.0    1219.0    1.0
s    0.564  0.333   0.205    1.100      0.011    0.022    1359.0    1136.0    1.0
nug  0.098  0.004   0.090    0.107      0.000    0.000    2353.0    2351.0    1.0

Trace plots

Performance

t_stan_gp = timeit.repeat(
    lambda: gp.sample(data=d_stan, show_progress=False),
    repeat=3, number=1
)
34.81s ± 0.68s (3 runs, 1 loop each)
t_stan_nutpie_gp = timeit.repeat(
    lambda: nutpie.sample(m, chains=4, progress_bar=False),
    repeat=3, number=1
)
14.83s ± 0.59s (3 runs, 1 loop each)

Posterior predictive model

      Lec20/gp2.stan

functions {
  // From https://mc-stan.org/docs/stan-users-guide/gaussian-processes.html#predictive-inference-with-a-gaussian-process
  vector gp_pred_rng(
    array[] real x2,
    vector y1,
    array[] real x1,
    real alpha,
    real rho,
    real sigma,
    real delta
  ) {
    int N1 = rows(y1);
    int N2 = size(x2);
    vector[N2] f2;
    {
      matrix[N1, N1] L_K;
      vector[N1] K_div_y1;
      matrix[N1, N2] k_x1_x2;
      matrix[N1, N2] v_pred;
      vector[N2] f2_mu;
      matrix[N2, N2] cov_f2;
      matrix[N2, N2] diag_delta;
      matrix[N1, N1] K;
      K = gp_exp_quad_cov(x1, alpha, rho);
      for (n in 1:N1) {
        K[n, n] = K[n, n] + square(sigma);
      }
      L_K = cholesky_decompose(K);
      K_div_y1 = mdivide_left_tri_low(L_K, y1);
      K_div_y1 = mdivide_right_tri_low(K_div_y1', L_K)';
      k_x1_x2 = gp_exp_quad_cov(x1, x2, alpha, rho);
      f2_mu = (k_x1_x2' * K_div_y1);
      v_pred = mdivide_left_tri_low(L_K, k_x1_x2);
      cov_f2 = gp_exp_quad_cov(x2, alpha, rho) - v_pred' * v_pred;
      diag_delta = diag_matrix(rep_vector(delta, N2));

      f2 = multi_normal_rng(f2_mu, cov_f2 + diag_delta);
    }
    return f2;
  }
}
data {
  int<lower=1> N;      // number of observations
  array[N] real x;         // univariate covariate
  vector[N] y;         // target variable
  int<lower=1> Np;     // number of test points
  array[Np] real xp;       // univariate test points
}
transformed data {
  real delta = 1e-9;
}
parameters {
  real<lower=0> l;
  real<lower=0> s;
  real<lower=0> nug;
}
model {
  // Covariance
  matrix[N, N] K = gp_exp_quad_cov(x, s, l);
  K = add_diag(K, nug^2);
  matrix[N, N] L = cholesky_decompose(K);
  
  // priors
  l ~ gamma(2, 1);
  s ~ cauchy(0, 5);
  nug ~ cauchy(0, 1);
  
  // model
  y ~ multi_normal_cholesky(rep_vector(0, N), L);
}
generated quantities {
  vector[Np] f = gp_pred_rng(xp, y, x, s, l, nug, delta);
}

Posterior predictive fit

d_stan = d.to_dict('list')
d_stan["N"] = len(d_stan["x"])
d_stan["xp"] = np.linspace(0, 2.2, 221)
d_stan["Np"] = len(d_stan["xp"])
gp2 = CmdStanModel(stan_file='Lec20/gp2.stan')
gp2_fit = gp2.sample(data=d_stan, show_progress=False)
gp2_fit.summary()
              Mean      MCSE    StdDev       MAD          5%         50%         95%  ESS_bulk  ESS_tail     R_hat
lp__    416.632000  0.035060  1.311560  1.051900  414.023000  416.983000  418.035000   1515.66   1890.34  1.000970
l         0.211678  0.000984  0.040204  0.037695    0.150728    0.208051    0.285263   1729.96   1521.48  1.000200
s         0.564591  0.007880  0.311079  0.196101    0.268283    0.481254    1.135830   1788.47   1577.89  1.000080
nug       0.098698  0.000110  0.004787  0.004621    0.091180    0.098506    0.107040   1935.62   1578.61  1.000380
f[1]      0.134059  0.000817  0.048615  0.049339    0.052479    0.135168    0.212925   3570.28   3764.86  0.999776
...            ...       ...       ...       ...         ...         ...         ...       ...       ...       ...
f[217]    0.089114  0.004508  0.281242  0.268842   -0.350164    0.083123    0.573051   3903.52   3875.60  1.001710
f[218]    0.093848  0.004769  0.297292  0.282442   -0.372340    0.085499    0.605130   3895.54   3896.30  1.001680
f[219]    0.097935  0.005028  0.313258  0.296261   -0.393446    0.088212    0.643016   3885.86   3916.85  1.001640
f[220]    0.101391  0.005284  0.329082  0.310729   -0.412781    0.090177    0.666344   3877.83   3877.07  1.001600
f[221]    0.104235  0.005536  0.344706  0.321935   -0.429439    0.091224    0.689960   3871.86   3937.23  1.001550

[225 rows x 10 columns]

Draws

gp2_fit.stan_variable("f").shape
(4000, 221)
np.mean(gp2_fit.stan_variable("f"), axis=0)
array([ 0.13406,  0.19034,  0.24532,  0.29852,  0.34948,  0.39774,  0.44286,  0.48445,  0.52213,  0.55558,  0.5845 ,  0.60865,  0.62782,  0.64189,  0.65074,  0.65435,  0.65272,  0.64592,  0.63406,
        0.6173 ,  0.59586,  0.56997,  0.53995,  0.5061 ,  0.4688 ,  0.42842,  0.38539,  0.34013,  0.29308,  0.2447 ,  0.19545,  0.14579,  0.09617,  0.04704, -0.00117, -0.04803, -0.09315, -0.13615,
       -0.1767 , -0.21446, -0.24917, -0.28057, -0.30847, -0.33267, -0.35307, -0.36957, -0.38213, -0.39073, -0.39543, -0.39628, -0.39339, -0.38691, -0.37701, -0.36389, -0.34778, -0.32892, -0.30758,
       -0.28403, -0.25857, -0.23149, -0.20309, -0.17367, -0.14353, -0.11296, -0.08225, -0.05166, -0.02146,  0.00811,  0.03682,  0.06447,  0.09085,  0.11581,  0.13918,  0.16084,  0.18068,  0.1986 ,
        0.21452,  0.2284 ,  0.24019,  0.24987,  0.25743,  0.26288,  0.26622,  0.26749,  0.26674,  0.26399,  0.25932,  0.25278,  0.24444,  0.23438,  0.22269,  0.20945,  0.19476,  0.17871,  0.16141,
        0.14299,  0.12354,  0.10321,  0.08212,  0.06042,  0.03825,  0.01577, -0.00685, -0.02943, -0.0518 , -0.07376, -0.09512, -0.11567, -0.1352 , -0.15351, -0.17039, -0.18563, -0.19904, -0.21044,
       -0.21965, -0.22654, -0.23096, -0.23283, -0.23208, -0.22867, -0.22262, -0.21398, -0.20282, -0.18927, -0.17349, -0.15569, -0.1361 , -0.11497, -0.0926 , -0.06927, -0.04531, -0.02104,  0.00324,
        0.02721,  0.05058,  0.07306,  0.09442,  0.11442,  0.13288,  0.14963,  0.16454,  0.17754,  0.18854,  0.19755,  0.20455,  0.20958,  0.21269,  0.21397,  0.2135 ,  0.21139,  0.20776,  0.20272,
        0.19641,  0.18895,  0.18047,  0.17107,  0.16089,  0.15002,  0.13856,  0.12662,  0.11428,  0.10161,  0.0887 ,  0.07561,  0.06242,  0.04917,  0.03594,  0.02278,  0.00974, -0.00312, -0.01574,
       -0.02806, -0.04004, -0.0516 , -0.06269, -0.07325, -0.08323, -0.09256, -0.10118, -0.10903, -0.11607, -0.12224, -0.1275 , -0.13178, -0.13507, -0.13733, -0.13854, -0.13869, -0.13777, -0.13578,
       -0.13276, -0.12872, -0.1237 , -0.11775, -0.11094, -0.10332, -0.09497, -0.08598, -0.07643, -0.06641, -0.05603, -0.04538, -0.03456, -0.02366, -0.01279, -0.00202,  0.00854,  0.01882,  0.02875,
        0.03826,  0.04728,  0.05577,  0.06369,  0.071  ,  0.07769,  0.08373,  0.08911,  0.09385,  0.09793,  0.10139,  0.10423])

Plot

Tennis Model Performance

The Data

  • Jeff Sackmann’s ATP tennis dataset — match records from 1968 to present
  • Each match record contains winning and losing player names; these are encoded as integer IDs
  • Goal is to estimate latent player skill levels from match outcomes using a hierarchical Bradley-Terry model
    • \(\text{logit} \, P(\text{player } i \text{ beats player } j) = \text{skill}_i - \text{skill}_j\)
  • Benchmark datasets created by filtering from different start_year values (1970, 1980, 1990, 2000, 2010, 2020), giving datasets of increasing size

Testing Setup

All benchmarks run on a departmental server with no CPU or GPU constraints:

  • CPU - AMD Ryzen 9 7950X (16-core / 32-thread)
  • GPU - 2× NVIDIA RTX A4000 (16 GB VRAM each)
  • Each sampler run with 2 chains, 1000 draws, 1000 tuning steps
  • JAX-based samplers (numpyro, blackjax) benchmarked in both cpu_parallel and gpu_parallel modes
  • nutpie benchmarked with both PyMC and Stan compiled models
  • Timing measured as wall-clock seconds per sampling run (including compile time)
  • All CPU and GPU results are using 64-bit precision

The Models

PyMC

with pm.Model() as model:
    player_sd = pm.HalfNormal("player_sd", sigma=1.0)

    player_skills_raw = pm.Normal(
        "player_skills_raw", 0., sigma=1.,
        shape=(n_players,)
    )
    player_skills = pm.Deterministic(
        "player_skills", player_skills_raw * player_sd
    )
    logit_p = player_skills[winner_ids] - player_skills[loser_ids]

    win_lik = pm.Bernoulli(
        "win_lik", logit_p=logit_p,
        observed=np.ones(n_matches)
    )

Stan

parameters {
    vector[n_players] player_skills_raw;
    real<lower=0> player_sd;
}
transformed parameters {
    vector[n_players] player_skills =
        player_skills_raw * player_sd;
}
model {
    player_skills_raw ~ std_normal();
    player_sd ~ normal(0, 1);
    vector[n_matches] mu;
    for (n in 1:n_matches)
        mu[n] = player_skills[winner_ids[n]]
              - player_skills[loser_ids[n]];
    1 ~ bernoulli_logit(mu);
}

Results

Results (log scale)

Results (relative to NumPyro GPU)