You can run this notebook and obtain the MCMC results once the necessary packages are installed and the required variables are correctly defined.
The whole procedure for parameter recovery tests is as follows.
see B0, B1, B2
see B3, B4, B5
.glopt
file and rename it appropriately.)see S1
see S2, S3
see S4
import os
from glob import glob
from multiprocessing.dummy import Pool
from os import mkdir, remove, stat
from os.path import basename
from os.path import dirname
from os.path import isdir, isfile
from shutil import copyfile
import numpy as np
from astropy.io import ascii
from corner import corner
from emcee import EnsembleSampler
from emcee.autocorr import integrated_time
from emcee.backends import HDFBackend
from emcee.moves import StretchMove
from izi import fb_extinction as ext
from izi.izi_MCMC_mod import grid, izi_MCMC
from matplotlib import rc
from matplotlib import rcParams
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.pyplot import subplots
from matplotlib.transforms import Bbox
from numpy import arange, ndarray, sort
from numpy import argmax, empty, exp, histogram, linspace, mean
from numpy import concatenate, where
from numpy import genfromtxt, unique, zeros
from numpy import log
from scipy import special
from scipy.optimize import Bounds, differential_evolution
# -- output directory
dir_out = 'mock_data/'
if not isdir(dir_out):
print('### Create a directory ... : ' + dir_out)
mkdir(dir_out)
# -- mock data input
setup_file = 'lst.mock_data_setup_example'
mock_name = 'mock_temp1'
ebv_min = 0.0
ebv_max = 1.0
rcParams['font.size'] = 13
os.environ["OMP_NUM_THREADS"] = "1"
def read_mock_data_setup(file_setup, setup_name):
col0 = genfromtxt(file_setup, delimiter='|', skip_header=1,
dtype=['|U30', '|U40', 'f', 'f', 'f', 'f', 'f', '|U1000'],
names=['name', 'grid_file', 'snr', 'dex_epsilon', 'logz', 'logq', 'ebv', 'lines'], autostrip=True)
pick = col0['name'] == setup_name
dummy, grid_file, snr, dex_epsilon, logz, logq, ebv, lines_tmp = col0[pick][0]
lines = lines_tmp.split(', ')
return grid_file, snr, dex_epsilon, logz, logq, ebv, lines
def read_grid_setup(grid_file, line_dir, templ_dir):
grid_read = grid(grid_file, templ_dir=templ_dir)
intergrid = grid_read.grid0
nlines0 = len(intergrid['ID'][0])
nz = len(unique(intergrid['LOGZ']))
nq = len(unique(intergrid['LOGQ']))
logohsun = (grid_read.logohsun)
line_params = ascii.read(line_dir + '/line_names.txt')
line_wav = zeros(nlines0)
for ii in range(nlines0):
line_name = intergrid['ID'][0][ii]
ww = (line_params['line_name'] == line_name)
assert ww.sum() == 1, 'ERROR: ===== Line ID ' + \
intergrid['ID'][0][ii] + 'not included in wavelength list====='
line_wav[ii] = line_params['wav'][ww]
grid_read.grid0['WAV'] = [line_wav] * nq * nz
intergrid = grid_read.grid0
zrange = [np.min(intergrid['LOGZ']), np.max(intergrid['LOGZ'])]
qrange = [np.min(intergrid['LOGQ']), np.max(intergrid['LOGQ'])]
return intergrid, logohsun, nlines0, zrange, qrange
#-- mock data setup
set_grid, set_snr, set_dex_epsilon, set_logz, set_logq, set_ebv, set_lines = read_mock_data_setup(setup_file, mock_name)
#-- grid setup
idnorm = 'hbeta'
templ_dir = '[PATH_TO_IZI]/python_izi/grids/'
line_dir = '[PATH_TO_IZI]/python_izi/izi/'
intergrid, logohsun, nlines0, zrange, qrange = read_grid_setup(set_grid, line_dir, templ_dir)
def mk_mockdata(grid, snr, logz, logq, ebv, lines, setup_name):
#-- output file setup
outfile = dir_out + setup_name + '.dat'
print('### Save... ' + outfile)
f_in = open(outfile, 'w')
print(f'S/N = {snr:5.0f}')
f_in.write(f'S/N = {snr:5.0f}\n')
print(f'logz = {logz:5.1f}')
f_in.write(f'logz = {logz:5.1f}\n')
print(f'logq = {logq:5.1f}')
f_in.write(f'logq = {logq:5.1f}\n')
print(f'ebv = {ebv:5.1f}')
f_in.write(f'ebv = {ebv:5.1f}\n')
flux_redd = izi_MCMC.flux_grid_ext(intergrid, logz - logohsun, logq, ebv, idnorm)
error = flux_redd / snr
#; CREATE DATA STRUCTURE CONTAINING LINE FLUXES AND ESTIMATED PARAMETERS
dd = {'id': intergrid['ID'][0], # line id
'flux': np.zeros(nlines0) + np.nan, # line flux
'error': np.zeros(nlines0) + np.nan,
'epsilon': np.zeros(nlines0) + 10 ** set_dex_epsilon}
#FILL STRUCTURE WITH LINE FLUXES
nlines_in = len(lines)
for i in range(nlines_in):
auxind = (dd['id'] == lines[i])
nmatch = auxind.sum()
assert nmatch == 1, 'ERROR: ===== Line ID ' + lines[i] + 'not recognized ====='
dd['flux'][auxind] = flux_redd[auxind]
dd['error'][auxind] = error[auxind]
in_idnorm = where(dd['id'] == 'halpha')
if set_dex_epsilon == -np.inf:
dd['epsilon'][in_idnorm] = 10 ** (set_dex_epsilon)
else:
dd['epsilon'][in_idnorm] = 10 ** 0.01
print('{0:10s} | {1:>12s} | {2:>12s} | {3:>12s}'.format('line', 'flux', 'error', 'epsilon'))
f_in.write('{0:10s} | {1:>12s} | {2:>12s} | {3:>12s}\n'.format('line', 'flux', 'error', 'epsilon'))
for i in range(nlines0):
print('{0:10s} | {1:12.3e} | {2:12.3e} | {3:12.3e}'.format(dd['id'][i], dd['flux'][i], dd['error'][i],
dd['epsilon'][i]))
f_in.write('{0:10s} | {1:12.3e} | {2:12.3e} | {3:12.3e}\n'.format(dd['id'][i], dd['flux'][i], dd['error'][i],
dd['epsilon'][i]))
f_in.close()
mk_mockdata(set_grid, set_snr, set_logz, set_logq, set_ebv, set_lines, mock_name)
def read_mock_data(in0):
f_in = open(in0, 'r')
for i in range(4):
line = f_in.readline()
if 'S/N' in line:
snr = float(line.split()[-1])
elif 'logz' in line:
logz = float(line.split()[-1])
elif 'logq' in line:
logq = float(line.split()[-1])
elif 'ebv' in line:
ebv = float(line.split()[-1])
col0 = genfromtxt(in0, dtype=['U20', 'f', 'f', 'f'], delimiter='|', skip_header=5,
names=['id', 'flux', 'error', 'epsilon'],
autostrip=True)
id = genfromtxt(in0, dtype=['U20'], delimiter='|', skip_header=5, names=['id'], usecols=(0), autostrip=True)
data = genfromtxt(in0, dtype=['f', 'f', 'f'], delimiter='|', skip_header=5, names=['flux', 'error', 'epsilon'],
usecols=(1, 2, 3), autostrip=True)
# INDEX LINES WITH MEASUREMENTS
good = (np.isfinite(data['error']))
# ngood=good.sum()
measured = (np.isfinite(data['flux']))
upperlim = ((np.isfinite(data['error'])) & (data['flux'] == -666))
flag0 = np.zeros(len(id))
flag0[measured] = 1 #measured flux
flag0[upperlim] = 2 #upper limit on flux
#This array has length ngood, which is the number of lines with
#given error measurements. If error is given but no flux this is treated
#as an upper limit
flag = flag0[good]
# NORMALIZE LINE FLUXES TO HBETA OR
# IF ABSENT NORMALIZE TO BRIGHTEST LINE
idnorm = 'hbeta'
in_idnorm = (id['id'] == idnorm)
if (np.isnan(data['flux'][in_idnorm]) | (data['flux'][in_idnorm] == -666)):
a = id[measured]
idnorm = a[np.argmax(data['flux'][measured])][0]
in_idnorm = (id['id'] == idnorm)
norm = data['flux'][in_idnorm]
#NORMALISE INPUT FLUXES
data['flux'][measured] = data['flux'][measured] / norm[0]
data['error'][good] = data['error'][good] / norm[0]
data['flux'][upperlim] = -666
print(f'line for normalization = {idnorm}')
return snr, logz, logq, ebv, id, data, good, flag, idnorm
def flux_grid_ext(grid_, zz, qq, ebv, idnorm):
flux_at = grid.value_grid_at_no_interp(grid_, zz, qq)
red_vec = ext.reddening_vector_calzetti(grid_['WAV'][0], ebv)
flux_redd = flux_at / red_vec
#NORMALISE GRID
norm = flux_redd[grid_['ID'][0] == idnorm]
flux_redd = flux_redd / norm
return flux_redd
#DEFINE STATISTICAL FUNCTIONS
#LIKELIHOOD of one specific point of paramter space (theta)
def lnlike(theta, dd, good, flag, grid_, idnorm, logzprior=None, logqprior=None):
zz, qq, ebv = theta
fff = np.array(flux_grid_ext(grid_, zz, qq, ebv, idnorm))
ngood = good.sum()
like = 1.
if logqprior is not None:
gauss_q = 1. / (logqprior[1] * np.sqrt(2. * np.pi)) * np.exp(- ((qq - logqprior[0]) / logqprior[1]) ** 2 / 2.)
else:
gauss_q = 1.
if logzprior is not None:
gauss_z = 1. / (logzprior[1] * np.sqrt(2. * np.pi)) * np.exp(- ((zz - logzprior[0]) / logzprior[1]) ** 2 / 2.)
else:
gauss_z = 1.
for j in range(ngood):
if (flag[j] == 1):
e2 = dd['error'][good][j] ** 2.0 + (dd['epsilon'][good][j] * fff[good][j]) ** 2.0
fdf2 = (dd['flux'][good][j] - fff[good][j]) ** 2.0
like = like * 1 / np.sqrt(2 * np.pi) * np.exp(-0.5 * fdf2 / e2) / np.sqrt(e2) * gauss_q * gauss_z
if (flag[j] == 2):
edf = (dd['error'][good][j] - fff[good][j])
e2 = dd['error'][good][j] ** 2.0 + (dd['epsilon'][good][j] * fff[good][j]) ** 2.0
like = like * 0.5 * (1 + special.erf(edf / np.sqrt(e2 * 2))) * gauss_q * gauss_z
return np.log(like)
#PRIOR
def lnprior(theta, zrange, qrange, max_ebv):
zz, qq, ebv = theta
if zrange[0] < zz < zrange[1] and qrange[0] < qq < qrange[1] and \
0. < ebv < max_ebv:
return 0.0
else:
return -np.inf
#POSTERIOR
def lnprob(theta, dd, good, flag, grid_, zrange, qrange, max_ebv, idnorm, glopt=True, chk_prior=False):
lp = lnprior(theta, zrange, qrange, max_ebv)
# tot=lp + lnlike(theta, dd, good, flag, grid_, idnorm)
tot = lp + lnlike(theta, dd, good, flag, grid_, idnorm) if not glopt \
else -(lp + lnlike(theta, dd, good, flag, grid_, idnorm))
if not np.isfinite(tot):
return -np.inf if not glopt else np.inf
# return -np.inf
return tot
def find_map_par(ln_target, data, good, flag, grid_, zrange, qrange, max_ebv, idnorm, output_num, out_glopt):
# a reasonable choice for NP is between 5*D and 10*D (Storn_1997_J.GlobalOptim._11_341)
print('### number of output is {} ...'.format(output_num))
print('### Save... ' + out_glopt)
f_glopt = open(out_glopt, 'w')
lb = [zrange[0], qrange[0], ebv_min]
ub = [zrange[1], qrange[1], ebv_max]
for i in range(output_num):
glopt = differential_evolution(ln_target, Bounds(lb, ub),
args=(data, good, flag, grid_, zrange, qrange, max_ebv, idnorm),
workers=-1, tol=0.5)
if glopt.success:
prtstr = ' |'.join(['{:>15.8e}'.format(i) for i in glopt.x])
print(prtstr)
f_glopt.write(prtstr + '\n')
print('### Done...')
f_glopt.close()
mock_file = dir_out + mock_name + '.dat'
input_snr, input_logz, input_logq, input_ebv, id, data, good, flag, idnorm = read_mock_data(mock_file)
output_num = 32
out_glopt = mock_file.replace('.dat', '.glopt')
find_map_par(lnprob, data, good, flag, intergrid, zrange, qrange, ebv_max, idnorm, output_num, out_glopt)
def plot_map_par(glopt_file):
col0 = genfromtxt(glopt_file, delimiter='|', names=['logz', 'logq', 'ebv'])
fig, ax = subplots(2, 2, figsize=(7, 7))
ax[0, 0].hist(col0['logz'] + logohsun)
ax[0, 0].vlines(input_logz, 0, 1, transform=ax[0, 0].get_xaxis_transform(), linestyles='--', colors='red')
ax[0, 0].set_xlabel('12 + log$\,$(O/H)')
ax[0, 1].hist(col0['logq'])
ax[0, 1].vlines(input_logq, 0, 1, transform=ax[0, 1].get_xaxis_transform(), linestyles='--', colors='red')
ax[0, 1].set_xlabel('log$\,$(q)')
ax[1, 0].hist(col0['ebv'])
ax[1, 0].vlines(input_ebv, 0, 1, transform=ax[1, 0].get_xaxis_transform(), linestyles='--', colors='red')
ax[1, 0].set_xlabel('E$\,$(B-V)')
outfile = glopt_file.replace('.glopt', '.map_par.png')
print('### Save... ' + outfile)
suptitle = basename(glopt_file.replace('.glopt', ''))
fig.suptitle(suptitle)
fig.tight_layout(rect=[0, 0, 1, 0.95])
fig.savefig(outfile)
mock_file = dir_out + mock_name + '.dat'
out_glopt = mock_file.replace('.dat', '.glopt')
input_snr, input_logz, input_logq, input_ebv, id, data, good, flag, idnorm = read_mock_data(mock_file)
glopt_files = sort(glob(out_glopt))
for glopt_file in glopt_files:
print('### Read... {}'.format(glopt_file))
plot_map_par(glopt_file)
iteration
¶iteration
: the number of iterations to which you run the MCMC samplingemcee
walkersdef plot_chains(backend_file, logohsun, truths=None, set_ymm=None):
# -- extract chains from backend
reader = HDFBackend(backend_file)
chain: ndarray = reader.get_chain()
iteration, nwalkers, ndim = chain.shape
print('### (iter, nwalkers, ndim) = {}'.format(chain.shape))
chain_log_prob: ndarray = reader.get_log_prob()
# -- chain plots
outfile0 = backend_file.replace('.h5', '.plot_chains.png')
title = basename(backend_file.replace('.h5', ''))
fig, axes = subplots(1, 4, figsize=(10, 7), sharey=True)
fig.suptitle(title)
labels = ['12 + log$\,$(O/H)', 'log$\,$(q)', 'E$\,$(B-V)', 'ln(posterior)']
for i in range(ndim + 1):
ax: Axes = axes[i]
if i == ndim:
if 'line3' in backend_file:
ax.set_xlabel(labels[i])
ax = axes[i + 1]
ax.set_xlabel(labels[i + 1])
else:
ax.set_xlabel(labels[i])
ax.axvline(x=max(chain_log_prob[0, :]))
ax.plot(chain_log_prob, arange(iteration), "k", alpha=0.1)
else:
if i == 0:
ax.plot(chain[:, :, i] + logohsun, arange(iteration), "k", alpha=0.1)
else:
ax.plot(chain[:, :, i], arange(iteration), "k", alpha=0.1)
if truths: ax.axvline(truths[i], color='red')
ax.set_xlabel(labels[i])
ax.set_ylim(0, len(chain))
axes[0].set_ylabel("iteration");
if set_ymm: axes[0].set_ylim(set_ymm)
print('### Save... ' + outfile0)
fig.tight_layout(rect=[0, 0, 1, 0.96])
fig.savefig(outfile0)
def mcmc_sampling(backend_file, backend_reset, glopt_file, iteration,
ln_target, data, good, flag, grid_, zrange, qrange, max_ebv, idnorm, mcmc_pass_kwargs,
random_state=None, stretch_a=2.):
# -- result file
backend = HDFBackend(backend_file)
print('### Save backend file... : {}'.format(backend_file))
# -- the initial values
if backend_reset:
init_values = genfromtxt(glopt_file, delimiter='|', dtype='f8')
nwalkers, ndim = init_values.shape
else:
init_values = backend.get_last_sample()
init_values.random_state = random_state
nwalkers, ndim = backend.shape
# -- backend reset
if backend_reset: backend.reset(nwalkers, ndim)
# -- multiprocessing
with Pool() as pool:
sampler = EnsembleSampler(nwalkers, ndim, ln_target, moves=StretchMove(a=stretch_a),
args=[data, good, flag, grid_, zrange, qrange, max_ebv, idnorm],
kwargs=mcmc_pass_kwargs, pool=pool,
backend=backend)
sampler.run_mcmc(init_values, iteration, progress=True)
print('### Total iteration: {}'.format(sampler.iteration))
random_state = sampler.random_state
# -- save last coords
outfile = backend_file.replace('.h5', '.last_coords.dat')
print('### Save ... ' + outfile)
f_out = open(outfile, 'w')
co_arr = backend.get_last_sample().coords
row, col = co_arr.shape
for row_i in range(row):
prtstr = ' |'.join(['{:0f}'.format(i) for i in co_arr[row_i, :]])
f_out.write(prtstr + '\n')
f_out.close()
return random_state
def sample_new(mock_input, logohsun):
backend_reset = True
input_snr, input_logz, input_logq, input_ebv, id, data, good, flag, idnorm = read_mock_data(mock_file)
random_state_last = mcmc_sampling(backend_file, backend_reset, glopt_file, iteration,
lnprob, data, good, flag, intergrid, zrange, qrange, ebv_max, idnorm,
mcmc_pass_kwargs)
print('### Copy the backend file to the backup : {}'.format(backend_backup_file))
copyfile(backend_file, backend_backup_file)
plot_chains(backend_file, logohsun, truths=mock_input)
return random_state_last
# mcmc_pass_kwargs = {'glopt': False, 'chk_prior': True}
mcmc_pass_kwargs = {'glopt': False, 'chk_prior': False}
iteration = 1000
mock_file = dir_out + mock_name + '.dat'
glopt_file = mock_file.replace('.dat', '.glopt')
mock_input = [input_logz, input_logq, input_ebv]
# -- file setting
backend_file = glopt_file.replace('.glopt', '.h5')
if mcmc_pass_kwargs['chk_prior']: backend_file = backend_file.replace('.h5', '_chkprior.h5')
backend_backup_file = backend_file.replace('.h5', '.backup.h5')
# -- backend existence check
if isfile(backend_file):
print(backend_file)
file_size_MB = stat(backend_file).st_size >> 20
chk = input('file size = {} MB.'.format(file_size_MB) + ' ### Delete it? (y or else) :')
if chk == 'y':
print('### Delete the existing backend file and create a new one...')
remove(backend_file)
random_state_last = sample_new(mock_input, logohsun)
else:
print('### backend_file is unchanged.')
backend = HDFBackend(backend_file)
random_state_last = backend.get_last_sample().random_state
else:
random_state_last = sample_new(mock_input, logohsun)
When the convergence is good enough (e.g., each $\tau_{int}$ crosses the "ESS = 2000" line), go to S4
and check if there is any other local extrema.
burn_in
¶burn_in
: the burn-in iterations (The samples before the burn-in iterations are removed when plotting the corner plot.)def plot_corner(backend_file, burn_in, logohsun,
truths=None, pdf=False, chain_end=None, quantiles=None, plotrange=[1, 1, 1, 1]):
# -- extract chains from backend
reader = HDFBackend(backend_file)
dir = dirname(backend_file)
chain: ndarray = reader.get_chain()
if chain_end:
chain = chain[burn_in:burn_in + chain_end, :, :]
else:
chain = chain[burn_in:, :, :]
#-- shift the logz values
chain[:, :, 0] = chain[:, :, 0] + logohsun
iteration, nwalkers, ndim = chain.shape
print('### (iter, nwalkers, ndim) = {}'.format(chain.shape))
# -- integrated autocorrelation time calculation
chk_num = 21
chain_mean = mean(chain, axis=1)
iat_arr = empty((chk_num, ndim))
chks = exp(linspace(log(100), log(iteration), chk_num)).astype(int)
for chk in chks:
chk_ind = where(chks == chk)[0]
for dim in range(ndim):
iat_arr[chk_ind, dim] = integrated_time(chain_mean[:chk, dim], tol=0)
# -- corner plots
labels = ['12 + log$\,$(O/H)', 'log$\,$(q)', 'E$\,$(B-V)']
burnin_tag = '_BurnIn{:04d}'.format(burn_in)
outfile1 = backend_file.replace('.h5', burnin_tag + '.plot_corner.png')
if pdf: outfile1 = outfile1.replace('.png', '.pdf')
chain_rsh = chain.reshape([-1, ndim])
title = basename(backend_file.replace('.h5', burnin_tag))
ftsize = 15
rc('axes', labelsize=ftsize)
rc('xtick', labelsize=ftsize * 0.8)
rc('ytick', labelsize=ftsize * 0.8)
if not quantiles: quantiles = [0.025, 0.16, 0.5, 0.84, 0.975]
fig: Figure = corner(chain_rsh, bins=40, truth_color='red', truths=truths, labels=labels[0:4],
quantiles=quantiles,
show_titles=True, title_kwargs={'size': ftsize * 0.8}, range=plotrange)
fig.suptitle(title, fontsize=ftsize * 1.2, y=1.03)
# -- check for the inverse-gamma prior
if 'chkprior' in backend_file and 'ONpriorI' in backend_file:
oiii_n, oiii_bins = histogram(chain_rsh[:, 2], bins=40, range=plotrange[2])
oiii_in = argmax(oiii_n)
print('[O III] peak = {:10.3e} {:10.3e}'.format(oiii_bins[oiii_in], oiii_bins[oiii_in + 1]))
nii_n, nii_bins = histogram(chain_rsh[:, 3], bins=40, range=plotrange[3])
nii_in = argmax(nii_n)
print('[N II] peak = {:10.3e} {:10.3e}'.format(nii_bins[nii_in], nii_bins[nii_in + 1]))
# -- tau_iat evolution as an overplot
if 'line3' in backend_file:
enlarge = 1.28
f_w, f_h = fig.get_size_inches()
fig.set_size_inches(f_w * enlarge, f_h * enlarge)
ax0: Axes = fig.add_subplot(111)
for par in range(ndim): ax0.plot(chks, iat_arr[:, par], '-o', markersize=4, label=labels[par], alpha=0.5)
chks_ext = concatenate([[chks[0] / 2.], chks])
chks_ext = concatenate([chks_ext, [chks[-1] * 2.]])
ess_chk = 1000. / nwalkers
ax0.plot(chks_ext, chks_ext / ess_chk, '-.k', label='ESS = 1000')
ess_chk = 2000. / nwalkers
ax0.plot(chks_ext, chks_ext / ess_chk, ':k', label='ESS = 2000')
ax0.semilogx()
ax0.semilogy()
ax0.tick_params(which='both', top=True, right=True,
labelbottom=False, labelleft=False, labeltop=True, labelright=True)
ax0.xaxis.set_label_position('top')
ax0.yaxis.set_label_position('right')
ax0.set_xlim(chks[0] * 0.8, chks[-1] * 1.2)
ax0.set_ylim(iat_arr.min() * 0.5, iat_arr.max() * 2)
ax0.set_xlabel('iteration')
ax0.set_ylabel(r'$\tau_{int}$')
ax0.legend(fontsize=ftsize * 0.8, loc='upper right', bbox_to_anchor=(1.03, 0))
pos2 = [0.57, 0.72, 0.38, 0.20]
ax0.set_position(pos2)
## -- save test-run
print('### Save... ' + outfile1)
f_enlarge = 1.05
shift = -0.1
bb = fig.bbox_inches.bounds
bb_new = (bb[0] + shift, bb[1] + shift, bb[2] * f_enlarge, bb[3] * f_enlarge)
fig.savefig(outfile1, bbox_inches=Bbox.from_bounds(*bb_new))
burn_in = 200
# plotrange = [(8.7529, 8.8533), (7.0195, 7.216), (0.05, 0.15)]
plotrange = [1, 1, 1]
plot_corner(backend_file, burn_in, logohsun, truths=mock_input, plotrange=plotrange)
plot_corner(backend_backup_file, burn_in, logohsun, truths=mock_input, plotrange=plotrange)
S2
above, run there, and check the convergence plot.¶iteration
, dump
¶iteration
: the number of iterations that you run more to extend the previous MCMC samplesdump
: the choice of whether to remove the samples appended in the previous iteration or notemcee
walkersiteration = 500
# -- choose one (True/False)
# dump = True
dump = False
backend_reset = False # Do not change this value
if dump:
copyfile(backend_backup_file, backend_file)
print('### Copy the backup to the backend...')
else:
copyfile(backend_file, backend_backup_file)
print('### Copy the backend to the backup...')
input_snr, input_logz, input_logq, input_ebv, id, data, good, flag, idnorm = read_mock_data(mock_file)
random_state_last = mcmc_sampling(backend_file, backend_reset, glopt_file, iteration,
lnprob, data, good, flag, intergrid, zrange, qrange, ebv_max, idnorm,
mcmc_pass_kwargs, random_state=random_state_last)
plot_chains(backend_file, logohsun, truths=mock_input)
iteration
, stretch_width
¶iteration
: the number of iterations to which you run the MCMC samplingstretch_width
: the range limit where the new parameter values are proposed (See J. Goodman and J. Weare, 'Ensemble samplers with affine invariance,' Communications in applied mathematics and computational science, vol. 5, no. 1, pp. 65–80, 2010.)stretch_width
valueiteration = 1000
stretch_width = 2. * 10
# -- sampler setting
lastco_file = backend_file.replace('.h5', '.last_coords.dat')
init_values = genfromtxt(lastco_file, delimiter='|', dtype='f8')
nwalkers, ndim = init_values.shape
# -- backend setting
backend_extsrch_file = backend_file.replace('.h5', '.extsrch.h5')
backend_extsrch = HDFBackend(backend_extsrch_file)
backend_extsrch.reset(nwalkers, ndim)
print('### Save backend file... : {}'.format(backend_extsrch_file))
input_snr, input_logz, input_logq, input_ebv, id, data, good, flag, idnorm = read_mock_data(mock_file)
# -- multiprocessing
with Pool() as pool:
sampler = EnsembleSampler(nwalkers, ndim, lnprob, moves=StretchMove(a=stretch_width),
args=[data, good, flag, intergrid, zrange, qrange, ebv_max, idnorm],
kwargs=mcmc_pass_kwargs, pool=pool, backend=backend_extsrch)
sampler.run_mcmc(init_values, iteration, progress=True)
print('### Total iteration: {}'.format(sampler.iteration))
plot_chains(backend_extsrch_file, logohsun, truths=mock_input)