5.28.1
Lecture 18
PyMC is a probabilistic programming library for Python that allows users to build Bayesian models with a simple Python API and fit them using Markov chain Monte Carlo (MCMC) methods.
Modern - Includes state-of-the-art inference algorithms, including MCMC (NUTS) and variational inference (ADVI).
User friendly - Write your models using friendly Python syntax. Learn Bayesian modeling from the many example notebooks.
Fast - Uses PyTensor as its computational backend to compile to C and JAX, run your models on the GPU, and benefit from complex graph-optimizations.
Batteries included - Includes probability distributions, Gaussian processes, ABC, SMC and much more. It integrates nicely with ArviZ for visualizations and diagnostics, as well as Bambi for high-level mixed-effect models.
Community focused - Ask questions on discourse, join MeetUp events, follow us on Twitter, and start contributing.
ArviZ is a Python package for exploratory analysis of Bayesian models. Includes functions for posterior analysis, data storage, sample diagnostics, model checking, and comparison.
Interoperability - Integrates with all major probabilistic programming libraries: PyMC, CmdStanPy, PyStan, Pyro, NumPyro, and emcee.
Large Suite of Visualizations - Provides over 25 plotting functions for all parts of Bayesian workflow: visualizing distributions, diagnostics, and model checking. See the gallery for examples.
State of the Art Diagnostics - Latest published diagnostics and statistics are implemented, tested and distributed with ArviZ.
Flexible Model Comparison - Includes functions for comparing models with information criteria, and cross validation (both approximate and brute force).
Built for Collaboration - Designed for flexible cross-language serialization using netCDF or Zarr formats. ArviZ also has a Julia version that uses the same data schema.
Labeled Data - Builds on top of xarray to work with labeled dimensions and coordinates.
All models are derived from pymc’s Model() class. Unlike what we have seen previously, PyMC makes heavy use of Python’s context manager using the with statement to add model components to a model.
PyMC requires a model context to register components - creating a distribution outside a with block raises an error:
PyMC distributions are implemented as TensorVariable objects - some useful attributes and functions:
[RNG(<Generator(PCG64) at 0x10F407920>), Constant(<pytensor.tensor.type_other.NoneTypeT object at 0x145527380>, data=None), TensorConstant(TensorType(int8, shape=()), data=array(0, dtype=int8)), TensorConstant(TensorType(int8, shape=()), data=array(1, dtype=int8))]
This context construction makes it possible to add additional components to an existing (named) model via subsequent with statements
Note that we defined \(y|x \sim \mathcal{N}(x, 1)\), so what is happening when we use pm.draw(norm.y)?
Each time we ask for a draw from y, PyMC is first drawing from x for us.
We will now build a basic model where we know what the solution should look like and compare the results.
Note the use of observed - this fixes the variable’s value to the provided data (i.e. it is not sampled). This is how we condition models on observed data.
pm.sample() results<xarray.Dataset> Size: 40kB
Dimensions: (chain: 4, draw: 1000)
Coordinates:
* chain (chain) int64 32B 0 1 2 3
* draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
Data variables:
p (chain, draw) float64 32kB 0.3347 0.3435 0.2629 ... 0.3276 0.3486
Attributes:
created_at: 2026-03-17T13:49:16.801350+00:00
arviz_version: 0.23.4
inference_library: pymc
inference_library_version: 5.28.1
sampling_time: 0.19043397903442383
tuning_steps: 1000array([0, 1, 2, 3])
array([ 0, 1, 2, ..., 997, 998, 999], shape=(1000,))
array([[0.33467, 0.3435 , 0.26291, ..., 0.25183, 0.48601, 0.28246],
[0.42668, 0.31782, 0.44187, ..., 0.55314, 0.45624, 0.29552],
[0.60711, 0.40384, 0.53182, ..., 0.40928, 0.3921 , 0.40702],
[0.43526, 0.43526, 0.54199, ..., 0.32757, 0.32757, 0.34861]],
shape=(4, 1000))<xarray.Dataset> Size: 528kB
Dimensions: (chain: 4, draw: 1000)
Coordinates:
* chain (chain) int64 32B 0 1 2 3
* draw (draw) int64 8kB 0 1 2 3 4 5 ... 995 996 997 998 999
Data variables: (12/18)
index_in_trajectory (chain, draw) int64 32kB 3 -1 1 1 -1 ... -1 -1 2 0 -1
max_energy_error (chain, draw) float64 32kB -0.5154 ... -0.06161
lp (chain, draw) float64 32kB -3.221 -3.164 ... -3.138
tree_depth (chain, draw) int64 32kB 2 1 2 1 1 1 ... 1 1 1 2 1 1
divergences (chain, draw) int64 32kB 0 0 0 0 0 0 ... 0 0 0 0 0 0
step_size_bar (chain, draw) float64 32kB 1.372 1.372 ... 1.348
... ...
perf_counter_diff (chain, draw) float64 32kB 5.317e-05 ... 2.275e-05
diverging (chain, draw) bool 4kB False False ... False False
process_time_diff (chain, draw) float64 32kB 5.4e-05 ... 2.2e-05
smallest_eigval (chain, draw) float64 32kB nan nan nan ... nan nan
reached_max_treedepth (chain, draw) bool 4kB False False ... False False
energy_error (chain, draw) float64 32kB -0.4137 ... -0.06161
Attributes:
created_at: 2026-03-17T13:49:16.806783+00:00
arviz_version: 0.23.4
inference_library: pymc
inference_library_version: 5.28.1
sampling_time: 0.19043397903442383
tuning_steps: 1000array([0, 1, 2, 3])
array([ 0, 1, 2, ..., 997, 998, 999], shape=(1000,))
array([[ 3, -1, 1, ..., 2, 2, 2],
[ 3, 2, 2, ..., -1, -1, -2],
[-1, -1, -1, ..., 1, -1, 1],
[-3, 0, -3, ..., 2, 0, -1]], shape=(4, 1000))array([[-0.51545, -0.02987, 0.74863, ..., 0.37656, -0.43313, -0.34427],
[ 0.09075, -0.05768, 0.04345, ..., 1.23917, -0.99833, -0.25378],
[ 1.29685, -1.93528, 0.85724, ..., -0.44422, -0.03206, 0.02663],
[ 0.14438, 0.5934 , 1.47895, ..., 0.36977, 0.51084, -0.06161]],
shape=(4, 1000))array([[-3.22085, -3.16426, -4.28019, ..., -4.55318, -4.07664, -3.87644],
[-3.29854, -3.37052, -3.44508, ..., -5.63471, -3.61717, -3.65754],
[-7.45632, -3.1471 , -5.05935, ..., -3.17557, -3.1021 , -3.16316],
[-3.3768 , -3.3768 , -5.32403, ..., -3.27715, -3.27715, -3.13809]],
shape=(4, 1000))array([[2, 1, 2, ..., 2, 2, 2],
[2, 2, 2, ..., 2, 1, 2],
[1, 1, 1, ..., 1, 1, 1],
[2, 1, 2, ..., 2, 1, 1]], shape=(4, 1000))array([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]], shape=(4, 1000))array([[1.37165, 1.37165, 1.37165, ..., 1.37165, 1.37165, 1.37165],
[1.32571, 1.32571, 1.32571, ..., 1.32571, 1.32571, 1.32571],
[1.40214, 1.40214, 1.40214, ..., 1.40214, 1.40214, 1.40214],
[1.34828, 1.34828, 1.34828, ..., 1.34828, 1.34828, 1.34828]],
shape=(4, 1000))array([[3., 1., 3., ..., 3., 3., 3.],
[3., 3., 3., ..., 3., 1., 3.],
[1., 1., 1., ..., 1., 1., 1.],
[3., 1., 3., ..., 3., 1., 1.]], shape=(4, 1000))array([[0.9798 , 1. , 0.6813 , ..., 0.79171, 1. , 1. ],
[0.93906, 0.98779, 0.97303, ..., 0.48313, 1. , 0.99734],
[0.27339, 1. , 0.42433, ..., 1. , 1. , 0.97373],
[0.90604, 0.55245, 0.336 , ..., 0.81019, 0.59999, 1. ]],
shape=(4, 1000))array([[0.97067, 0.97067, 0.97067, ..., 0.97067, 0.97067, 0.97067],
[1.22038, 1.22038, 1.22038, ..., 1.22038, 1.22038, 1.22038],
[1.15645, 1.15645, 1.15645, ..., 1.15645, 1.15645, 1.15645],
[1.16165, 1.16165, 1.16165, ..., 1.16165, 1.16165, 1.16165]],
shape=(4, 1000))array([[3.64089, 3.2391 , 4.32469, ..., 4.9668 , 4.57928, 4.05298],
[3.42377, 3.44517, 3.64978, ..., 5.63471, 4.8918 , 3.67754],
[8.01765, 5.52213, 5.06727, ..., 3.74661, 3.15127, 3.17126],
[3.39885, 4.12882, 6.37591, ..., 3.86657, 4.04152, 3.23763]],
shape=(4, 1000))array([[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]], shape=(4, 1000))array([[273471.07599, 273471.07607, 273471.07611, ..., 273471.1314 ,
273471.13147, 273471.13154],
[273471.07835, 273471.07843, 273471.07851, ..., 273471.14098,
273471.14105, 273471.14109],
[273471.08154, 273471.08159, 273471.08164, ..., 273471.14875,
273471.1488 , 273471.14885],
[273471.08769, 273471.08777, 273471.08782, ..., 273471.1623 ,
273471.16236, 273471.1624 ]], shape=(4, 1000))array([[0.00005, 0.00003, 0.00005, ..., 0.00005, 0.00005, 0.00005],
[0.00006, 0.00005, 0.00006, ..., 0.00005, 0.00002, 0.00005],
[0.00003, 0.00003, 0.00002, ..., 0.00003, 0.00003, 0.00003],
[0.00006, 0.00003, 0.00005, ..., 0.00005, 0.00002, 0.00002]],
shape=(4, 1000))array([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]], shape=(4, 1000))array([[0.00005, 0.00003, 0.00005, ..., 0.00005, 0.00005, 0.00005],
[0.00006, 0.00005, 0.00006, ..., 0.00005, 0.00002, 0.00005],
[0.00003, 0.00002, 0.00002, ..., 0.00003, 0.00003, 0.00003],
[0.00006, 0.00003, 0.00005, ..., 0.00004, 0.00002, 0.00002]],
shape=(4, 1000))array([[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]], shape=(4, 1000))array([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]], shape=(4, 1000))array([[-0.41371, -0.02987, 0.56059, ..., 0.37656, -0.34224, -0.11375],
[ 0.04817, 0.03732, 0.01685, ..., 1.23917, -0.99833, 0.008 ],
[ 1.29685, -1.93528, 0.85724, ..., -0.44422, -0.03206, 0.02663],
[ 0.14021, 0. , 0.95574, ..., 0.00094, 0. , -0.06161]],
shape=(4, 1000))<xarray.Dataset> Size: 16B
Dimensions: (x_dim_0: 1)
Coordinates:
* x_dim_0 (x_dim_0) int64 8B 0
Data variables:
x (x_dim_0) int64 8B 5
Attributes:
created_at: 2026-03-17T13:49:16.808502+00:00
arviz_version: 0.23.4
inference_library: pymc
inference_library_version: 5.28.1array([0])
array([5])
Xarray makes working with labelled multi-dimensional arrays in Python simple, efficient, and fun!
Xarray introduces labels in the form of dimensions, coordinates and attributes on top of raw NumPy-like arrays, which allows for a more intuitive, more concise, and less error-prone developer experience. The package includes a large and growing library of domain-agnostic functions for advanced analytics and visualization with these data structures.
Xarray is inspired by and borrows heavily from pandas, the popular data analysis package focused on labelled tabular data. It integrates tightly with dask for parallel computing.
trace<xarray.Dataset> Size: 40kB
Dimensions: (chain: 4, draw: 1000)
Coordinates:
* chain (chain) int64 32B 0 1 2 3
* draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
Data variables:
p (chain, draw) float64 32kB 0.3347 0.3435 0.2629 ... 0.3276 0.3486
Attributes:
created_at: 2026-03-17T13:49:16.801350+00:00
arviz_version: 0.23.4
inference_library: pymc
inference_library_version: 5.28.1
sampling_time: 0.19043397903442383
tuning_steps: 1000
Posterior values, or subsets, can be converted to DataFrames via the to_dataframe() method
Standard MCMC diagnostic statistics are available via summary() from ArviZ
individual methods are available for each statistic,
We want to fit a linear regression model to the following synthetic data,
post_m = trace.posterior['m'].sel(chain=0, draw=slice(0,None,10))
post_b = trace.posterior['b'].sel(chain=0, draw=slice(0,None,10))
plt.figure(layout="constrained")
plt.scatter(x, y, s=30, label='data')
for m, b in zip(post_m.values, post_b.values):
plt.plot(x, m*x + b, c='gray', alpha=0.1)
plt.plot(x, 6*x + 2, label='true regression line', lw=3., c='red')
plt.legend(loc='best')
plt.show()Draws for observed variables can also be generated (posterior predictive draws) via the sample_posterior_predictive() method.
<xarray.Dataset> Size: 360kB
Dimensions: (chain: 4, draw: 1000, y_dim_0: 11)
Coordinates:
* chain (chain) int64 32B 0 1 2 3
* draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
* y_dim_0 (y_dim_0) int64 88B 0 1 2 3 4 5 6 7 8 9 10
Data variables:
y (chain, draw, y_dim_0) float64 352kB 1.91 0.5098 ... 7.131 9.175
Attributes:
created_at: 2026-03-17T13:49:20.418944+00:00
arviz_version: 0.23.4
inference_library: pymc
inference_library_version: 5.28.1array([0, 1, 2, 3])
array([ 0, 1, 2, ..., 997, 998, 999], shape=(1000,))
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
array([[[ 1.91011, 0.50981, 2.59095, ..., 6.43548, 8.77594, 8.61524],
[ 3.18742, 4.67153, 1.07601, ..., 6.78672, 3.61698, 8.0777 ],
[ 3.63459, 2.6331 , 3.53191, ..., 7.28758, 7.99757, 9.6308 ],
...,
[ 4.2251 , 4.52701, 4.75021, ..., 5.60264, 7.56166, 7.73234],
[ 5.66788, 3.54712, 4.60727, ..., 6.8031 , 7.9435 , 7.33588],
[ 2.49872, 2.89499, 0.73279, ..., 2.54344, 4.80291, 5.40609]],
[[ 1.06852, 1.61278, 1.18639, ..., 4.3495 , 8.05515, 9.78978],
[ 0.22298, 1.44626, 1.64953, ..., 8.06522, 7.36619, 7.50524],
[ 4.31092, 3.39197, 7.56417, ..., 5.99653, 3.69241, 5.62348],
...,
[ 4.07533, 3.8864 , 3.66937, ..., 5.71958, 5.53281, 7.09341],
[ 2.44573, 5.6334 , 2.6641 , ..., 6.87421, 7.8983 , 5.69425],
[ 1.05691, 2.97674, -0.16589, ..., 4.94704, 4.95218, 10.91855]],
[[ 3.81183, 2.71322, 3.0321 , ..., 7.05421, 8.40673, 10.14874],
[ 1.26884, 3.15099, 1.6376 , ..., 6.81197, 7.79014, 7.46467],
[ 4.6739 , 2.73554, 2.32695, ..., 6.7327 , 7.67699, 7.61547],
...,
[ 3.05352, 2.94331, 3.81466, ..., 8.21252, 5.8116 , 6.0766 ],
[ 3.1235 , 1.70126, 3.01471, ..., 8.38225, 6.29895, 7.22501],
[ 0.71035, 1.87126, 2.38402, ..., 5.45755, 7.87741, 7.41636]],
[[ 2.68959, 3.45528, 3.05976, ..., 7.90385, 8.14811, 7.55025],
[ 1.56779, 1.38546, 2.72054, ..., 7.5357 , 5.36787, 7.35296],
[ 2.06695, 1.54197, 1.36088, ..., 4.33732, 3.82479, 7.7061 ],
...,
[ 2.71452, 3.26631, 4.86782, ..., 6.0191 , 7.29525, 10.28038],
[ 1.1766 , 2.54292, 2.04468, ..., 7.8327 , 7.32065, 7.4227 ],
[ 2.60092, 3.43102, 4.50401, ..., 6.4927 , 7.13094, 9.17468]]],
shape=(4, 1000, 11))<xarray.Dataset> Size: 176B
Dimensions: (y_dim_0: 11)
Coordinates:
* y_dim_0 (y_dim_0) int64 88B 0 1 2 3 4 5 6 7 8 9 10
Data variables:
y (y_dim_0) float64 88B 2.471 1.409 4.633 3.487 ... 6.816 5.157 9.15
Attributes:
created_at: 2026-03-17T13:49:20.420160+00:00
arviz_version: 0.23.4
inference_library: pymc
inference_library_version: 5.28.1array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
array([2.47144, 1.40902, 4.63271, 3.48735, 3.67941, 5.88716, 6.45959, 5.56348,
6.8157 , 5.15732, 9.15004])By wrapping m*x + b in pm.Deterministic() we can track this derived quantity in the trace and generate posterior predictive draws for the mean function (rather than just the observation-level predictions).
with pm.Model() as lm2:
m = pm.Normal('m', mu=0, sigma=50)
b = pm.Normal('b', mu=0, sigma=50)
sigma = pm.HalfNormal('sigma', sigma=5)
y_hat = pm.Deterministic("y_hat", m*x + b)
likelihood = pm.Normal('y', mu=y_hat, sigma=sigma, observed=y)
trace = pm.sample(random_seed=1234, progressbar=False)
pp = pm.sample_posterior_predictive(
trace, var_names=["y_hat"], progressbar=False
) mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
m 5.629 1.329 3.145 8.093 0.035 0.029 1418.0 1755.0 1.00
b 2.156 0.775 0.670 3.572 0.020 0.018 1458.0 1568.0 1.00
sigma 1.364 0.405 0.799 2.116 0.013 0.017 1192.0 1541.0 1.01
y_hat[0] 2.156 0.775 0.670 3.572 0.020 0.018 1458.0 1568.0 1.00
y_hat[1] 2.719 0.667 1.499 3.982 0.017 0.015 1544.0 1633.0 1.00
y_hat[2] 3.282 0.568 2.180 4.291 0.014 0.012 1712.0 1919.0 1.00
y_hat[3] 3.845 0.487 2.914 4.743 0.011 0.010 2067.0 2144.0 1.00
y_hat[4] 4.408 0.432 3.566 5.204 0.008 0.008 2869.0 2429.0 1.00
y_hat[5] 4.971 0.414 4.155 5.741 0.006 0.008 4130.0 2799.0 1.00
y_hat[6] 5.534 0.438 4.719 6.382 0.007 0.009 4394.0 2752.0 1.00
y_hat[7] 6.097 0.497 5.170 7.061 0.008 0.011 3573.0 2550.0 1.00
y_hat[8] 6.660 0.581 5.615 7.809 0.011 0.012 2821.0 2475.0 1.00
y_hat[9] 7.222 0.681 6.026 8.583 0.014 0.014 2372.0 2506.0 1.00
y_hat[10] 7.785 0.791 6.289 9.284 0.017 0.016 2113.0 2517.0 1.00
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
beta[0] -230.110 1051.471 -2122.764 1588.034 506.354 264.540 5.0 17.0 3.11
beta[1] 189.346 1534.946 -1467.009 3398.762 731.692 624.755 5.0 11.0 3.08
beta[2] -297.668 457.031 -1001.502 479.395 215.023 73.083 5.0 11.0 2.19
beta[3] -358.025 1054.157 -1871.283 1643.062 511.156 275.699 5.0 11.0 2.79
beta[4] -460.791 658.009 -1687.435 373.611 303.674 170.469 5.0 11.0 2.58
... ... ... ... ... ... ... ... ... ...
beta[96] -1016.168 686.708 -2248.531 -145.503 330.527 140.027 5.0 13.0 2.98
beta[97] -221.856 482.609 -1240.211 619.665 208.244 117.138 5.0 14.0 2.06
beta[98] -192.017 716.281 -1649.817 779.600 337.012 188.262 5.0 11.0 2.75
beta[99] -671.608 833.620 -2339.851 328.992 381.139 145.968 5.0 25.0 2.32
sigma 2.115 1.304 0.202 4.396 0.510 0.142 7.0 17.0 1.68
[101 rows x 9 columns]
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
beta[0] 0.585 7.763 -13.738 15.280 0.103 0.125 5836.0 3052.0 1.00
beta[1] 1.061 6.676 -11.240 13.635 0.094 0.120 5007.0 2739.0 1.00
beta[2] -0.271 7.631 -14.916 13.594 0.101 0.151 5684.0 2330.0 1.00
beta[3] -1.331 7.216 -15.468 11.740 0.112 0.135 4229.0 1168.0 1.00
beta[4] 0.772 7.525 -13.477 14.330 0.098 0.123 5841.0 3134.0 1.00
... ... ... ... ... ... ... ... ... ...
beta[96] -0.111 6.579 -12.839 12.116 0.100 0.106 4317.0 2591.0 1.00
beta[97] -0.874 6.673 -13.067 11.592 0.094 0.105 4989.0 2864.0 1.00
beta[98] 1.392 6.728 -10.798 14.804 0.109 0.117 3787.0 2665.0 1.00
beta[99] -1.180 6.845 -14.337 11.251 0.093 0.119 5449.0 2653.0 1.00
sigma 2.145 1.083 0.606 4.136 0.128 0.058 64.0 98.0 1.07
[101 rows x 9 columns]
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
beta[10] 4.042 6.840 -9.114 16.509 0.102 0.127 4538.0 2360.0 1.0
beta[20] -4.228 7.184 -17.563 9.451 0.110 0.131 4258.0 1975.0 1.0
beta[30] 5.523 6.625 -6.166 18.190 0.128 0.134 2738.0 1067.0 1.0
beta[40] -4.958 8.103 -20.187 10.227 0.111 0.159 5370.0 2778.0 1.0
beta[50] 5.478 6.546 -6.357 17.987 0.094 0.102 4855.0 2791.0 1.0
beta[60] -5.773 6.937 -18.517 7.836 0.105 0.115 4403.0 2725.0 1.0
beta[70] 4.751 7.121 -8.624 18.226 0.094 0.120 5739.0 2989.0 1.0
beta[80] -7.812 6.105 -19.841 2.702 0.098 0.102 3869.0 2779.0 1.0
def plot_slope(trace, prior="beta", chain=0):
post = (trace.posterior[prior]
.to_dataframe()
.reset_index()
.query(f"chain == {chain}")
)
sns.catplot(
x="beta_dim_0", y="beta", data=post,
kind="boxen", linewidth=0, color='blue',
aspect=2, showfliers=False
)
plt.tight_layout()
plt.xticks(range(0,110,10))
plt.show()
Using a Laplace distribution as our prior is the Bayesian analogue of L1 (Lasso) regularization — the resulting MAP estimate is equivalent to the frequentist Lasso solution.
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
beta[0] 0.133 0.768 -1.636 1.532 0.019 0.025 757.0 1910.0 1.02
beta[1] 0.289 0.715 -1.072 1.757 0.032 0.019 430.0 1677.0 1.01
beta[2] -0.105 0.774 -1.857 1.275 0.013 0.026 3599.0 1912.0 1.14
beta[3] -0.355 0.764 -1.776 1.141 0.071 0.020 93.0 1647.0 1.03
beta[4] 0.098 0.776 -1.422 1.636 0.018 0.041 1729.0 1493.0 1.14
... ... ... ... ... ... ... ... ... ...
beta[96] 0.031 0.670 -1.232 1.421 0.014 0.039 2116.0 1987.0 1.15
beta[97] -0.088 0.663 -1.468 1.120 0.020 0.027 1090.0 427.0 1.02
beta[98] 0.422 0.782 -1.020 1.827 0.090 0.020 83.0 480.0 1.04
beta[99] -0.399 0.721 -1.686 1.015 0.031 0.025 492.0 2385.0 1.02
sigma 0.746 0.450 0.295 1.585 0.080 0.048 17.0 8.0 1.17
[101 rows x 9 columns]
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
beta[10] 7.986 1.447 5.751 10.352 0.317 0.144 27.0 9.0 1.09
beta[20] -7.990 1.544 -10.314 -5.499 0.369 0.177 23.0 578.0 1.12
beta[30] 8.706 0.910 6.931 10.416 0.033 0.033 948.0 2160.0 1.04
beta[40] -9.160 1.743 -11.683 -5.894 0.307 0.031 41.0 242.0 1.07
beta[50] 8.727 1.231 6.742 10.640 0.282 0.151 26.0 9.0 1.10
beta[60] -9.199 1.079 -11.193 -7.066 0.039 0.038 795.0 2021.0 1.03
beta[70] 8.503 1.094 6.439 10.509 0.041 0.046 896.0 2280.0 1.03
beta[80] -9.800 0.854 -11.451 -8.204 0.061 0.024 237.0 1933.0 1.02
Based on PyMC Out-Of-Sample Predictions example
x1 x2 y
0 -3.207674 0.859021 0
1 0.128200 2.827588 0
2 1.481783 -0.116956 0
3 0.305238 -1.378604 0
4 1.727488 -0.926357 1
.. ... ... ..
245 -2.182813 3.314672 0
246 -2.362568 2.078652 0
247 0.114571 2.249021 0
248 2.093975 -1.212528 1
249 1.241667 -2.363412 0
[250 rows x 3 columns]
from sklearn.model_selection import train_test_split
y, X = patsy.dmatrices("y ~ x1 * x2", data=df)
X_lab = X.design_info.column_names
y_lab = y.design_info.column_names
y = np.asarray(y).flatten()
X = np.asarray(X)
X_train, X_test, y_train, y_test = train_test_split(
X, y, train_size=0.7, random_state=1234
)with pm.Model(coords = {"coeffs": X_lab}) as model:
# data containers
X = pm.Data("X", X_train)
y = pm.Data("y", y_train)
# priors
b = pm.Normal("b", mu=0, sigma=3, dims="coeffs")
# linear model
mu = X @ b
# link function
p = pm.Deterministic("p", pm.math.invlogit(mu))
# likelihood
obs = pm.Bernoulli("obs", p=p, observed=y) mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
b[Intercept] -0.600 0.307 -1.194 -0.050 0.007 0.005 2127.0 2351.0 1.0
b[x1] 1.477 0.358 0.831 2.170 0.008 0.006 1849.0 2268.0 1.0
b[x2] -1.201 0.304 -1.774 -0.637 0.007 0.005 1910.0 2048.0 1.0
b[x1:x2] 2.658 0.500 1.800 3.688 0.013 0.010 1632.0 2144.0 1.0
p[0] 0.224 0.074 0.084 0.354 0.002 0.001 1912.0 2306.0 1.0
... ... ... ... ... ... ... ... ... ...
p[170] 0.404 0.089 0.234 0.566 0.002 0.001 2496.0 2485.0 1.0
p[171] 0.014 0.019 0.000 0.043 0.000 0.001 1739.0 1815.0 1.0
p[172] 0.985 0.015 0.959 1.000 0.000 0.001 2046.0 1982.0 1.0
p[173] 0.545 0.067 0.414 0.660 0.001 0.001 3491.0 3288.0 1.0
p[174] 1.000 0.000 1.000 1.000 0.000 0.000 1702.0 2679.0 1.0
[179 rows x 9 columns]
p_post = post.posterior["p"].mean(dim=["chain", "draw"])
fig, axes = plt.subplots(1, 2, figsize=(10, 4), layout="constrained")
sc = axes[0].scatter(X_train[:,1], X_train[:,2], c=p_post, cmap='coolwarm', vmin=0, vmax=1, s=15)
axes[0].set(xlabel="x1", ylabel="x2", title="Posterior mean p (train)")
cb = plt.colorbar(sc, ax=axes[0])
axes[1].scatter(X_train[:,1], X_train[:,2], c=(p_post > 0.5).astype(int), cmap='coolwarm', vmin=0, vmax=1, s=15)
axes[1].set(xlabel="x1", ylabel="x2", title="Predicted class (train)")
plt.show()<xarray.Dataset> Size: 6MB
Dimensions: (chain: 4, draw: 1000, coeffs: 4, p_dim_0: 175)
Coordinates:
* chain (chain) int64 32B 0 1 2 3
* draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
* coeffs (coeffs) <U9 144B 'Intercept' 'x1' 'x2' 'x1:x2'
* p_dim_0 (p_dim_0) int64 1kB 0 1 2 3 4 5 6 7 ... 168 169 170 171 172 173 174
Data variables:
b (chain, draw, coeffs) float64 128kB -0.05825 1.681 ... -1.83 3.949
p (chain, draw, p_dim_0) float64 6MB 0.2694 0.5574 1.0 ... 0.5946 1.0
Attributes:
created_at: 2026-03-17T13:49:27.035568+00:00
arviz_version: 0.23.4
inference_library: pymc
inference_library_version: 5.28.1
sampling_time: 0.3888070583343506
tuning_steps: 1000array([0, 1, 2, 3])
array([ 0, 1, 2, ..., 997, 998, 999], shape=(1000,))
array(['Intercept', 'x1', 'x2', 'x1:x2'], dtype='<U9')
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,
154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167,
168, 169, 170, 171, 172, 173, 174])array([[[-0.05825, 1.68124, -1.05657, 2.45565],
[-0.66652, 0.81142, -1.08228, 2.31992],
[-0.33851, 1.39107, -0.95414, 2.96128],
...,
[-0.65008, 1.68385, -1.19153, 2.67779],
[-0.64458, 1.44224, -0.78403, 2.41496],
[-0.42461, 0.91177, -1.0866 , 2.43232]],
[[-0.67913, 1.84803, -1.40907, 2.91414],
[-0.32478, 1.96976, -1.41528, 2.89949],
[-0.62647, 1.44283, -0.99409, 2.6026 ],
...,
[-0.22938, 1.15244, -1.14299, 2.44078],
[-1.07189, 1.88163, -1.1652 , 2.89325],
[-1.05027, 1.81028, -1.42338, 3.02394]],
[[-0.16035, 1.4986 , -1.45435, 3.15169],
[-0.8516 , 0.92087, -0.83495, 2.02336],
[-0.13517, 1.82689, -1.33311, 2.91467],
...,
[-0.48101, 1.22307, -0.89896, 2.16073],
[-0.56095, 1.55998, -1.14908, 2.72059],
[-0.65067, 1.56294, -1.1283 , 3.47324]],
[[-0.05985, 1.15796, -1.50369, 2.67329],
[-1.58097, 1.78383, -0.97018, 2.47163],
[-0.74179, 1.34764, -1.45499, 3.13086],
...,
[-0.70375, 1.59326, -1.28051, 2.40135],
[-0.88113, 2.16069, -1.66928, 3.0906 ],
[-0.7924 , 2.05688, -1.83042, 3.94918]]], shape=(4, 1000, 4))array([[[0.26939, 0.55743, 1. , ..., 0.99523, 0.66339, 1. ],
[0.3015 , 0.72347, 1. , ..., 0.95597, 0.49159, 1. ],
[0.27789, 0.70815, 1. , ..., 0.99551, 0.55582, 1. ],
...,
[0.17754, 0.48494, 1. , ..., 0.99301, 0.53928, 1. ],
[0.19415, 0.46033, 1. , ..., 0.98944, 0.46679, 1. ],
[0.33849, 0.75972, 1. , ..., 0.97452, 0.55346, 1. ]],
[[0.16324, 0.50103, 1. , ..., 0.99528, 0.56986, 1. ],
[0.19929, 0.53569, 1. , ..., 0.99721, 0.65944, 1. ],
[0.20765, 0.5433 , 1. , ..., 0.99049, 0.50309, 1. ],
...,
[0.33623, 0.72802, 1. , ..., 0.98522, 0.6193 , 1. ],
[0.10876, 0.35145, 1. , ..., 0.99449, 0.43151, 1. ],
[0.12437, 0.45407, 1. , ..., 0.99371, 0.47554, 1. ]],
[[0.31959, 0.79845, 1. , ..., 0.9963 , 0.67931, 1. ],
[0.22899, 0.5362 , 1. , ..., 0.94592, 0.41512, 1. ],
[0.25263, 0.63147, 1. , ..., 0.99734, 0.68316, 1. ],
...,
[0.25305, 0.54616, 1. , ..., 0.97965, 0.52628, 1. ],
[0.20863, 0.56331, 1. , ..., 0.9929 , 0.54796, 1. ],
[0.21244, 0.71235, 1. , ..., 0.99743, 0.50126, 1. ]],
[[0.39759, 0.83148, 1. , ..., 0.98786, 0.70792, 1. ],
[0.06752, 0.18242, 1. , ..., 0.98349, 0.29095, 1. ],
[0.23069, 0.73655, 1. , ..., 0.99128, 0.53697, 1. ],
...,
[0.17751, 0.45333, 1. , ..., 0.9863 , 0.54559, 1. ],
[0.11465, 0.40406, 1. , ..., 0.99664, 0.57255, 1. ],
[0.15472, 0.69796, 1. , ..., 0.99878, 0.59458, 1. ]]],
shape=(4, 1000, 175))<xarray.Dataset> Size: 528kB
Dimensions: (chain: 4, draw: 1000)
Coordinates:
* chain (chain) int64 32B 0 1 2 3
* draw (draw) int64 8kB 0 1 2 3 4 5 ... 995 996 997 998 999
Data variables: (12/18)
index_in_trajectory (chain, draw) int64 32kB -1 3 4 5 2 ... -1 -6 -5 -2
max_energy_error (chain, draw) float64 32kB 0.7143 0.8379 ... -0.3487
lp (chain, draw) float64 32kB -54.98 -55.09 ... -55.64
tree_depth (chain, draw) int64 32kB 2 3 3 4 2 3 ... 3 3 4 3 3 2
divergences (chain, draw) int64 32kB 0 0 0 0 0 0 ... 0 0 0 0 0 0
step_size_bar (chain, draw) float64 32kB 0.4975 0.4975 ... 0.5673
... ...
perf_counter_diff (chain, draw) float64 32kB 7.137e-05 ... 6.663e-05
diverging (chain, draw) bool 4kB False False ... False False
process_time_diff (chain, draw) float64 32kB 7.1e-05 ... 6.6e-05
smallest_eigval (chain, draw) float64 32kB nan nan nan ... nan nan
reached_max_treedepth (chain, draw) bool 4kB False False ... False False
energy_error (chain, draw) float64 32kB -0.377 0.2598 ... -0.232
Attributes:
created_at: 2026-03-17T13:49:27.041467+00:00
arviz_version: 0.23.4
inference_library: pymc
inference_library_version: 5.28.1
sampling_time: 0.3888070583343506
tuning_steps: 1000array([0, 1, 2, 3])
array([ 0, 1, 2, ..., 997, 998, 999], shape=(1000,))
array([[-1, 3, 4, ..., -3, 2, 3],
[-3, 1, 4, ..., 6, 5, 2],
[-3, -3, 4, ..., 3, -8, 1],
[ 6, -3, 2, ..., -6, -5, -2]], shape=(4, 1000))array([[ 0.71426, 0.83795, -0.83111, ..., 0.46019, 0.76972, 1.04768],
[-0.1623 , 1.01006, -0.8886 , ..., -0.68675, 0.78752, -0.20025],
[ 0.27642, 0.49482, 1.26843, ..., -0.02894, 0.0413 , 2.42641],
[ 0.54456, 0.69548, -0.50553, ..., 0.74353, 0.3579 , -0.34866]],
shape=(4, 1000))array([[-54.98048, -55.08988, -54.14337, ..., -52.27666, -53.26737, -54.04426],
[-52.79676, -54.86587, -52.27038, ..., -52.91864, -54.16457, -53.16204],
[-55.10906, -54.8415 , -54.83688, ..., -52.10868, -52.0458 , -55.81293],
[-56.0737 , -61.49927, -54.2691 , ..., -52.84386, -54.89327, -55.64337]],
shape=(4, 1000))array([[2, 3, 3, ..., 3, 3, 2],
[3, 3, 3, ..., 3, 3, 2],
[2, 4, 3, ..., 4, 4, 4],
[3, 3, 2, ..., 3, 3, 2]], shape=(4, 1000))array([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]], shape=(4, 1000))array([[0.4975 , 0.4975 , 0.4975 , ..., 0.4975 , 0.4975 , 0.4975 ],
[0.55867, 0.55867, 0.55867, ..., 0.55867, 0.55867, 0.55867],
[0.53781, 0.53781, 0.53781, ..., 0.53781, 0.53781, 0.53781],
[0.56734, 0.56734, 0.56734, ..., 0.56734, 0.56734, 0.56734]],
shape=(4, 1000))array([[ 3., 7., 7., ..., 7., 7., 3.],
[ 7., 7., 7., ..., 7., 7., 3.],
[ 3., 15., 7., ..., 15., 15., 11.],
[ 7., 7., 3., ..., 7., 7., 3.]], shape=(4, 1000))array([[0.76139, 0.87891, 0.91453, ..., 0.78181, 0.78129, 0.67325],
[0.99838, 0.70948, 0.97681, ..., 1. , 0.78863, 0.98459],
[0.8558 , 0.90322, 0.70386, ..., 0.99929, 0.98286, 0.47758],
[0.74635, 0.83899, 0.86285, ..., 0.74263, 0.89061, 0.99737]],
shape=(4, 1000))array([[0.58754, 0.58754, 0.58754, ..., 0.58754, 0.58754, 0.58754],
[0.58001, 0.58001, 0.58001, ..., 0.58001, 0.58001, 0.58001],
[0.58205, 0.58205, 0.58205, ..., 0.58205, 0.58205, 0.58205],
[0.59406, 0.59406, 0.59406, ..., 0.59406, 0.59406, 0.59406]],
shape=(4, 1000))array([[59.11869, 56.63231, 56.16869, ..., 53.28841, 54.81568, 55.42176],
[53.12271, 56.20037, 56.37673, ..., 55.99144, 55.28081, 54.64459],
[56.06345, 58.0166 , 57.42368, ..., 52.63637, 52.89784, 57.92707],
[57.75146, 62.71908, 62.59779, ..., 53.84075, 55.30436, 56.47349]],
shape=(4, 1000))array([[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]], shape=(4, 1000))array([[273481.21629, 273481.21639, 273481.21655, ..., 273481.38278,
273481.38294, 273481.3831 ],
[273481.22044, 273481.22073, 273481.22089, ..., 273481.37818,
273481.37834, 273481.37849],
[273481.22544, 273481.22554, 273481.22584, ..., 273481.39374,
273481.39402, 273481.39431],
[273481.22432, 273481.22449, 273481.22471, ..., 273481.386 ,
273481.38615, 273481.38631]], shape=(4, 1000))array([[0.00007, 0.00013, 0.00014, ..., 0.00014, 0.00013, 0.00007],
[0.00026, 0.00014, 0.00014, ..., 0.00014, 0.00013, 0.00007],
[0.00007, 0.00028, 0.00014, ..., 0.00026, 0.00026, 0.0002 ],
[0.00014, 0.00018, 0.00007, ..., 0.00013, 0.00013, 0.00007]],
shape=(4, 1000))array([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]], shape=(4, 1000))array([[0.00007, 0.00013, 0.00013, ..., 0.00014, 0.00013, 0.00007],
[0.00025, 0.00014, 0.00014, ..., 0.00013, 0.00013, 0.00007],
[0.00007, 0.00028, 0.00014, ..., 0.00026, 0.00026, 0.0002 ],
[0.00014, 0.00018, 0.00007, ..., 0.00013, 0.00013, 0.00007]],
shape=(4, 1000))array([[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]], shape=(4, 1000))array([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]], shape=(4, 1000))array([[-0.37703, 0.25981, -0.11623, ..., 0.08085, 0.11853, 0.40197],
[-0.0515 , 0.49763, -0.53393, ..., -0.503 , 0.08932, -0.20025],
[ 0.19402, -0.09699, -0.15399, ..., -0.02229, 0.01047, 1.06569],
[ 0.52102, 0.69548, -0.50553, ..., 0.29807, 0.3579 , -0.23201]],
shape=(4, 1000))<xarray.Dataset> Size: 3kB
Dimensions: (obs_dim_0: 175)
Coordinates:
* obs_dim_0 (obs_dim_0) int64 1kB 0 1 2 3 4 5 6 ... 169 170 171 172 173 174
Data variables:
obs (obs_dim_0) int64 1kB 0 0 1 0 1 1 1 0 1 0 ... 1 1 1 1 0 0 0 1 1 1
Attributes:
created_at: 2026-03-17T13:49:27.043646+00:00
arviz_version: 0.23.4
inference_library: pymc
inference_library_version: 5.28.1array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,
154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167,
168, 169, 170, 171, 172, 173, 174])array([0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1,
0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0,
0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0,
1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1,
1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1,
0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0,
1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1,
1, 0, 0, 0, 1, 1, 1])<xarray.Dataset> Size: 10kB
Dimensions: (X_dim_0: 175, X_dim_1: 4, y_dim_0: 175)
Coordinates:
* X_dim_0 (X_dim_0) int64 1kB 0 1 2 3 4 5 6 7 ... 168 169 170 171 172 173 174
* X_dim_1 (X_dim_1) int64 32B 0 1 2 3
* y_dim_0 (y_dim_0) int64 1kB 0 1 2 3 4 5 6 7 ... 168 169 170 171 172 173 174
Data variables:
X (X_dim_0, X_dim_1) float64 6kB 1.0 -0.8999 -0.1756 ... -3.397 11.63
y (y_dim_0) float64 1kB 0.0 0.0 1.0 0.0 1.0 ... 0.0 0.0 1.0 1.0 1.0
Attributes:
created_at: 2026-03-17T13:49:27.044271+00:00
arviz_version: 0.23.4
inference_library: pymc
inference_library_version: 5.28.1array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,
154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167,
168, 169, 170, 171, 172, 173, 174])array([0, 1, 2, 3])
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,
154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167,
168, 169, 170, 171, 172, 173, 174])array([[ 1. , -0.89989, -0.17557, 0.158 ],
[ 1. , -1.67942, -0.60079, 1.00899],
[ 1. , 2.68515, 3.60742, 9.68646],
[ 1. , 2.69419, -2.932 , -7.89937],
[ 1. , -1.72056, -1.15562, 1.98832],
[ 1. , 1.39256, 0.55746, 0.7763 ],
[ 1. , -2.4052 , -0.71888, 1.72906],
[ 1. , 0.70277, -0.99446, -0.69887],
[ 1. , 2.46359, 1.61154, 3.97019],
[ 1. , 0.43956, -0.62857, -0.27629],
[ 1. , 3.46658, 2.28206, 7.91093],
[ 1. , 1.59779, 3.26548, 5.21756],
[ 1. , 2.64752, 0.04849, 0.12838],
[ 1. , -1.64511, 0.432 , -0.71068],
[ 1. , -2.20534, -1.1967 , 2.63914],
[ 1. , 4.10149, 2.76474, 11.33953],
[ 1. , 3.829 , 1.35041, 5.17071],
[ 1. , -1.4024 , -4.3979 , 6.1676 ],
[ 1. , 2.42934, -0.08882, -0.21577],
[ 1. , 0.25195, -1.17936, -0.29714],
...
[ 1. , 1.04192, 2.84321, 2.96241],
[ 1. , -0.24823, -0.43724, 0.10854],
[ 1. , 1.28262, 2.75645, 3.53547],
[ 1. , -0.06483, 0.35711, -0.02315],
[ 1. , -2.18112, 0.40746, -0.88873],
[ 1. , 2.07077, 2.65026, 5.48806],
[ 1. , 0.24363, -0.27189, -0.06624],
[ 1. , 2.66243, -0.10489, -0.27925],
[ 1. , 1.57252, -1.21139, -1.90494],
[ 1. , -2.86799, -2.60529, 7.47194],
[ 1. , -2.31482, -2.10299, 4.86805],
[ 1. , -4.32753, -0.69455, 3.00569],
[ 1. , 0.80815, 4.85354, 3.92238],
[ 1. , -1.61406, -0.81406, 1.31395],
[ 1. , 2.13768, -0.37337, -0.79814],
[ 1. , 0.31893, 0.77569, 0.24739],
[ 1. , -0.02196, 3.40024, -0.07466],
[ 1. , 1.65199, 0.87413, 1.44405],
[ 1. , 0.15744, -0.70449, -0.11092],
[ 1. , -3.4222 , -3.39729, 11.62622]])array([0., 0., 1., 0., 1., 1., 1., 0., 1., 0., 1., 1., 0., 0., 1., 1., 1., 1.,
1., 0., 0., 0., 1., 1., 0., 1., 1., 1., 1., 0., 0., 1., 1., 1., 0., 0.,
1., 1., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 1.,
1., 0., 1., 0., 0., 0., 1., 0., 1., 1., 0., 0., 0., 1., 0., 1., 0., 0.,
1., 0., 1., 1., 0., 1., 1., 0., 1., 1., 1., 0., 0., 1., 0., 0., 1., 1.,
0., 1., 0., 1., 0., 1., 1., 1., 0., 0., 0., 0., 1., 1., 1., 0., 1., 1.,
0., 1., 0., 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1., 0., 0., 0., 0.,
1., 1., 0., 1., 1., 0., 1., 1., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 0., 0., 1., 1.,
1., 0., 1., 1., 1., 1., 1., 0., 0., 0., 1., 1., 1.])<xarray.Dataset> Size: 6MB
Dimensions: (chain: 4, draw: 1000, coeffs: 4, p_dim_0: 175)
Coordinates:
* chain (chain) int64 32B 0 1 2 3
* draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
* coeffs (coeffs) <U9 144B 'Intercept' 'x1' 'x2' 'x1:x2'
* p_dim_0 (p_dim_0) int64 1kB 0 1 2 3 4 5 6 7 ... 168 169 170 171 172 173 174
Data variables:
b (chain, draw, coeffs) float64 128kB -0.05825 1.681 ... -1.83 3.949
p (chain, draw, p_dim_0) float64 6MB 0.2694 0.5574 1.0 ... 0.5946 1.0
Attributes:
created_at: 2026-03-17T13:49:27.035568+00:00
arviz_version: 0.23.4
inference_library: pymc
inference_library_version: 5.28.1
sampling_time: 0.3888070583343506
tuning_steps: 1000array([0, 1, 2, 3])
array([ 0, 1, 2, ..., 997, 998, 999], shape=(1000,))
array(['Intercept', 'x1', 'x2', 'x1:x2'], dtype='<U9')
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,
154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167,
168, 169, 170, 171, 172, 173, 174])array([[[-0.05825, 1.68124, -1.05657, 2.45565],
[-0.66652, 0.81142, -1.08228, 2.31992],
[-0.33851, 1.39107, -0.95414, 2.96128],
...,
[-0.65008, 1.68385, -1.19153, 2.67779],
[-0.64458, 1.44224, -0.78403, 2.41496],
[-0.42461, 0.91177, -1.0866 , 2.43232]],
[[-0.67913, 1.84803, -1.40907, 2.91414],
[-0.32478, 1.96976, -1.41528, 2.89949],
[-0.62647, 1.44283, -0.99409, 2.6026 ],
...,
[-0.22938, 1.15244, -1.14299, 2.44078],
[-1.07189, 1.88163, -1.1652 , 2.89325],
[-1.05027, 1.81028, -1.42338, 3.02394]],
[[-0.16035, 1.4986 , -1.45435, 3.15169],
[-0.8516 , 0.92087, -0.83495, 2.02336],
[-0.13517, 1.82689, -1.33311, 2.91467],
...,
[-0.48101, 1.22307, -0.89896, 2.16073],
[-0.56095, 1.55998, -1.14908, 2.72059],
[-0.65067, 1.56294, -1.1283 , 3.47324]],
[[-0.05985, 1.15796, -1.50369, 2.67329],
[-1.58097, 1.78383, -0.97018, 2.47163],
[-0.74179, 1.34764, -1.45499, 3.13086],
...,
[-0.70375, 1.59326, -1.28051, 2.40135],
[-0.88113, 2.16069, -1.66928, 3.0906 ],
[-0.7924 , 2.05688, -1.83042, 3.94918]]], shape=(4, 1000, 4))array([[[0.26939, 0.55743, 1. , ..., 0.99523, 0.66339, 1. ],
[0.3015 , 0.72347, 1. , ..., 0.95597, 0.49159, 1. ],
[0.27789, 0.70815, 1. , ..., 0.99551, 0.55582, 1. ],
...,
[0.17754, 0.48494, 1. , ..., 0.99301, 0.53928, 1. ],
[0.19415, 0.46033, 1. , ..., 0.98944, 0.46679, 1. ],
[0.33849, 0.75972, 1. , ..., 0.97452, 0.55346, 1. ]],
[[0.16324, 0.50103, 1. , ..., 0.99528, 0.56986, 1. ],
[0.19929, 0.53569, 1. , ..., 0.99721, 0.65944, 1. ],
[0.20765, 0.5433 , 1. , ..., 0.99049, 0.50309, 1. ],
...,
[0.33623, 0.72802, 1. , ..., 0.98522, 0.6193 , 1. ],
[0.10876, 0.35145, 1. , ..., 0.99449, 0.43151, 1. ],
[0.12437, 0.45407, 1. , ..., 0.99371, 0.47554, 1. ]],
[[0.31959, 0.79845, 1. , ..., 0.9963 , 0.67931, 1. ],
[0.22899, 0.5362 , 1. , ..., 0.94592, 0.41512, 1. ],
[0.25263, 0.63147, 1. , ..., 0.99734, 0.68316, 1. ],
...,
[0.25305, 0.54616, 1. , ..., 0.97965, 0.52628, 1. ],
[0.20863, 0.56331, 1. , ..., 0.9929 , 0.54796, 1. ],
[0.21244, 0.71235, 1. , ..., 0.99743, 0.50126, 1. ]],
[[0.39759, 0.83148, 1. , ..., 0.98786, 0.70792, 1. ],
[0.06752, 0.18242, 1. , ..., 0.98349, 0.29095, 1. ],
[0.23069, 0.73655, 1. , ..., 0.99128, 0.53697, 1. ],
...,
[0.17751, 0.45333, 1. , ..., 0.9863 , 0.54559, 1. ],
[0.11465, 0.40406, 1. , ..., 0.99664, 0.57255, 1. ],
[0.15472, 0.69796, 1. , ..., 0.99878, 0.59458, 1. ]]],
shape=(4, 1000, 175))<xarray.Dataset> Size: 5MB
Dimensions: (chain: 4, draw: 1000, obs_dim_0: 75, p_dim_0: 75)
Coordinates:
* chain (chain) int64 32B 0 1 2 3
* draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
* obs_dim_0 (obs_dim_0) int64 600B 0 1 2 3 4 5 6 7 ... 68 69 70 71 72 73 74
* p_dim_0 (p_dim_0) int64 600B 0 1 2 3 4 5 6 7 ... 67 68 69 70 71 72 73 74
Data variables:
obs (chain, draw, obs_dim_0) int64 2MB 1 0 1 0 1 1 0 ... 1 0 1 0 1 1
p (chain, draw, p_dim_0) float64 2MB 1.0 0.03145 ... 1.0 0.7292
Attributes:
created_at: 2026-03-17T13:49:28.609920+00:00
arviz_version: 0.23.4
inference_library: pymc
inference_library_version: 5.28.1array([0, 1, 2, 3])
array([ 0, 1, 2, ..., 997, 998, 999], shape=(1000,))
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
72, 73, 74])array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
72, 73, 74])array([[[1, 0, 1, ..., 0, 1, 1],
[1, 0, 1, ..., 0, 1, 0],
[1, 0, 1, ..., 0, 1, 0],
...,
[1, 0, 1, ..., 0, 1, 1],
[1, 0, 1, ..., 0, 1, 1],
[1, 0, 0, ..., 0, 1, 0]],
[[1, 0, 1, ..., 0, 1, 1],
[1, 0, 0, ..., 0, 1, 1],
[1, 0, 0, ..., 0, 1, 1],
...,
[1, 0, 0, ..., 0, 1, 0],
[1, 0, 0, ..., 0, 1, 1],
[1, 0, 1, ..., 0, 1, 1]],
[[1, 0, 1, ..., 0, 1, 1],
[1, 0, 0, ..., 0, 1, 0],
[1, 0, 1, ..., 0, 1, 1],
...,
[1, 0, 1, ..., 0, 1, 1],
[1, 0, 0, ..., 0, 1, 0],
[1, 0, 1, ..., 0, 1, 0]],
[[1, 0, 0, ..., 0, 1, 0],
[1, 0, 0, ..., 0, 1, 0],
[1, 0, 0, ..., 0, 1, 1],
...,
[1, 0, 1, ..., 0, 1, 0],
[1, 0, 1, ..., 0, 1, 1],
[1, 0, 1, ..., 0, 1, 1]]], shape=(4, 1000, 75))array([[[1. , 0.03145, 0.72748, ..., 0.00001, 1. , 0.86324],
[1. , 0.00805, 0.46193, ..., 0.00012, 1. , 0.39431],
[1. , 0.0023 , 0.69835, ..., 0.00001, 1. , 0.62859],
...,
[1. , 0.01146, 0.59281, ..., 0. , 1. , 0.75616],
[1. , 0.00778, 0.61556, ..., 0.00002, 1. , 0.66668],
[1. , 0.00823, 0.54584, ..., 0.00009, 1. , 0.48263]],
[[1. , 0.01051, 0.58814, ..., 0. , 1. , 0.79174],
[1. , 0.01896, 0.68194, ..., 0. , 1. , 0.87579],
[1. , 0.00679, 0.59865, ..., 0.00001, 1. , 0.65586],
...,
[1. , 0.01569, 0.61362, ..., 0.00005, 1. , 0.65335],
[1. , 0.00464, 0.54122, ..., 0. , 1. , 0.71777],
[1. , 0.00492, 0.49981, ..., 0. , 1. , 0.6911 ]],
[[1. , 0.00526, 0.6804 , ..., 0. , 1. , 0.72459],
[1. , 0.01214, 0.44848, ..., 0.00016, 1. , 0.43773],
[1. , 0.01469, 0.72103, ..., 0. , 1. , 0.8589 ],
...,
[1. , 0.01992, 0.58075, ..., 0.00006, 1. , 0.65807],
[1. , 0.00823, 0.61105, ..., 0. , 1. , 0.71437],
[1. , 0.00058, 0.66556, ..., 0. , 1. , 0.55639]],
[[1. , 0.01932, 0.60941, ..., 0.00003, 1. , 0.68347],
[1. , 0.00639, 0.39823, ..., 0. , 1. , 0.61859],
[1. , 0.00253, 0.52173, ..., 0. , 1. , 0.52337],
...,
[1. , 0.0289 , 0.52285, ..., 0.00001, 1. , 0.75894],
[1. , 0.01356, 0.54551, ..., 0. , 1. , 0.84998],
[1. , 0.00103, 0.60865, ..., 0. , 1. , 0.72918]]],
shape=(4, 1000, 75))<xarray.Dataset> Size: 528kB
Dimensions: (chain: 4, draw: 1000)
Coordinates:
* chain (chain) int64 32B 0 1 2 3
* draw (draw) int64 8kB 0 1 2 3 4 5 ... 995 996 997 998 999
Data variables: (12/18)
index_in_trajectory (chain, draw) int64 32kB -1 3 4 5 2 ... -1 -6 -5 -2
max_energy_error (chain, draw) float64 32kB 0.7143 0.8379 ... -0.3487
lp (chain, draw) float64 32kB -54.98 -55.09 ... -55.64
tree_depth (chain, draw) int64 32kB 2 3 3 4 2 3 ... 3 3 4 3 3 2
divergences (chain, draw) int64 32kB 0 0 0 0 0 0 ... 0 0 0 0 0 0
step_size_bar (chain, draw) float64 32kB 0.4975 0.4975 ... 0.5673
... ...
perf_counter_diff (chain, draw) float64 32kB 7.137e-05 ... 6.663e-05
diverging (chain, draw) bool 4kB False False ... False False
process_time_diff (chain, draw) float64 32kB 7.1e-05 ... 6.6e-05
smallest_eigval (chain, draw) float64 32kB nan nan nan ... nan nan
reached_max_treedepth (chain, draw) bool 4kB False False ... False False
energy_error (chain, draw) float64 32kB -0.377 0.2598 ... -0.232
Attributes:
created_at: 2026-03-17T13:49:27.041467+00:00
arviz_version: 0.23.4
inference_library: pymc
inference_library_version: 5.28.1
sampling_time: 0.3888070583343506
tuning_steps: 1000array([0, 1, 2, 3])
array([ 0, 1, 2, ..., 997, 998, 999], shape=(1000,))
array([[-1, 3, 4, ..., -3, 2, 3],
[-3, 1, 4, ..., 6, 5, 2],
[-3, -3, 4, ..., 3, -8, 1],
[ 6, -3, 2, ..., -6, -5, -2]], shape=(4, 1000))array([[ 0.71426, 0.83795, -0.83111, ..., 0.46019, 0.76972, 1.04768],
[-0.1623 , 1.01006, -0.8886 , ..., -0.68675, 0.78752, -0.20025],
[ 0.27642, 0.49482, 1.26843, ..., -0.02894, 0.0413 , 2.42641],
[ 0.54456, 0.69548, -0.50553, ..., 0.74353, 0.3579 , -0.34866]],
shape=(4, 1000))array([[-54.98048, -55.08988, -54.14337, ..., -52.27666, -53.26737, -54.04426],
[-52.79676, -54.86587, -52.27038, ..., -52.91864, -54.16457, -53.16204],
[-55.10906, -54.8415 , -54.83688, ..., -52.10868, -52.0458 , -55.81293],
[-56.0737 , -61.49927, -54.2691 , ..., -52.84386, -54.89327, -55.64337]],
shape=(4, 1000))array([[2, 3, 3, ..., 3, 3, 2],
[3, 3, 3, ..., 3, 3, 2],
[2, 4, 3, ..., 4, 4, 4],
[3, 3, 2, ..., 3, 3, 2]], shape=(4, 1000))array([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]], shape=(4, 1000))array([[0.4975 , 0.4975 , 0.4975 , ..., 0.4975 , 0.4975 , 0.4975 ],
[0.55867, 0.55867, 0.55867, ..., 0.55867, 0.55867, 0.55867],
[0.53781, 0.53781, 0.53781, ..., 0.53781, 0.53781, 0.53781],
[0.56734, 0.56734, 0.56734, ..., 0.56734, 0.56734, 0.56734]],
shape=(4, 1000))array([[ 3., 7., 7., ..., 7., 7., 3.],
[ 7., 7., 7., ..., 7., 7., 3.],
[ 3., 15., 7., ..., 15., 15., 11.],
[ 7., 7., 3., ..., 7., 7., 3.]], shape=(4, 1000))array([[0.76139, 0.87891, 0.91453, ..., 0.78181, 0.78129, 0.67325],
[0.99838, 0.70948, 0.97681, ..., 1. , 0.78863, 0.98459],
[0.8558 , 0.90322, 0.70386, ..., 0.99929, 0.98286, 0.47758],
[0.74635, 0.83899, 0.86285, ..., 0.74263, 0.89061, 0.99737]],
shape=(4, 1000))array([[0.58754, 0.58754, 0.58754, ..., 0.58754, 0.58754, 0.58754],
[0.58001, 0.58001, 0.58001, ..., 0.58001, 0.58001, 0.58001],
[0.58205, 0.58205, 0.58205, ..., 0.58205, 0.58205, 0.58205],
[0.59406, 0.59406, 0.59406, ..., 0.59406, 0.59406, 0.59406]],
shape=(4, 1000))array([[59.11869, 56.63231, 56.16869, ..., 53.28841, 54.81568, 55.42176],
[53.12271, 56.20037, 56.37673, ..., 55.99144, 55.28081, 54.64459],
[56.06345, 58.0166 , 57.42368, ..., 52.63637, 52.89784, 57.92707],
[57.75146, 62.71908, 62.59779, ..., 53.84075, 55.30436, 56.47349]],
shape=(4, 1000))array([[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]], shape=(4, 1000))array([[273481.21629, 273481.21639, 273481.21655, ..., 273481.38278,
273481.38294, 273481.3831 ],
[273481.22044, 273481.22073, 273481.22089, ..., 273481.37818,
273481.37834, 273481.37849],
[273481.22544, 273481.22554, 273481.22584, ..., 273481.39374,
273481.39402, 273481.39431],
[273481.22432, 273481.22449, 273481.22471, ..., 273481.386 ,
273481.38615, 273481.38631]], shape=(4, 1000))array([[0.00007, 0.00013, 0.00014, ..., 0.00014, 0.00013, 0.00007],
[0.00026, 0.00014, 0.00014, ..., 0.00014, 0.00013, 0.00007],
[0.00007, 0.00028, 0.00014, ..., 0.00026, 0.00026, 0.0002 ],
[0.00014, 0.00018, 0.00007, ..., 0.00013, 0.00013, 0.00007]],
shape=(4, 1000))array([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]], shape=(4, 1000))array([[0.00007, 0.00013, 0.00013, ..., 0.00014, 0.00013, 0.00007],
[0.00025, 0.00014, 0.00014, ..., 0.00013, 0.00013, 0.00007],
[0.00007, 0.00028, 0.00014, ..., 0.00026, 0.00026, 0.0002 ],
[0.00014, 0.00018, 0.00007, ..., 0.00013, 0.00013, 0.00007]],
shape=(4, 1000))array([[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]], shape=(4, 1000))array([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]], shape=(4, 1000))array([[-0.37703, 0.25981, -0.11623, ..., 0.08085, 0.11853, 0.40197],
[-0.0515 , 0.49763, -0.53393, ..., -0.503 , 0.08932, -0.20025],
[ 0.19402, -0.09699, -0.15399, ..., -0.02229, 0.01047, 1.06569],
[ 0.52102, 0.69548, -0.50553, ..., 0.29807, 0.3579 , -0.23201]],
shape=(4, 1000))<xarray.Dataset> Size: 3kB
Dimensions: (obs_dim_0: 175)
Coordinates:
* obs_dim_0 (obs_dim_0) int64 1kB 0 1 2 3 4 5 6 ... 169 170 171 172 173 174
Data variables:
obs (obs_dim_0) int64 1kB 0 0 1 0 1 1 1 0 1 0 ... 1 1 1 1 0 0 0 1 1 1
Attributes:
created_at: 2026-03-17T13:49:27.043646+00:00
arviz_version: 0.23.4
inference_library: pymc
inference_library_version: 5.28.1array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,
154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167,
168, 169, 170, 171, 172, 173, 174])array([0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1,
0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0,
0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0,
1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1,
1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1,
0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0,
1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1,
1, 0, 0, 0, 1, 1, 1])<xarray.Dataset> Size: 10kB
Dimensions: (X_dim_0: 175, X_dim_1: 4, y_dim_0: 175)
Coordinates:
* X_dim_0 (X_dim_0) int64 1kB 0 1 2 3 4 5 6 7 ... 168 169 170 171 172 173 174
* X_dim_1 (X_dim_1) int64 32B 0 1 2 3
* y_dim_0 (y_dim_0) int64 1kB 0 1 2 3 4 5 6 7 ... 168 169 170 171 172 173 174
Data variables:
X (X_dim_0, X_dim_1) float64 6kB 1.0 -0.8999 -0.1756 ... -3.397 11.63
y (y_dim_0) float64 1kB 0.0 0.0 1.0 0.0 1.0 ... 0.0 0.0 1.0 1.0 1.0
Attributes:
created_at: 2026-03-17T13:49:27.044271+00:00
arviz_version: 0.23.4
inference_library: pymc
inference_library_version: 5.28.1array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,
154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167,
168, 169, 170, 171, 172, 173, 174])array([0, 1, 2, 3])
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,
154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167,
168, 169, 170, 171, 172, 173, 174])array([[ 1. , -0.89989, -0.17557, 0.158 ],
[ 1. , -1.67942, -0.60079, 1.00899],
[ 1. , 2.68515, 3.60742, 9.68646],
[ 1. , 2.69419, -2.932 , -7.89937],
[ 1. , -1.72056, -1.15562, 1.98832],
[ 1. , 1.39256, 0.55746, 0.7763 ],
[ 1. , -2.4052 , -0.71888, 1.72906],
[ 1. , 0.70277, -0.99446, -0.69887],
[ 1. , 2.46359, 1.61154, 3.97019],
[ 1. , 0.43956, -0.62857, -0.27629],
[ 1. , 3.46658, 2.28206, 7.91093],
[ 1. , 1.59779, 3.26548, 5.21756],
[ 1. , 2.64752, 0.04849, 0.12838],
[ 1. , -1.64511, 0.432 , -0.71068],
[ 1. , -2.20534, -1.1967 , 2.63914],
[ 1. , 4.10149, 2.76474, 11.33953],
[ 1. , 3.829 , 1.35041, 5.17071],
[ 1. , -1.4024 , -4.3979 , 6.1676 ],
[ 1. , 2.42934, -0.08882, -0.21577],
[ 1. , 0.25195, -1.17936, -0.29714],
...
[ 1. , 1.04192, 2.84321, 2.96241],
[ 1. , -0.24823, -0.43724, 0.10854],
[ 1. , 1.28262, 2.75645, 3.53547],
[ 1. , -0.06483, 0.35711, -0.02315],
[ 1. , -2.18112, 0.40746, -0.88873],
[ 1. , 2.07077, 2.65026, 5.48806],
[ 1. , 0.24363, -0.27189, -0.06624],
[ 1. , 2.66243, -0.10489, -0.27925],
[ 1. , 1.57252, -1.21139, -1.90494],
[ 1. , -2.86799, -2.60529, 7.47194],
[ 1. , -2.31482, -2.10299, 4.86805],
[ 1. , -4.32753, -0.69455, 3.00569],
[ 1. , 0.80815, 4.85354, 3.92238],
[ 1. , -1.61406, -0.81406, 1.31395],
[ 1. , 2.13768, -0.37337, -0.79814],
[ 1. , 0.31893, 0.77569, 0.24739],
[ 1. , -0.02196, 3.40024, -0.07466],
[ 1. , 1.65199, 0.87413, 1.44405],
[ 1. , 0.15744, -0.70449, -0.11092],
[ 1. , -3.4222 , -3.39729, 11.62622]])array([0., 0., 1., 0., 1., 1., 1., 0., 1., 0., 1., 1., 0., 0., 1., 1., 1., 1.,
1., 0., 0., 0., 1., 1., 0., 1., 1., 1., 1., 0., 0., 1., 1., 1., 0., 0.,
1., 1., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 1.,
1., 0., 1., 0., 0., 0., 1., 0., 1., 1., 0., 0., 0., 1., 0., 1., 0., 0.,
1., 0., 1., 1., 0., 1., 1., 0., 1., 1., 1., 0., 0., 1., 0., 0., 1., 1.,
0., 1., 0., 1., 0., 1., 1., 1., 0., 0., 0., 0., 1., 1., 1., 0., 1., 1.,
0., 1., 0., 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1., 0., 0., 0., 0.,
1., 1., 0., 1., 1., 0., 1., 1., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 0., 0., 1., 1.,
1., 0., 1., 1., 1., 1., 1., 0., 0., 0., 1., 1., 1.]) mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
obs[0] 1.000 0.000 1.000 1.000 0.000 NaN 4000.0 4000.0 NaN
obs[1] 0.017 0.128 0.000 0.000 0.002 0.008 3965.0 3965.0 1.0
obs[2] 0.558 0.497 0.000 1.000 0.008 0.001 3813.0 3813.0 1.0
obs[3] 0.000 0.000 0.000 0.000 0.000 NaN 4000.0 4000.0 NaN
obs[4] 0.245 0.430 0.000 1.000 0.007 0.004 3894.0 3894.0 1.0
... ... ... ... ... ... ... ... ... ...
p[70] 0.181 0.138 0.007 0.444 0.002 0.002 3667.0 3370.0 1.0
p[71] 0.961 0.033 0.901 0.999 0.001 0.001 2179.0 2413.0 1.0
p[72] 0.000 0.000 0.000 0.000 0.000 0.000 1551.0 1793.0 1.0
p[73] 1.000 0.000 1.000 1.000 0.000 0.000 1786.0 1980.0 1.0
p[74] 0.676 0.102 0.486 0.865 0.002 0.001 3477.0 3181.0 1.0
[150 rows x 9 columns]
<xarray.DataArray 'p' (p_dim_0: 175)> Size: 1kB
array([0.22354, 0.57537, 1. , 0.00002, 0.96285, 0.93441, 0.76759, 0.44544,
0.99993, 0.51717, 1. , 0.99992, 0.96277, 0.00761, 0.98443, 1. ,
1. , 1. , 0.91063, 0.59602, 0.01477, 0.67008, 1. , 0.91286,
0.01492, 0.99939, 0.74472, 0.99996, 0.97781, 0.49015, 0.00133, 0.54671,
0.81544, 0.99442, 0.14779, 0. , 0.99993, 0.58088, 0.40767, 0. ,
0.93578, 0.47611, 0.008 , 0.33411, 0.07102, 0.00032, 0.50066, 0. ,
0.67693, 0.01106, 0.0376 , 1. , 0. , 0.90967, 0.98833, 0.18546,
0.89679, 0.18902, 0. , 0.11247, 0.57731, 0.00002, 0.99812, 0.96232,
0.0004 , 0.05597, 0. , 0.56356, 0.0745 , 1. , 0.56973, 0.0002 ,
1. , 0.04337, 0.99839, 0.84047, 0.00035, 1. , 0.70715, 0.00002,
0.93892, 0.25884, 0.55489, 0.11686, 0.00011, 0.50344, 0. , 0. ,
1. , 0.24473, 0.00055, 0.82414, 0.00132, 0.81201, 0. , 0.99999,
0.7419 , 0.99982, 0.09778, 0.06814, 0.00011, 0.21825, 1. , 0.98934,
0.9969 , 0.03411, 0.1611 , 0.43356, 0.00738, 1. , 0.03185, 1. ,
0.69623, 0.99907, 0.99812, 1. , 0.01027, 1. , 0.63142, 0.98525,
0.00432, 0.96826, 0.33577, 0.00199, 0.584 , 0.00224, 0.94832, 0.72957,
0.01221, 1. , 0.66046, 0.38862, 0.88489, 0.47824, 0.59226, 0. ,
0.77121, 0.59311, 0.00975, 0.99965, 0.13918, 0.13581, 0.30046, 0.03326,
0.99824, 0. , 0. , 0.98214, 0.24219, 1. , 0.99604, 0.99658,
0.6764 , 0.99849, 0.77462, 0.99031, 0.46254, 0.99807, 0.24091, 0.00289,
0.99998, 0.47801, 0.92261, 0.14541, 1. , 0.99993, 0.81985, 0.98312,
0.80224, 0.69729, 0.40399, 0.01401, 0.98495, 0.545 , 1. ])
Coordinates:
* p_dim_0 (p_dim_0) int64 1kB 0 1 2 3 4 5 6 7 ... 168 169 170 171 172 173 174
<xarray.DataArray 'p' (p_dim_0: 75)> Size: 600B
array([1. , 0.01609, 0.57241, 0. , 0.25357, 1. , 0. , 0.00318,
0.31829, 0.00197, 0.50455, 0.99644, 0.20515, 0.31078, 0.49424, 0.99998,
0.99991, 0.75426, 0.00075, 0.65654, 0.89401, 0.99977, 0.43345, 0.98165,
0.56451, 0.00047, 0.97465, 0.63495, 0.00033, 0.19531, 0.00596, 0.88232,
0.00002, 0.85128, 0.00041, 0.38313, 0. , 0.61304, 0.99987, 0.29577,
1. , 0.13034, 0.0919 , 0.00002, 0.08267, 0.43678, 0.00265, 0.98128,
0.00486, 0. , 0.15431, 0.13207, 0.64057, 0.00584, 0.99999, 0. ,
0.24031, 0.0774 , 0.00348, 0.74317, 0.14325, 0.38558, 0.99972, 0.00002,
0.9803 , 0.9225 , 0.23294, 0.37334, 0.00075, 0.9937 , 0.18148, 0.96076,
0.00006, 1. , 0.67627])
Coordinates:
* p_dim_0 (p_dim_0) int64 600B 0 1 2 3 4 5 6 7 8 ... 67 68 69 70 71 72 73 74
Test data:
0.9377777777777778
Training data:
Sta 663 - Spring 2026