Example of creating and running a Metropolis-Hastings Sampler
In this notebook we will setup a sampler to sample a 2D Gaussian distribution with mean \(\bar{x} = 2, \bar{y} = 5\) and variance \(\sigma_x^2 = 1\), \(\sigma_y^2 = 2\); \(\sigma^2_{xy} = \sigma^2_{yx} = 0\), using a prior that is uniform over \(x \in [-20, 20)\), \(y\in[-40, 40)\).
We will use Python’s multiprocessing to evolve 12 chains using 4 cores. We will demonstrate that resuming a sampler from it’s state attribute yields the same results as if we had run continuously. We then make a plot of the posterior. Finally, we make an animation showing how the 12 chains moved.
%matplotlib notebook
from matplotlib import pyplot
import numpy
import epsie
from epsie.samplers import MetropolisHastingsSampler
import multiprocessing
Create the model to sample
Note: Below we create a class with several functions to draw samples from the prior and to evaluate the log posterior. This isn’t strictly necessary. The only thing the Sampler really requires is a function that it can pass keyword arguments to and get back a tuple of (log likelihood, log prior). However, setting things up as a class will make it convenient to, e.g., draw random samples from the prior for the starting positiions, as well as plot the model later on.
from scipy import stats
class Model(object):
def __init__(self):
# we'll use a 2D Gaussian for the likelihood distribution
self.params = ['x', 'y']
self.mean = [2., 5.]
self.cov = [[1., 0.], [0., 2.]]
self.likelihood_dist = stats.multivariate_normal(mean=self.mean,
cov=self.cov)
# we'll just use a uniform prior
self.prior_bounds = {'x': (-20., 20.),
'y': (-40., 40.)}
xmin = self.prior_bounds['x'][0]
dx = self.prior_bounds['x'][1] - xmin
ymin = self.prior_bounds['y'][0]
dy = self.prior_bounds['y'][1] - ymin
self.prior_dist = {'x': stats.uniform(xmin, dx),
'y': stats.uniform(ymin, dy)}
def prior_rvs(self, size=None):
return {p: self.prior_dist[p].rvs(size=size)
for p in self.params}
def logprior(self, **kwargs):
return sum([self.prior_dist[p].logpdf(kwargs[p]) for p in self.params])
def loglikelihood(self, **kwargs):
return self.likelihood_dist.logpdf([kwargs[p] for p in self.params])
def __call__(self, **kwargs):
logp = self.logprior(**kwargs)
if logp == -numpy.inf:
logl = None
else:
logl = self.loglikelihood(**kwargs)
return logl, logp
model = Model()
Setup and run the sampler
Create a pool of 4 parallel processes, then initialize the sampler using the model we created above.
nchains = 12
nprocs = 4
pool = multiprocessing.Pool(nprocs)
#pool = None
sampler = MetropolisHastingsSampler(model.params, model, nchains, pool=pool)
Now set the starting positions of the chains by drawing random variates from the model’s prior.
sampler.start_position = model.prior_rvs(size=nchains)
Let’s run it!
This will evolve each chain in the collection by 256 steps. This is parallelized over the pool of processes.
sampler.run(256)
Extract results
We can get the history of all of the chains using the .positions
attribute. This will return a numpy structured array in which the fields
are the parameters names (in this case, 'x'
and 'y'
), and with
shape nchains x niterations
:
positions = sampler.positions
print('sampler.positions: {}'.format(type(positions)))
print('with fields: {}'.format(positions.dtype.names))
print('and shape:', positions.shape)
sampler.positions: <class 'numpy.ndarray'>
with fields: ('x', 'y')
and shape: (12, 256)
This (or any structured array returned by epsie) can be turned into a
dictionary of arrays, where the keys are the parameter names, using
epsie.array2dict
:
positions = epsie.array2dict(sampler.positions)
print('sampler.positions: {} with keys/values:'.format(type(positions)))
for param in sorted(positions):
print('"{}": {} with shape {}'.format(param, type(positions[param]), positions[param].shape))
sampler.positions: <class 'dict'> with keys/values:
"x": <class 'numpy.ndarray'> with shape (12, 256)
"y": <class 'numpy.ndarray'> with shape (12, 256)
We can also access the history of log likelihoods and log priors using
sampler.stats
, as well as the acceptance ratios and which jumps were
accepted with sampler.acceptance
:
stats = sampler.stats
print('sampler.stats: {}'.format(type(stats)))
print('with fields: {}'.format(stats.dtype.names))
print('and shape:', stats.shape)
sampler.stats: <class 'numpy.ndarray'>
with fields: ('logl', 'logp')
and shape: (12, 256)
acceptance = sampler.acceptance
print('sampler.acceptance: {}'.format(type(acceptance)))
print('with fields: {}'.format(acceptance.dtype.names))
print('and shape:', acceptance.shape)
sampler.acceptance: <class 'numpy.ndarray'>
with fields: ('acceptance_ratio', 'accepted')
and shape: (12, 256)
If the model returned “blobs” (i.e., the model returns a dictionary
along with the logl and logp), then we can also access those using
sampler.blobs
. Similar to positions
, this would also be a
dictionary of arrays with keys given by the names in the dictionary the
model returned. However, because our model above returns no blobs, in
this case we just get None
:
print(sampler.blobs)
None
The individual chains can be accessed using the .chains
attribute:
sampler.chains
[<epsie.chain.chain.Chain at 0x1226544d0>,
<epsie.chain.chain.Chain at 0x122654c90>,
<epsie.chain.chain.Chain at 0x122654690>,
<epsie.chain.chain.Chain at 0x12265a210>,
<epsie.chain.chain.Chain at 0x122654ed0>,
<epsie.chain.chain.Chain at 0x12265a950>,
<epsie.chain.chain.Chain at 0x12265c850>,
<epsie.chain.chain.Chain at 0x12266fcd0>,
<epsie.chain.chain.Chain at 0x12265c890>,
<epsie.chain.chain.Chain at 0x12265c110>,
<epsie.chain.chain.Chain at 0x1226549d0>,
<epsie.chain.chain.Chain at 0x12267cd50>]
Resume from a state
The sampler can be checkpointed by getting its current state with
sampler.state
. To demonstrate this, we’ll get the current state of
the sampler, then run it for another set of iterations. We’ll then
create a new sampler, and set it’s state to the state we obtained from
first sampler. Running the same sampler for the same number of
iterations should produce the same results.
# get the current state
state = sampler.state
# now advance the sampler for another 256 iterations
sampler.run(256)
# create a new sampler, but set it's state to what the original sampler's was after the first 250 iterations
sampler2 = MetropolisHastingsSampler(model.params, model, nchains, pool=pool)
sampler2.set_state(state)
# now advance the new sampler for 256 iterations
# note that we don't have to run set_start first, since the starting positions have been set by set_start
sampler2.run(256)
# compare the current results; they should be the same between sampler2 and sampler
print('x:', (sampler.current_positions['x'] == sampler2.current_positions['x']).all())
print('y:', (sampler.current_positions['y'] == sampler2.current_positions['y']).all())
print('logl:', (sampler.current_stats['logl'] == sampler2.current_stats['logl']).all())
print('logp:', (sampler.current_stats['logp'] == sampler2.current_stats['logp']).all())
print('acceptance ratio:',
(sampler.acceptance['acceptance_ratio'][:,-1] == sampler2.acceptance['acceptance_ratio'][:,-1]).all())
print('accepted:',
(sampler.acceptance['accepted'][:,-1] == sampler2.acceptance['accepted'][:,-1]).all())
x: True
y: True
logl: True
logp: True
acceptance ratio: True
accepted: True
Clearing memory and continuing
The history of results in memory can be cleared using .clear()
.
Running the sampler after a clear yields the same results as if no clear
had been done. This is useful for keeping memory usage down: you can
dump results to a file after some number of iterations, clear, then
continue.
To demonstrate this, we’ll clear sampler2
, then run both sampler
and sampler2
for another 512 iterations. We’ll then compare the
current results; they should be the same.
sampler2.clear()
sampler.run(512)
sampler2.run(512)
# compare the current results; they should be the same between sampler2 and sampler
print('x:', (sampler.current_positions['x'] == sampler2.current_positions['x']).all())
print('y:', (sampler.current_positions['y'] == sampler2.current_positions['y']).all())
print('logl:', (sampler.current_stats['logl'] == sampler2.current_stats['logl']).all())
print('logp:', (sampler.current_stats['logp'] == sampler2.current_stats['logp']).all())
print('acceptance ratio:',
(sampler.acceptance['acceptance_ratio'][:,-1] == sampler2.acceptance['acceptance_ratio'][:,-1]).all())
print('accepted:',
(sampler.acceptance['accepted'][:,-1] == sampler2.acceptance['accepted'][:,-1]).all())
x: True
y: True
logl: True
logp: True
acceptance ratio: True
accepted: True
Plot the posterior
Let’s create a scatter plot of the posterior. For this, we’ll need to throw out some earlier samples for the burn-in period; we’ll just assume the first-half of the chains were burn in. We also need to calculate the autocorrelation length of the chains in order to get independent samples.
def calculate_acf(data):
"""Calculates the autocorrelation of some data"""
# zero the mean
data = data - data.mean()
# zero-pad to 2 * nearest power of 2
newlen = int(2**(1+numpy.ceil(numpy.log2(len(data)))))
x = numpy.zeros(newlen)
x[:len(data)] = data[:]
# correlate
acf = numpy.correlate(x, x, mode='full')
# drop corrupted region
acf = acf[len(acf)//2:]
# normalize
acf /= acf[0]
return acf
def calculate_acl(data):
"""Calculates the autocorrelation length of some data.
Algorithm used is from:
N. Madras and A.D. Sokal, J. Stat. Phys. 50, 109 (1988).
"""
# calculate the acf
acf = calculate_acf(data)
# now the ACL: Following from Sokal, this is estimated
# as the first point where M*tau[k] <= k, where
# tau = 2*cumsum(acf) - 1, and M is a tuneable parameter,
# generally chosen to be = 5 (which we use here)
m = 5
cacf = 2.*numpy.cumsum(acf) - 1.
win = m * cacf <= numpy.arange(len(cacf))
if win.any():
acl = int(numpy.ceil(cacf[numpy.where(win)[0][0]]))
else:
# data is too short to estimate the ACL, just choose
# the length of the data
acl = len(data)
return acl
Since the chains are completely independent of each other, we can calculate the ACL separately for each chain. However, if you’d like to be more conservative, you can also just take the max over all of the chains.
# get the samples; recall that this is a dictionary of
# nchains x niterations arrays for each parameter
samples = sampler.positions
# as we said above, we'll assume the first half
# of the chain was burn in
burnin_iter = sampler.niterations // 2
# set up arrays to store the ACL of each chain and
# the thinned chains
acls = numpy.zeros(nchains, dtype=int)
thinned_arrays = {'x': [], 'y': []}
# cycle over the chains, calculating the ACLs and thinning them
for ii in range(nchains):
# get the second half of the chains
sx = samples['x'][ii, burnin_iter:]
sy = samples['y'][ii, burnin_iter:]
# compute the acl for each parameter
aclx = calculate_acl(sx)
acly = calculate_acl(sy)
acl = max(aclx, acly)
acls[ii] = acl
# note that we'll thin the arrays starting from the
# end to get the lastest results
thinned_arrays['x'].append(sx[::-1][::acl][::-1])
thinned_arrays['y'].append(sy[::-1][::acl][::-1])
# the ACL of each chain:
print(acls)
[23 10 6 9 13 12 19 10 18 24 24 19]
# create a flattened posterior array
posterior = {'x': numpy.concatenate(thinned_arrays['x']),
'y': numpy.concatenate(thinned_arrays['y'])}
print("Number of independent samples:", posterior['x'].size)
Number of independent samples: 480
# histogram them
fig, ax = pyplot.subplots()
ax.hist(posterior['x'], bins=10, histtype='step', label='x')
ax.hist(posterior['y'], bins=10, histtype='step', label='y')
ax.legend()
fig.show()
<IPython.core.display.Javascript object>
Let’s check the mean and variance of our estimated posterior. These should be \(\bar{x} \approx 2, \sigma^2_{x} \approx 1\) and \(\bar{y} \approx 5, \sigma^2_{y} \approx 2\):
for param in posterior:
s = posterior[param]
print(param, 'mean: {}'.format(s.mean()), 'var: {}'.format(s.var()))
x mean: 1.99478140402679 var: 0.8584428515335814
y mean: 5.178616585856629 var: 1.9987894875419652
The values are close, but not exact. This isn’t too surprising since we only have \(\mathcal{O}(100)\) samples. To get more samples, the sampler can be run longer.
Create an animation of the results
To visualize the results, we’ll create an animation showing how the chains evolved. We’ll do this by plotting one point for each chain, with each frame in the animation representing a single iteration.
Note: To keep file size down, the animation has not been created for the version of this notebook uploaded to the repository.
from matplotlib import animation
# Prepare an array to create a density map showing the shape of the model posterior
npts = 100
xmean, ymean = model.likelihood_dist.mean
xsig = model.likelihood_dist.cov[0,0]**0.5
ysig = model.likelihood_dist.cov[1,1]**0.5
X, Y = numpy.mgrid[xmean-3*xsig:xmean+3*xsig:complex(0, npts),
ymean-3*ysig:ymean+3*ysig:complex(0, npts)]
Z = numpy.zeros(X.shape)
for ii in range(Z.shape[0]):
for jj in range(Z.shape[1]):
logl, logp = model(x=X[ii,jj], y=Y[ii,jj])
Z[ii, jj] = numpy.exp(logl+logp)
# we'll just animate the first 200 iterations; change this to
# nframes = xdata.shape[1]
# if you want to see all iterations
nframes = 200
fig, ax = pyplot.subplots()
positions = sampler.positions
xdata = positions['x']
ydata = positions['y']
# Plot density map showing the shape of the true posterior density
ax.imshow(numpy.rot90(Z), extent=[X.min(), X.max(), Y.min(), Y.max()],
aspect='auto', cmap='binary', zorder=-3)
# Put an x at the maximum posterior point
ax.scatter(model.mean[0], model.mean[1], marker='x', color='w', s=10, zorder=-2)
ax.set_xlabel('x')
ax.set_ylabel('y')
# create the scatter points
ptsize = 60
# we'll include the last bufferlen number of steps a chain visited, having the size and transparency
# exponentially damped with each new frame
bufferlen = 16
alphas = numpy.exp(-4*(numpy.arange(bufferlen))/float(bufferlen))
sizes = ptsize * alphas
#colors = numpy.array(['C{}'.format(ii) for ii in range(nchains)])
colors = numpy.arange(nchains)
plts = [ax.scatter(xdata[:, bufferlen-ii-1], ydata[:, bufferlen-ii-1], c=colors, s=sizes[ii],
edgecolors='w', linewidths=0.5,
alpha=alphas[ii], zorder=bufferlen-ii, marker='s' if ii==0 else 'o', cmap='jet')
for ii in range(bufferlen)]
# put a + showing the average of the chain positions at the current iteration
meanplt = ax.scatter(xdata[:,0].mean(), ydata[:,0].mean(), marker='P', c='w', edgecolors='k', linewidths=0.5,
zorder=bufferlen+1)
# add some text giving the iteration
itertxt = 'Iteration {}'
txt = ax.annotate(itertxt.format(1), (0.03, 0.94), xycoords='axes fraction')
def animate(ii):
txt.set_text(itertxt.format(ii+1))
for jj,plt in enumerate(plts):
plt.set_offsets(numpy.array([xdata[:, max(ii-jj, 0)], ydata[:, max(ii-jj, 0)]]).T)
meanplt.set_offsets([xdata[:,ii].mean(), ydata[:,ii].mean()])
# zoom in as it narrows on the result
istart = max(ii-bufferlen, 0)
# smooth it out a bit
xmin = numpy.array([xdata[:, max(istart-kk, 0):].min() for kk in range(50)]).mean()
xmax = numpy.array([xdata[:, max(istart-kk, 0):].max() for kk in range(50)]).mean()
ymin = numpy.array([ydata[:, max(istart-kk, 0):].min() for kk in range(50)]).mean()
ymax = numpy.array([ydata[:, max(istart-kk, 0):].max() for kk in range(50)]).mean()
ax.set_xlim((1.1 if xmin < 1 else 0.9)*xmin, (0.9 if xmax < 1 else 1.1)*xmax)
ax.set_ylim((1.1 if ymin < 1 else 0.9)*ymin, (0.9 if ymax < 1 else 1.1)*ymax)
ani = animation.FuncAnimation(fig, animate, frames=nframes, interval=160, blit=True)
Save the animation:
ani.save('chain_animation.mp4')
The result:
%%HTML
<video width="640" height="480" controls>
<source src="chain_animation.mp4" type="video/mp4">
</video>