from glob import glob
from multiprocessing import Pool
from os import mkdir, remove, stat
from os.path import basename, isdir, isfile
from shutil import copyfile
from corner import corner
from emcee import EnsembleSampler
from emcee.autocorr import integrated_time
from emcee.backends import HDFBackend
from emcee.moves import StretchMove
from matplotlib import rc, rcParams
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.pyplot import subplots
from matplotlib.transforms import Bbox
from numpy import arange, concatenate, empty, exp, genfromtxt, linspace, log, mean, ndarray, sort, where
from scipy.optimize import Bounds, differential_evolution
from metallicity_mcmc import mcmc_sampling, plot_chains
# -- font size
rcParams['font.size']=13
# -- output directory
dir = 'MCMC_sampling/'
if not isdir(dir):
print('### Create a directory ... : ' + dir)
mkdir(dir)
mcmc_pass_kwargs = {'glopt': False}
# -- results from the global optimization (finding the mode of the target distribution)
glopt_file = dir+'Rosenbrock.glopt'
def ln_target(param, glopt=True):
Rosenbrock=-(100*(param[1]-param[0]**2)**2 + (1-param[0])**2)/20.
ret_val = Rosenbrock if not glopt else -Rosenbrock
return ret_val
The tolerance is deliberately loosened to 0.9 so that the results distribute around the global minimum (1,1) not to fall exactly on it. You may need to reduce the tolerance for your task, say 0.01.
The results are used as the initial positions for the MCMC sampling [see S1) below].
def find_bestfitpar(output_num, out_glopt):
# a reasonable choice for NP is between 5*D and 10*D (Storn_1997_J.GlobalOptim._11_341)
print('### number of output is {} ...'.format(output_num))
print('### Save... ' + out_glopt)
f_glopt = open(out_glopt, 'w')
lb=[-10,-5]
ub=[10,50]
for i in range(output_num):
glopt = differential_evolution(ln_target, Bounds(lb, ub), workers=-1, tol=0.9)
if glopt.success:
prtstr = ' |'.join(['{:>15.8e}'.format(i) for i in glopt.x])
print(prtstr)
f_glopt.write(prtstr + '\n')
print('### Done...')
f_glopt.close()
output_num = 32 # This is equal to the number of walkers for the MCMC sampling.
find_bestfitpar(output_num, glopt_file)
def plot_bestfitpar(glopt_file):
col0 = genfromtxt(glopt_file, delimiter='|', names=['metal', 'E_BV'])
fig, ax = subplots(1, 2, figsize=(10,5))
ax[0].hist(col0['metal'])
ax[0].set_xlabel('12 + log$\,$(O/H)')
ax[1].hist(col0['E_BV'])
ax[1].set_xlabel('E$\,$(B-V)')
outfile = glopt_file.replace('.glopt','.bestfitpar.png')
print('### Save... ' + outfile)
suptitle = basename(glopt_file.replace('.glopt',''))
fig.suptitle(suptitle)
fig.tight_layout(rect=[0,0,1,0.95])
fig.savefig(outfile)
glopt_files=sort(glob(dir+'*.glopt'))
for glopt_file in glopt_files:
print('### Read... {}'.format(glopt_file))
plot_bestfitpar(glopt_file)
def plot_chains(backend_file, truths=None, set_ymm=None):
# -- extract chains from backend
reader = HDFBackend(backend_file)
chain: ndarray = reader.get_chain()
iteration, nwalkers, ndim = chain.shape
print('### (iter, nwalkers, ndim) = {}'.format(chain.shape))
chain_log_prob: ndarray = reader.get_log_prob()
# -- chain plots
outfile0 = backend_file.replace('.h5', '.plot_chains.png')
title = basename(backend_file.replace('.h5', ''))
fig, axes = subplots(1, 3, figsize=(7, 7), sharey=True)
fig.suptitle(title)
labels = ['p0', 'p1', 'ln(posterior)']
for i in range(ndim + 1):
ax: Axes = axes[i]
if i == ndim:
ax.set_xlabel(labels[i])
ax.axvline(x=max(chain_log_prob[0, :]))
ax.plot(chain_log_prob, arange(iteration), "k", alpha=0.1)
else:
ax.plot(chain[:, :, i], arange(iteration), "k", alpha=0.1)
if truths: ax.axvline(truths[i], color='red')
ax.set_xlabel(labels[i])
ax.set_ylim(0, len(chain))
axes[0].set_ylabel("iteration");
if set_ymm: axes[0].set_ylim(set_ymm)
print('### Save... ' + outfile0)
fig.tight_layout(rect=[0, 0, 1, 0.96])
fig.savefig(outfile0)
def mcmc_sampling(backend_file, backend_reset, glopt_file, iteration, mcmc_pass_kwargs, random_state=None, stretch_a=2.):
# -- result file
backend = HDFBackend(backend_file)
print('### Save backend file... : {}'.format(backend_file))
# -- the initial values
if backend_reset:
init_values = genfromtxt(glopt_file, delimiter='|', dtype='f8')
nwalkers, ndim = init_values.shape
else:
init_values = backend.get_last_sample()
init_values.random_state = random_state
nwalkers, ndim = backend.shape
# -- backend reset
if backend_reset: backend.reset(nwalkers, ndim)
# -- multiprocessing
with Pool() as pool:
sampler = EnsembleSampler(nwalkers, ndim, ln_target, moves=StretchMove(a=stretch_a),
kwargs=mcmc_pass_kwargs, pool=pool, backend=backend)
sampler.run_mcmc(init_values, iteration, progress=True)
print('### Total iteration: {}'.format(sampler.iteration))
random_state = sampler.random_state
# -- save last coords
outfile = backend_file.replace('.h5', '.last_coords.dat')
print('### Save ... ' + outfile)
f_out = open(outfile, 'w')
co_arr = backend.get_last_sample().coords
row, col = co_arr.shape
for row_i in range(row):
prtstr = ' |'.join(['{:0f}'.format(i) for i in co_arr[row_i, :]])
f_out.write(prtstr + '\n')
f_out.close()
return random_state
def sample_new():
backend_reset = True
random_state_last = mcmc_sampling(backend_file, backend_reset, glopt_file, iteration, mcmc_pass_kwargs)
print('### Copy the backend file to the backup : {}'.format(backend_backup_file))
copyfile(backend_file, backend_backup_file)
plot_chains(backend_file)
return random_state_last
iteration = 2000
# -- file setting
backend_file = glopt_file.replace('.glopt','.h5')
backend_backup_file = backend_file.replace('.h5','.backup.h5')
# -- backend existence check
if isfile(backend_file):
print(backend_file)
file_size_MB=stat(backend_file).st_size >> 20
chk = input('file size = {} MB.'.format(file_size_MB)+' ### Delete it? (y or else) :')
if chk == 'y':
print('### Delete the existing backend file and create a new one...')
remove(backend_file)
random_state_last = sample_new()
else:
print('### backend_file is unchanged.')
backend = HDFBackend(backend_file)
random_state_last = backend.get_last_sample().random_state
else:
random_state_last = sample_new()
When the convergence is good enough (e.g., each $\tau_{int}$ crosses the "ESS = 2000" line), go to S4) and check if there is any other local extrema.
def plot_corner(backend_file, burn_in, truths=None, pdf=False, chain_end=None, quantiles=None):
# -- extract chains from backend
reader = HDFBackend(backend_file)
chain: ndarray = reader.get_chain()
if chain_end:
chain = chain[burn_in:burn_in + chain_end, :, :]
else:
chain = chain[burn_in:, :, :]
iteration, nwalkers, ndim = chain.shape
print('### (iter, nwalkers, ndim) = {}'.format(chain.shape))
# -- integrated autocorrelation time calculation
chk_num = 21
chain_mean = mean(chain, axis=1)
iat_arr = empty((chk_num, ndim))
chks = exp(linspace(log(100), log(iteration), chk_num)).astype(int)
for chk in chks:
chk_ind = where(chks == chk)[0]
for dim in range(ndim):
iat_arr[chk_ind, dim] = integrated_time(chain_mean[:chk, dim], tol=0)
#---- plotting
## -- corner plots
labels = ['p0', 'p1', 'ln(posterior)']
burnin_tag = '_BurnIn{:04d}'.format(burn_in)
outfile1 = backend_file.replace('.h5', burnin_tag + '.plot_corner.png')
if pdf: outfile1 = outfile1.replace('.png', '.pdf')
chain_rsh = chain.reshape([-1, ndim])
title = basename(backend_file.replace('.h5', burnin_tag))
ftsize = 15
rc('axes', labelsize=ftsize)
rc('xtick', labelsize=ftsize * 0.8)
rc('ytick', labelsize=ftsize * 0.8)
if not quantiles: quantiles = [0.025, 0.16, 0.5, 0.84, 0.975]
fig: Figure = corner(chain_rsh, bins=40, truth_color='red', truths=truths, labels=labels[0:4],
quantiles=quantiles, show_titles=True)
fig.suptitle(title, fontsize=ftsize * 1.2, y=1.03)
## -- tau_iat evolution as an overplot
enlarge = 1.8
f_w, f_h = fig.get_size_inches()
fig.set_size_inches(f_w * enlarge, f_h * enlarge)
ax0: Axes = fig.add_subplot(111)
for par in range(ndim): ax0.plot(chks, iat_arr[:, par], '-o', markersize=4, label=labels[par], alpha=0.5)
chks_ext = concatenate([[chks[0] / 2.], chks])
chks_ext = concatenate([chks_ext, [chks[-1] * 2.]])
ess_chk = 1000. / nwalkers
ax0.plot(chks_ext, chks_ext / ess_chk, '-.k', label='ESS = 1000')
ess_chk = 2000. / nwalkers
ax0.plot(chks_ext, chks_ext / ess_chk, '--k', label='ESS = 2000')
ax0.semilogx()
ax0.semilogy()
ax0.tick_params(which='both', top=True, right=True,
labelbottom=False, labelleft=False, labeltop=True, labelright=True)
ax0.xaxis.set_label_position('top')
ax0.yaxis.set_label_position('right')
ax0.set_xlim(chks[0] * 0.8, chks[-1] * 1.2)
ax0.set_ylim(iat_arr.min() * 0.5, iat_arr.max() * 2)
ax0.set_xlabel('iteration')
ax0.set_ylabel(r'$\tau_{int}$')
ax0.legend(fontsize=ftsize * 0.8, loc='upper right', bbox_to_anchor=(1.03, 0))
pos2 = [0.57, 0.72, 0.38, 0.20]
ax0.set_position(pos2)
# -- save
print('### Save... ' + outfile1)
f_enlarge = 1.05
shift = -0.1
bb = fig.bbox_inches.bounds
bb_new = (bb[0] + shift, bb[1] + shift, bb[2] * f_enlarge, bb[3] * f_enlarge)
fig.savefig(outfile1, bbox_inches=Bbox.from_bounds(*bb_new))
burn_in = 1500
plot_corner(backend_file, burn_in)
plot_corner(backend_backup_file, burn_in)
iteration = 20000
# -- choose one
# abandon = True
abandon = False
backend_reset = False # Do not change this value
if abandon:
copyfile(backend_backup_file, backend_file)
print('### Copy the backup to the backend...')
else:
copyfile(backend_file, backend_backup_file)
print('### Copy the backend to the backup...')
random_state_last = mcmc_sampling(backend_file, backend_reset, glopt_file, iteration,
mcmc_pass_kwargs, random_state=random_state_last)
plot_chains(backend_file)
iteration = 1000
stretch_width = 2.*10
# -- sampler setting
lastco_file = backend_file.replace('.h5','.last_coords.dat')
init_values = genfromtxt(lastco_file, delimiter='|', dtype='f8')
nwalkers, ndim = init_values.shape
# -- backend setting
backend_extsrch_file = backend_file.replace('.h5','.extsrch.h5')
backend_extsrch = HDFBackend(backend_extsrch_file)
backend_extsrch.reset(nwalkers, ndim)
print('### Save backend file... : {}'.format(backend_extsrch_file))
# -- multiprocessing
with Pool() as pool:
sampler = EnsembleSampler(nwalkers, ndim, ln_target, moves=StretchMove(a=stretch_width), kwargs=mcmc_pass_kwargs, pool=pool, backend=backend_extsrch)
sampler.run_mcmc(init_values, iteration, progress=True)
print('### Total iteration: {}'.format(sampler.iteration))
plot_chains(backend_extsrch_file)