Use the Veitch et al. Adaptive Normal proposal with the Parallel Tempered Sampler
In this notebook we will test the adaptive normal proposal on a 2D Gaussian distribution with mean \(\bar{x} = 2, \bar{y} = 5\) and variance \(\sigma_x^2 = 1\), \(\sigma_y^2 = 2\); \(\sigma^2_{xy} = \sigma^2_{yx} = 0\), using a prior that is uniform over \(x \in [-20, 20)\) and \(y \in [-40, 40)\). For this we will use the parallel tempered sampler.
%matplotlib notebook
from matplotlib import pyplot
import numpy
import epsie
from epsie import make_betas_ladder
from epsie.samplers import ParallelTemperedSampler
from epsie.proposals import AdaptiveNormal
import multiprocessing
Create the model to sample
Note: Below we create a class with several functions to draw samples from the prior and to evaluate the log posterior. This isn’t strictly necessary. The only thing the Sampler really requires is a function that it can pass keyword arguments to and get back a tuple of (log likelihood, log prior). However, setting things up as a class will make it convenient to, e.g., draw random samples from the prior for the starting positiions, as well as plot the model later on.
from scipy import stats
class Model(object):
def __init__(self):
# we'll use a 2D Gaussian for the likelihood distribution
self.params = ['x', 'y']
self.mean = [2., 5.]
self.cov = [[1., 0.], [0., 2.]]
self.likelihood_dist = stats.multivariate_normal(mean=self.mean,
cov=self.cov)
# we'll just use a uniform prior
self.prior_bounds = {'x': (-20., 20.),
'y': (-40., 40.)}
xmin = self.prior_bounds['x'][0]
dx = self.prior_bounds['x'][1] - xmin
ymin = self.prior_bounds['y'][0]
dy = self.prior_bounds['y'][1] - ymin
self.prior_dist = {'x': stats.uniform(xmin, dx),
'y': stats.uniform(ymin, dy)}
def prior_rvs(self, size=None, shape=None):
return {p: self.prior_dist[p].rvs(size=size).reshape(shape)
for p in self.params}
def logprior(self, **kwargs):
return sum([self.prior_dist[p].logpdf(kwargs[p]) for p in self.params])
def loglikelihood(self, **kwargs):
return self.likelihood_dist.logpdf([kwargs[p] for p in self.params])
def __call__(self, **kwargs):
logp = self.logprior(**kwargs)
if logp == -numpy.inf:
logl = None
else:
logl = self.loglikelihood(**kwargs)
return logl, logp
model = Model()
Setup the proposal
We’ll setup the adaptive normal proposal to run for 512 iterations.
adaptation_duration = 512
prior_widths = {p: abs(bnds[1] - bnds[0]) for p, bnds in model.prior_bounds.items()}
proposal = AdaptiveNormal(model.params, prior_widths, adaptation_duration=adaptation_duration)
Setup and run the sampler
Create a pool of 4 parallel processes, then initialize the sampler using the model we created above.
nchains = 12
ntemps = 3
swap_interval = 4
nprocs = 4
pool = multiprocessing.Pool(nprocs)
betas = make_betas_ladder(ntemps, 1e5)
sampler = ParallelTemperedSampler(model.params, model, nchains,
proposals=[proposal],
betas=betas, swap_interval=swap_interval,
pool=pool)
Now set the starting positions of the chains by drawing random variates from the model’s prior.
sampler.start_position = model.prior_rvs(size=nchains*ntemps, shape=(ntemps, nchains))
Let’s run it!
This will evolve each chain in the collection by 256 steps. This is parallelized over the pool of processes.
sampler.run(256)
Let’s check how the standard deviation has changed. We started with:
print(proposal.std)
[2.7576 5.5152]
Now we have:
# the current covariance of each
for ci, ptchain in enumerate(sampler.chains):
print('==== Chain {} ===='.format(ci))
for tk, c in enumerate(ptchain.chains):
print("Temp {}:".format(tk), c.proposal_dist.proposals[0].std)
==== Chain 0 ====
Temp 0: [1.95442018 3.90884035]
Temp 1: [24.14693992 48.29387985]
Temp 2: [22.47629237 44.95258473]
==== Chain 1 ====
Temp 0: [1.50320855 3.00641709]
Temp 1: [21.79252365 43.58504731]
Temp 2: [23.0946761 46.1893522]
==== Chain 2 ====
Temp 0: [1.92168189 3.84336378]
Temp 1: [20.02595837 40.05191674]
Temp 2: [25.18863185 50.3772637 ]
==== Chain 3 ====
Temp 0: [1.49592492 2.99184984]
Temp 1: [21.30158006 42.60316012]
Temp 2: [23.09363368 46.18726736]
==== Chain 4 ====
Temp 0: [1.92178972 3.84357944]
Temp 1: [19.57048537 39.14097073]
Temp 2: [23.14193028 46.28386057]
==== Chain 5 ====
Temp 0: [2.14013127 4.28026254]
Temp 1: [24.83317471 49.66634941]
Temp 2: [21.50729733 43.01459466]
==== Chain 6 ====
Temp 0: [2.78252961 5.56505922]
Temp 1: [20.03601225 40.07202449]
Temp 2: [24.30641137 48.61282273]
==== Chain 7 ====
Temp 0: [1.82792179 3.65584358]
Temp 1: [21.75826256 43.51652511]
Temp 2: [23.52957773 47.05915545]
==== Chain 8 ====
Temp 0: [2.03698192 4.07396384]
Temp 1: [22.17226908 44.34453816]
Temp 2: [24.09203654 48.18407308]
==== Chain 9 ====
Temp 0: [2.08347402 4.16694804]
Temp 1: [21.70794273 43.41588547]
Temp 2: [24.93525196 49.87050393]
==== Chain 10 ====
Temp 0: [1.71800748 3.43601497]
Temp 1: [23.4670909 46.9341818]
Temp 2: [22.44236126 44.88472251]
==== Chain 11 ====
Temp 0: [2.12504161 4.25008321]
Temp 1: [21.55691027 43.11382055]
Temp 2: [24.81255185 49.62510371]
Plot acceptance rates
We’ll plot the acceptance rate for each chain, which we define here as
the number of times a proposal was accepted divided by the total number
of iterations. We expect this to be close to ~0.23 for the coldest
chain, as this was the target rate of the AdaptiveNormal
proposal
that we used.
acceptance = sampler.acceptance
arate = acceptance['accepted'].sum(axis=2)/float(acceptance.shape[-1])
aratio = acceptance['acceptance_ratio']
# limit to 1
aratio[aratio > 1] = 1.
aratio = aratio.mean(axis=2)
# plot
fig, ax = pyplot.subplots()
for tk in range(ntemps):
ax.scatter(range(nchains), arate[tk,:], label='temp {}'.format(tk))
ax.axhline(arate[tk,:].mean(), color='C{}'.format(tk), linestyle='--')
ax.legend()
ax.set_ylabel('mean acceptance rate')
ax.set_xlabel('chain index')
fig.show()
<IPython.core.display.Javascript object>
print("Average acceptance rate over all chains:", arate.mean(axis=1))
Average acceptance rate over all chains: [0.2125651 0.34830729 0.35611979]
Indeed, the average acceptance rate over all of the coldest chains is close to 0.23.
The acceptance rate for the higher temperatures is larger than 0.23, however. To understand why, let also plot the average acceptance ratio. This should be approximately the same as the acceptance rate.
# plot
fig, ax = pyplot.subplots()
for tk in range(ntemps):
ax.scatter(range(nchains), aratio[tk,:], label='temp {}'.format(tk))
ax.axhline(aratio[tk,:].mean(), color='C{}'.format(tk), linestyle='--')
ax.set_ylabel('mean acceptance ratio')
ax.set_xlabel('chain index')
fig.show()
<IPython.core.display.Javascript object>
print("Average acceptance ratio over all chains:", aratio.mean(axis=1))
Average acceptance ratio over all chains: [0.21649604 0.35011284 0.35586311]
The average acceptance ratio is the same as the average acceptance rate, as expected.
So why do the higher temperatures have a larger acceptance ratio than are target rate of 0.234? There is a minimum intrinsic acceptance rate that a posterior may have. To see this, consider a posterior that is a uniform distribution. Every point in that case has equal probability. Likewise, for a symmetric jump proposal like the normal proposal, the ratio of the proposal’s pdf between any two points is also the same. Consequently, the acceptance ratio is 1 everywhere, no matter what size covariance we use for our jump proposal. As the posterior becomes less flat it becomes possible to have an acceptance ratio < 1. This is why the higher temperatures have higher rates: the higher the temperature, the flatter the likelihood we are sampling.
Resume from a state
The sampler can be checkpointed by getting its current state with
sampler.state
. Let’s check that this still works with the
AdaptiveNormal
proposal. To demonstrate this, we’ll save the current
state of the sampler to a checkpoint HDF5 file, then run it for another
set of iterations. We’ll then create a new sampler, and set it’s state
using the checkpoint. Running the same sampler for the same number of
iterations should produce the same results.
Note: Since we will be reading/writing from an HDF5 file, we’ll need
h5py
installed. This is not one of the required packages for epsie,
so if you don’t have it, uncomment the next line to install it. If you
don’t wish to use this functionality, or prefer to save checkpoints
using pickle or some other file format, you do not need to use HDF5.
#!pip install h5py
import h5py
# open the file to checkpoint to
fp = h5py.File('checkpoint_test.hdf', 'w')
# Dump the checkpoint. By default, the checkpoint will be saved
# to '/sampler_state' in the hdf5 file, but you can change the group
# and dataset name using the path and dsetname arguments, respectively.
sampler.checkpoint(fp)
# now advance the sampler for another 256 iterations
sampler.run(256)
# create a new sampler, but set it's state to what the original sampler's
# was after the first 256 iterations
sampler2 = ParallelTemperedSampler(model.params, model, nchains,
proposals=[proposal],
betas=betas, swap_interval=swap_interval,
pool=pool)
sampler2.set_state_from_checkpoint(fp)
# close the file since we're done with it
fp.close()
# now advance the new sampler for 256 iterations
# note that we don't have to run set_start first, since the
# starting positions have been set by set_start
sampler2.run(256)
# compare the current results; they should be the same between sampler2 and sampler
print('x:', (sampler.current_positions['x'] == sampler2.current_positions['x']).all())
print('y:', (sampler.current_positions['y'] == sampler2.current_positions['y']).all())
print('logl:', (sampler.current_stats['logl'] == sampler2.current_stats['logl']).all())
print('logp:', (sampler.current_stats['logp'] == sampler2.current_stats['logp']).all())
print('acceptance ratio:',
(sampler.acceptance['acceptance_ratio'][:,:,-1] == sampler2.acceptance['acceptance_ratio'][:,:,-1]).all())
print('accepted:',
(sampler.acceptance['accepted'][:,:,-1] == sampler2.acceptance['accepted'][:,:,-1]).all())
x: True
y: True
logl: True
logp: True
acceptance ratio: True
accepted: True
Clearing memory and continuing
The history of results in memory can be cleared using .clear()
.
Running the sampler after a clear yields the same results as if no clear
had been done. This is useful for keeping memory usage down: you can
dump results to a file after some number of iterations, clear, then
continue.
To demonstrate this, we’ll clear sampler2
, then run both sampler
and sampler2
for another 512 iterations. We’ll then compare the
current results; they should be the same.
We’ll also check that the adaptation has ended. Recall that, above, we set the proposal’s adaption to run for the first 512 iterations. From this point on, no more adaption should occur. Let’s check that by comparing the covariance matrix to the current iteration to what it is after the next set of iterations.
cov = numpy.zeros((ntemps, nchains, 2))
for ci, ptchain in enumerate(sampler.chains):
for tk, c in enumerate(ptchain.chains):
p = c.proposal_dist.proposals[0]
cov[tk, ci, :] = p.cov[[0, 1], [0, 1]]
Now clear and run both samplers:
sampler2.clear()
sampler.run(512)
sampler2.run(512)
# compare the current results; they should be the same between sampler2 and sampler
print('x:', (sampler.current_positions['x'] == sampler2.current_positions['x']).all())
print('y:', (sampler.current_positions['y'] == sampler2.current_positions['y']).all())
print('logl:', (sampler.current_stats['logl'] == sampler2.current_stats['logl']).all())
print('logp:', (sampler.current_stats['logp'] == sampler2.current_stats['logp']).all())
print('acceptance ratio:',
(sampler.acceptance['acceptance_ratio'][:,:,-1] == sampler2.acceptance['acceptance_ratio'][:,:,-1]).all())
print('accepted:',
(sampler.acceptance['accepted'][:,:,-1] == sampler2.acceptance['accepted'][:,:,-1]).all())
x: True
y: True
logl: True
logp: True
acceptance ratio: True
accepted: True
Check that the covariance matrix hasn’t changed:
current_cov = numpy.zeros((ntemps, nchains, 2))
for ci, ptchain in enumerate(sampler.chains):
for tk, c in enumerate(ptchain.chains):
p = c.proposal_dist.proposals[0]
current_cov[tk, ci, :] = p.cov[[0, 1], [0, 1]]
print("current covariance is the same:", (current_cov == cov).all())
current covariance is the same: True
Plot the posterior
Let’s create a scatter plot of the posterior. For this, we’ll need to throw out some earlier samples for the burn-in period; we’ll just assume the first-half of the chains were burn in. We also need to calculate the autocorrelation length of the chains in order to get independent samples.
def calculate_acf(data):
"""Calculates the autocorrelation of some data"""
# zero the mean
data = data - data.mean()
# zero-pad to 2 * nearest power of 2
newlen = int(2**(1+numpy.ceil(numpy.log2(len(data)))))
x = numpy.zeros(newlen)
x[:len(data)] = data[:]
# correlate
acf = numpy.correlate(x, x, mode='full')
# drop corrupted region
acf = acf[len(acf)//2:]
# normalize
acf /= acf[0]
return acf
def calculate_acl(data):
"""Calculates the autocorrelation length of some data.
Algorithm used is from:
N. Madras and A.D. Sokal, J. Stat. Phys. 50, 109 (1988).
"""
# calculate the acf
acf = calculate_acf(data)
# now the ACL: Following from Sokal, this is estimated
# as the first point where M*tau[k] <= k, where
# tau = 2*cumsum(acf) - 1, and M is a tuneable parameter,
# generally chosen to be = 5 (which we use here)
m = 5
cacf = 2.*numpy.cumsum(acf) - 1.
win = m * cacf <= numpy.arange(len(cacf))
if win.any():
acl = int(numpy.ceil(cacf[numpy.where(win)[0][0]]))
else:
# data is too short to estimate the ACL, just choose
# the length of the data
acl = len(data)
return acl
For parallel tempered chains, we take the ACL for a given chain as the maximum ACL over all of the temperatures of that chain. Usually, the longest ACL comes from the coldest chain (from which we get the posterior). Since the chains are completely independent of each other, we can calculate the ACL separately for each chain. However, if you’d like to be more conservative, you can also just take the max over all of the chains.
# get the samples; recall that this is a dictionary of
# nchains x niterations arrays for each parameter
samples = sampler.positions
# as we said above, we'll assume the first half
# of the chain was burn in
burnin_iter = sampler.niterations // 2
# set up arrays to store the ACL of each chain and
# the thinned chains
acls = numpy.zeros(nchains, dtype=int)
# cycle over the chains, calculating the ACLs and thinning them
for ii in range(nchains):
temp_acls = numpy.zeros(ntemps, dtype=int)
for tk in range(ntemps):
# get the second half of the chains
sx = samples['x'][tk, ii, burnin_iter:]
sy = samples['y'][tk, ii, burnin_iter:]
# compute the acl for each parameter
aclx = calculate_acl(sx)
acly = calculate_acl(sy)
acl = max(aclx, acly)
temp_acls[tk] = acl
# take the max over the temps
acl = max(temp_acls)
acls[ii] = acl
# thin the arrays
thinned_arrays = {'x': [], 'y': []}
for tk in range(ntemps):
txarray = []
tyarray = []
for ii in range(nchains):
sx = samples['x'][tk, ii, burnin_iter:]
sy = samples['y'][tk, ii, burnin_iter:]
# we'll thin the arrays starting from the
# end to get the lastest results
sx = sx[::-1][::acls[ii]][::-1]
sy = sy[::-1][::acls[ii]][::-1]
txarray.append(sx)
tyarray.append(sy)
thinned_arrays['x'].append(txarray)
thinned_arrays['y'].append(tyarray)
# the ACL of each chain:
print(acls)
[ 7 7 7 7 14 16 13 8 7 10 6 9]
# create a flattened posterior array from the coldest temperature
posterior = {'x': numpy.concatenate(thinned_arrays['x'][0]),
'y': numpy.concatenate(thinned_arrays['y'][0])}
print("Number of independent samples:", posterior['x'].size)
Number of independent samples: 738
Compare to the Normal proposal used in test_ptsampler
notebook: the
ACLs are ~\(5\times\) smaller, yielding ~\(5\times\) more
posterior samples for the same number of iterations.
# histogram them
fig, ax = pyplot.subplots()
ax.hist(posterior['x'], bins=10, histtype='step', label='x')
ax.hist(posterior['y'], bins=10, histtype='step', label='y')
ax.legend()
fig.show()
<IPython.core.display.Javascript object>
Let’s check the mean and variance of our estimated posterior. These should be \(\bar{x} \approx 2, \sigma^2_{x} \approx 1\) and \(\bar{y} \approx 5, \sigma^2_{y} \approx 2\):
for param in posterior:
s = posterior[param]
print(param, 'mean: {}'.format(s.mean()), 'var: {}'.format(s.var()))
x mean: 2.0703161571151942 var: 1.023541999736801
y mean: 5.008670434569227 var: 1.9882423448316633