Note¶

You can run this notebook and obtain the MCMC results once the necessary packages are installed and the required variables are correctly defined.

Test with mock data¶

The whole procedure for parameter recovery tests is as follows.

  1. Make mock data. see B0, B1, B2
  2. Run IZI on the mock data.
    1. Find the starting points for MCMC sampling around maximum a posteriori (MAP). see B3, B4, B5
      1. If you set SNR = 100, getting the starting points might take a long time because the peak around MAP is sharp and narrow. In this case, you can use the starting points obtained with a lower SNR, say SNR = 10 or 20. (i.e., just copy the .glopt file and rename it appropriately.)
    2. Run MCMC sampling using the package emcee.
      1. Run the 1st iteration using the starting points obtained from the step 2-A. see S1
      2. Check the sampling convergence from the corner plot. see S2, S3
        1. If you want to keep the generated samples, keep them and run another iteration.
        2. If not, dump the generated samples and run another iteration.
      3. Repeat step 2-B-b until you get a converged sample set (say, ESS > 2000 for all parameters).
      4. Check for any missed local extrema of the posterior during the MCMC sampling by increasing the stretch width by ten times. see S4

B0) Set-up for the test¶

Users should set the following variables.¶

  • setup_file: This should contain proper specifications for the parameter recovery test.
  • dir_out: the directory for the output
  • mock_name: the name of mock data configuration defined in setup_file
In [1]:
import os
from glob import glob
from multiprocessing.dummy import Pool
from os import mkdir, remove, stat
from os.path import basename
from os.path import dirname
from os.path import isdir, isfile
from shutil import copyfile

import numpy as np
from astropy.io import ascii
from corner import corner
from emcee import EnsembleSampler
from emcee.autocorr import integrated_time
from emcee.backends import HDFBackend
from emcee.moves import StretchMove
from izi import fb_extinction as ext
from izi.izi_MCMC_mod import grid, izi_MCMC
from matplotlib import rc
from matplotlib import rcParams
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.pyplot import subplots
from matplotlib.transforms import Bbox
from numpy import arange, ndarray, sort
from numpy import argmax, empty, exp, histogram, linspace, mean
from numpy import concatenate, where
from numpy import genfromtxt, unique, zeros
from numpy import log
from scipy import special
from scipy.optimize import Bounds, differential_evolution

# -- output directory
dir_out = 'mock_data/'
if not isdir(dir_out):
    print('### Create a directory ... : ' + dir_out)
    mkdir(dir_out)
# -- mock data input
setup_file = 'lst.mock_data_setup_example'
mock_name = 'mock_temp1'

ebv_min = 0.0
ebv_max = 1.0

rcParams['font.size'] = 13
os.environ["OMP_NUM_THREADS"] = "1"

B1) Read setups¶

Users should set the following variables.¶

  • templ_dir: the path to the photoionisation model grids of IZI
  • line_dir: the path to the file that includes the emission line names used in IZI
In [ ]:
def read_mock_data_setup(file_setup, setup_name):
    col0 = genfromtxt(file_setup, delimiter='|', skip_header=1,
                      dtype=['|U30', '|U40', 'f', 'f', 'f', 'f', 'f', '|U1000'],
                      names=['name', 'grid_file', 'snr', 'dex_epsilon', 'logz', 'logq', 'ebv', 'lines'], autostrip=True)
    pick = col0['name'] == setup_name
    dummy, grid_file, snr, dex_epsilon, logz, logq, ebv, lines_tmp = col0[pick][0]
    lines = lines_tmp.split(', ')
    return grid_file, snr, dex_epsilon, logz, logq, ebv, lines


def read_grid_setup(grid_file, line_dir, templ_dir):
    grid_read = grid(grid_file, templ_dir=templ_dir)
    intergrid = grid_read.grid0
    nlines0 = len(intergrid['ID'][0])
    nz = len(unique(intergrid['LOGZ']))
    nq = len(unique(intergrid['LOGQ']))
    logohsun = (grid_read.logohsun)
    line_params = ascii.read(line_dir + '/line_names.txt')
    line_wav = zeros(nlines0)
    for ii in range(nlines0):
        line_name = intergrid['ID'][0][ii]
        ww = (line_params['line_name'] == line_name)
        assert ww.sum() == 1, 'ERROR: ===== Line ID ' + \
                              intergrid['ID'][0][ii] + 'not included in wavelength list====='

        line_wav[ii] = line_params['wav'][ww]
    grid_read.grid0['WAV'] = [line_wav] * nq * nz
    intergrid = grid_read.grid0

    zrange = [np.min(intergrid['LOGZ']), np.max(intergrid['LOGZ'])]
    qrange = [np.min(intergrid['LOGQ']), np.max(intergrid['LOGQ'])]

    return intergrid, logohsun, nlines0, zrange, qrange


#-- mock data setup
set_grid, set_snr, set_dex_epsilon, set_logz, set_logq, set_ebv, set_lines = read_mock_data_setup(setup_file, mock_name)

#-- grid setup
idnorm = 'hbeta'
templ_dir = '[PATH_TO_IZI]/python_izi/grids/'
line_dir = '[PATH_TO_IZI]/python_izi/izi/'
intergrid, logohsun, nlines0, zrange, qrange = read_grid_setup(set_grid, line_dir, templ_dir)

B2) Make the mock data¶

  • output: '*.dat' file
In [ ]:
def mk_mockdata(grid, snr, logz, logq, ebv, lines, setup_name):
    #-- output file setup
    outfile = dir_out + setup_name + '.dat'
    print('### Save... ' + outfile)
    f_in = open(outfile, 'w')
    print(f'S/N  = {snr:5.0f}')
    f_in.write(f'S/N  = {snr:5.0f}\n')
    print(f'logz = {logz:5.1f}')
    f_in.write(f'logz = {logz:5.1f}\n')
    print(f'logq = {logq:5.1f}')
    f_in.write(f'logq = {logq:5.1f}\n')
    print(f'ebv  = {ebv:5.1f}')
    f_in.write(f'ebv  = {ebv:5.1f}\n')

    flux_redd = izi_MCMC.flux_grid_ext(intergrid, logz - logohsun, logq, ebv, idnorm)
    error = flux_redd / snr

    #; CREATE DATA STRUCTURE CONTAINING LINE FLUXES AND ESTIMATED PARAMETERS
    dd = {'id': intergrid['ID'][0],  # line id
          'flux': np.zeros(nlines0) + np.nan,  # line flux      
          'error': np.zeros(nlines0) + np.nan,
          'epsilon': np.zeros(nlines0) + 10 ** set_dex_epsilon}

    #FILL STRUCTURE WITH LINE FLUXES
    nlines_in = len(lines)
    for i in range(nlines_in):
        auxind = (dd['id'] == lines[i])
        nmatch = auxind.sum()
        assert nmatch == 1, 'ERROR: ===== Line ID ' + lines[i] + 'not recognized ====='
        dd['flux'][auxind] = flux_redd[auxind]
        dd['error'][auxind] = error[auxind]

    in_idnorm = where(dd['id'] == 'halpha')
    if set_dex_epsilon == -np.inf:
        dd['epsilon'][in_idnorm] = 10 ** (set_dex_epsilon)
    else:
        dd['epsilon'][in_idnorm] = 10 ** 0.01

    print('{0:10s} | {1:>12s} | {2:>12s} | {3:>12s}'.format('line', 'flux', 'error', 'epsilon'))
    f_in.write('{0:10s} | {1:>12s} | {2:>12s} | {3:>12s}\n'.format('line', 'flux', 'error', 'epsilon'))
    for i in range(nlines0):
        print('{0:10s} | {1:12.3e} | {2:12.3e} | {3:12.3e}'.format(dd['id'][i], dd['flux'][i], dd['error'][i],
                                                                   dd['epsilon'][i]))
        f_in.write('{0:10s} | {1:12.3e} | {2:12.3e} | {3:12.3e}\n'.format(dd['id'][i], dd['flux'][i], dd['error'][i],
                                                                          dd['epsilon'][i]))
    f_in.close()


mk_mockdata(set_grid, set_snr, set_logz, set_logq, set_ebv, set_lines, mock_name)

B3) Functions for the analysis¶

In [ ]:
def read_mock_data(in0):
    f_in = open(in0, 'r')
    for i in range(4):
        line = f_in.readline()
        if 'S/N' in line:
            snr = float(line.split()[-1])
        elif 'logz' in line:
            logz = float(line.split()[-1])
        elif 'logq' in line:
            logq = float(line.split()[-1])
        elif 'ebv' in line:
            ebv = float(line.split()[-1])
    col0 = genfromtxt(in0, dtype=['U20', 'f', 'f', 'f'], delimiter='|', skip_header=5,
                      names=['id', 'flux', 'error', 'epsilon'],
                      autostrip=True)
    id = genfromtxt(in0, dtype=['U20'], delimiter='|', skip_header=5, names=['id'], usecols=(0), autostrip=True)
    data = genfromtxt(in0, dtype=['f', 'f', 'f'], delimiter='|', skip_header=5, names=['flux', 'error', 'epsilon'],
                      usecols=(1, 2, 3), autostrip=True)

    #  INDEX LINES WITH MEASUREMENTS
    good = (np.isfinite(data['error']))
    # ngood=good.sum()
    measured = (np.isfinite(data['flux']))
    upperlim = ((np.isfinite(data['error'])) & (data['flux'] == -666))

    flag0 = np.zeros(len(id))
    flag0[measured] = 1  #measured flux
    flag0[upperlim] = 2  #upper limit on flux
    #This array has length ngood, which is the number of lines with 
    #given error measurements. If error is given but no flux this is treated
    #as an upper limit
    flag = flag0[good]

    # NORMALIZE LINE FLUXES TO HBETA OR
    # IF ABSENT NORMALIZE TO BRIGHTEST LINE

    idnorm = 'hbeta'
    in_idnorm = (id['id'] == idnorm)

    if (np.isnan(data['flux'][in_idnorm]) | (data['flux'][in_idnorm] == -666)):
        a = id[measured]
        idnorm = a[np.argmax(data['flux'][measured])][0]
        in_idnorm = (id['id'] == idnorm)

    norm = data['flux'][in_idnorm]

    #NORMALISE INPUT FLUXES

    data['flux'][measured] = data['flux'][measured] / norm[0]
    data['error'][good] = data['error'][good] / norm[0]

    data['flux'][upperlim] = -666

    print(f'line for normalization = {idnorm}')
    return snr, logz, logq, ebv, id, data, good, flag, idnorm


def flux_grid_ext(grid_, zz, qq, ebv, idnorm):
    flux_at = grid.value_grid_at_no_interp(grid_, zz, qq)

    red_vec = ext.reddening_vector_calzetti(grid_['WAV'][0], ebv)
    flux_redd = flux_at / red_vec

    #NORMALISE GRID           
    norm = flux_redd[grid_['ID'][0] == idnorm]

    flux_redd = flux_redd / norm

    return flux_redd


#DEFINE STATISTICAL FUNCTIONS
#LIKELIHOOD of one specific point of paramter space (theta)
def lnlike(theta, dd, good, flag, grid_, idnorm, logzprior=None, logqprior=None):
    zz, qq, ebv = theta
    fff = np.array(flux_grid_ext(grid_, zz, qq, ebv, idnorm))
    ngood = good.sum()
    like = 1.

    if logqprior is not None:
        gauss_q = 1. / (logqprior[1] * np.sqrt(2. * np.pi)) * np.exp(- ((qq - logqprior[0]) / logqprior[1]) ** 2 / 2.)
    else:
        gauss_q = 1.

    if logzprior is not None:
        gauss_z = 1. / (logzprior[1] * np.sqrt(2. * np.pi)) * np.exp(- ((zz - logzprior[0]) / logzprior[1]) ** 2 / 2.)
    else:
        gauss_z = 1.

    for j in range(ngood):

        if (flag[j] == 1):
            e2 = dd['error'][good][j] ** 2.0 + (dd['epsilon'][good][j] * fff[good][j]) ** 2.0
            fdf2 = (dd['flux'][good][j] - fff[good][j]) ** 2.0
            like = like * 1 / np.sqrt(2 * np.pi) * np.exp(-0.5 * fdf2 / e2) / np.sqrt(e2) * gauss_q * gauss_z

        if (flag[j] == 2):
            edf = (dd['error'][good][j] - fff[good][j])
            e2 = dd['error'][good][j] ** 2.0 + (dd['epsilon'][good][j] * fff[good][j]) ** 2.0
            like = like * 0.5 * (1 + special.erf(edf / np.sqrt(e2 * 2))) * gauss_q * gauss_z

    return np.log(like)


#PRIOR
def lnprior(theta, zrange, qrange, max_ebv):
    zz, qq, ebv = theta

    if zrange[0] < zz < zrange[1] and qrange[0] < qq < qrange[1] and \
            0. < ebv < max_ebv:
        return 0.0
    else:
        return -np.inf


#POSTERIOR
def lnprob(theta, dd, good, flag, grid_, zrange, qrange, max_ebv, idnorm, glopt=True, chk_prior=False):
    lp = lnprior(theta, zrange, qrange, max_ebv)

    # tot=lp + lnlike(theta, dd, good, flag, grid_, idnorm)
    tot = lp + lnlike(theta, dd, good, flag, grid_, idnorm) if not glopt \
        else -(lp + lnlike(theta, dd, good, flag, grid_, idnorm))

    if not np.isfinite(tot):
        return -np.inf if not glopt else np.inf
        # return -np.inf

    return tot


def find_map_par(ln_target, data, good, flag, grid_, zrange, qrange, max_ebv, idnorm, output_num, out_glopt):
    #  a reasonable choice for NP is between 5*D and 10*D (Storn_1997_J.GlobalOptim._11_341)
    print('### number of output is {} ...'.format(output_num))
    print('### Save... ' + out_glopt)
    f_glopt = open(out_glopt, 'w')

    lb = [zrange[0], qrange[0], ebv_min]
    ub = [zrange[1], qrange[1], ebv_max]

    for i in range(output_num):
        glopt = differential_evolution(ln_target, Bounds(lb, ub),
                                       args=(data, good, flag, grid_, zrange, qrange, max_ebv, idnorm),
                                       workers=-1, tol=0.5)
        if glopt.success:
            prtstr = ' |'.join(['{:>15.8e}'.format(i) for i in glopt.x])
            print(prtstr)
            f_glopt.write(prtstr + '\n')
    print('### Done...')
    f_glopt.close()

B4) Find the maximum a posteriori (MAP) values using a global optimization method (differential evolution)¶

  • output: '*.glopt' file
In [ ]:
mock_file = dir_out + mock_name + '.dat'
input_snr, input_logz, input_logq, input_ebv, id, data, good, flag, idnorm = read_mock_data(mock_file)

output_num = 32
out_glopt = mock_file.replace('.dat', '.glopt')

find_map_par(lnprob, data, good, flag, intergrid, zrange, qrange, ebv_max, idnorm, output_num, out_glopt)

B5) Plot the MAP values¶

  • output: '*.map_par.png' file
In [ ]:
def plot_map_par(glopt_file):
    col0 = genfromtxt(glopt_file, delimiter='|', names=['logz', 'logq', 'ebv'])
    fig, ax = subplots(2, 2, figsize=(7, 7))
    ax[0, 0].hist(col0['logz'] + logohsun)
    ax[0, 0].vlines(input_logz, 0, 1, transform=ax[0, 0].get_xaxis_transform(), linestyles='--', colors='red')
    ax[0, 0].set_xlabel('12 + log$\,$(O/H)')
    ax[0, 1].hist(col0['logq'])
    ax[0, 1].vlines(input_logq, 0, 1, transform=ax[0, 1].get_xaxis_transform(), linestyles='--', colors='red')
    ax[0, 1].set_xlabel('log$\,$(q)')
    ax[1, 0].hist(col0['ebv'])
    ax[1, 0].vlines(input_ebv, 0, 1, transform=ax[1, 0].get_xaxis_transform(), linestyles='--', colors='red')
    ax[1, 0].set_xlabel('E$\,$(B-V)')
    outfile = glopt_file.replace('.glopt', '.map_par.png')
    print('### Save... ' + outfile)
    suptitle = basename(glopt_file.replace('.glopt', ''))
    fig.suptitle(suptitle)
    fig.tight_layout(rect=[0, 0, 1, 0.95])
    fig.savefig(outfile)


mock_file = dir_out + mock_name + '.dat'
out_glopt = mock_file.replace('.dat', '.glopt')
input_snr, input_logz, input_logq, input_ebv, id, data, good, flag, idnorm = read_mock_data(mock_file)
glopt_files = sort(glob(out_glopt))
for glopt_file in glopt_files:
    print('### Read... {}'.format(glopt_file))
    plot_map_par(glopt_file)

S1) Run the MCMC sampling for the first time & make a backup backend & plot the samples to see the burn-in process¶

Users should set the following variable: iteration¶

  • input
    • iteration: the number of iterations to which you run the MCMC sampling
  • output
    • '*.h5': the file that contains the generated MCMC samples
    • '*.backup.h5': the copy of the above file
    • '*.last_coords.dat': the file that contains the last coordinates of the emcee walkers
    • '*.plot_chains.png': the plot that shows the parameter evolutions of the samples in the '*.h5' file
In [ ]:
def plot_chains(backend_file, logohsun, truths=None, set_ymm=None):
    # -- extract chains from backend
    reader = HDFBackend(backend_file)
    chain: ndarray = reader.get_chain()
    iteration, nwalkers, ndim = chain.shape
    print('### (iter, nwalkers, ndim) = {}'.format(chain.shape))
    chain_log_prob: ndarray = reader.get_log_prob()
    # -- chain plots
    outfile0 = backend_file.replace('.h5', '.plot_chains.png')
    title = basename(backend_file.replace('.h5', ''))
    fig, axes = subplots(1, 4, figsize=(10, 7), sharey=True)
    fig.suptitle(title)
    labels = ['12 + log$\,$(O/H)', 'log$\,$(q)', 'E$\,$(B-V)', 'ln(posterior)']
    for i in range(ndim + 1):
        ax: Axes = axes[i]
        if i == ndim:
            if 'line3' in backend_file:
                ax.set_xlabel(labels[i])
                ax = axes[i + 1]
                ax.set_xlabel(labels[i + 1])
            else:
                ax.set_xlabel(labels[i])
            ax.axvline(x=max(chain_log_prob[0, :]))
            ax.plot(chain_log_prob, arange(iteration), "k", alpha=0.1)
        else:
            if i == 0:
                ax.plot(chain[:, :, i] + logohsun, arange(iteration), "k", alpha=0.1)
            else:
                ax.plot(chain[:, :, i], arange(iteration), "k", alpha=0.1)
            if truths: ax.axvline(truths[i], color='red')
            ax.set_xlabel(labels[i])
        ax.set_ylim(0, len(chain))
    axes[0].set_ylabel("iteration");
    if set_ymm: axes[0].set_ylim(set_ymm)
    print('### Save... ' + outfile0)
    fig.tight_layout(rect=[0, 0, 1, 0.96])
    fig.savefig(outfile0)


def mcmc_sampling(backend_file, backend_reset, glopt_file, iteration,
                  ln_target, data, good, flag, grid_, zrange, qrange, max_ebv, idnorm, mcmc_pass_kwargs,
                  random_state=None, stretch_a=2.):
    # -- result file
    backend = HDFBackend(backend_file)
    print('### Save backend file... : {}'.format(backend_file))
    # -- the initial values
    if backend_reset:
        init_values = genfromtxt(glopt_file, delimiter='|', dtype='f8')
        nwalkers, ndim = init_values.shape
    else:
        init_values = backend.get_last_sample()
        init_values.random_state = random_state
        nwalkers, ndim = backend.shape
    # -- backend reset
    if backend_reset: backend.reset(nwalkers, ndim)
    # -- multiprocessing
    with Pool() as pool:
        sampler = EnsembleSampler(nwalkers, ndim, ln_target, moves=StretchMove(a=stretch_a),
                                  args=[data, good, flag, grid_, zrange, qrange, max_ebv, idnorm],
                                  kwargs=mcmc_pass_kwargs, pool=pool,
                                  backend=backend)
        sampler.run_mcmc(init_values, iteration, progress=True)
    print('### Total iteration: {}'.format(sampler.iteration))
    random_state = sampler.random_state
    # -- save last coords
    outfile = backend_file.replace('.h5', '.last_coords.dat')
    print('### Save ... ' + outfile)
    f_out = open(outfile, 'w')
    co_arr = backend.get_last_sample().coords
    row, col = co_arr.shape
    for row_i in range(row):
        prtstr = ' |'.join(['{:0f}'.format(i) for i in co_arr[row_i, :]])
        f_out.write(prtstr + '\n')
    f_out.close()
    return random_state


def sample_new(mock_input, logohsun):
    backend_reset = True
    input_snr, input_logz, input_logq, input_ebv, id, data, good, flag, idnorm = read_mock_data(mock_file)
    random_state_last = mcmc_sampling(backend_file, backend_reset, glopt_file, iteration,
                                      lnprob, data, good, flag, intergrid, zrange, qrange, ebv_max, idnorm,
                                      mcmc_pass_kwargs)
    print('### Copy the backend file to the backup : {}'.format(backend_backup_file))
    copyfile(backend_file, backend_backup_file)
    plot_chains(backend_file, logohsun, truths=mock_input)
    return random_state_last


# mcmc_pass_kwargs = {'glopt': False, 'chk_prior': True}
mcmc_pass_kwargs = {'glopt': False, 'chk_prior': False}
iteration = 1000
mock_file = dir_out + mock_name + '.dat'
glopt_file = mock_file.replace('.dat', '.glopt')

mock_input = [input_logz, input_logq, input_ebv]
# -- file setting
backend_file = glopt_file.replace('.glopt', '.h5')
if mcmc_pass_kwargs['chk_prior']: backend_file = backend_file.replace('.h5', '_chkprior.h5')
backend_backup_file = backend_file.replace('.h5', '.backup.h5')
# -- backend existence check
if isfile(backend_file):
    print(backend_file)
    file_size_MB = stat(backend_file).st_size >> 20
    chk = input('file size = {} MB.'.format(file_size_MB) + ' ### Delete it? (y or else) :')
    if chk == 'y':
        print('### Delete the existing backend file and create a new one...')
        remove(backend_file)
        random_state_last = sample_new(mock_input, logohsun)
    else:
        print('### backend_file is unchanged.')
        backend = HDFBackend(backend_file)
        random_state_last = backend.get_last_sample().random_state
else:
    random_state_last = sample_new(mock_input, logohsun)

S2) Plot corner & integrated autocorrelation time¶

When the convergence is good enough (e.g., each $\tau_{int}$ crosses the "ESS = 2000" line), go to S4 and check if there is any other local extrema.

Users should set the following variable: burn_in¶

  • input
    • burn_in: the burn-in iterations (The samples before the burn-in iterations are removed when plotting the corner plot.)
  • output
    • '*.plot_corner.png': the corner plot
In [ ]:
def plot_corner(backend_file, burn_in, logohsun,
                truths=None, pdf=False, chain_end=None, quantiles=None, plotrange=[1, 1, 1, 1]):
    # -- extract chains from backend
    reader = HDFBackend(backend_file)
    dir = dirname(backend_file)
    chain: ndarray = reader.get_chain()
    if chain_end:
        chain = chain[burn_in:burn_in + chain_end, :, :]
    else:
        chain = chain[burn_in:, :, :]
    #-- shift the logz values
    chain[:, :, 0] = chain[:, :, 0] + logohsun
    iteration, nwalkers, ndim = chain.shape
    print('### (iter, nwalkers, ndim) = {}'.format(chain.shape))
    # -- integrated autocorrelation time calculation
    chk_num = 21
    chain_mean = mean(chain, axis=1)
    iat_arr = empty((chk_num, ndim))
    chks = exp(linspace(log(100), log(iteration), chk_num)).astype(int)
    for chk in chks:
        chk_ind = where(chks == chk)[0]
        for dim in range(ndim):
            iat_arr[chk_ind, dim] = integrated_time(chain_mean[:chk, dim], tol=0)
    # -- corner plots
    labels = ['12 + log$\,$(O/H)', 'log$\,$(q)', 'E$\,$(B-V)']
    burnin_tag = '_BurnIn{:04d}'.format(burn_in)
    outfile1 = backend_file.replace('.h5', burnin_tag + '.plot_corner.png')
    if pdf: outfile1 = outfile1.replace('.png', '.pdf')
    chain_rsh = chain.reshape([-1, ndim])
    title = basename(backend_file.replace('.h5', burnin_tag))
    ftsize = 15
    rc('axes', labelsize=ftsize)
    rc('xtick', labelsize=ftsize * 0.8)
    rc('ytick', labelsize=ftsize * 0.8)
    if not quantiles: quantiles = [0.025, 0.16, 0.5, 0.84, 0.975]
    fig: Figure = corner(chain_rsh, bins=40, truth_color='red', truths=truths, labels=labels[0:4],
                         quantiles=quantiles,
                         show_titles=True, title_kwargs={'size': ftsize * 0.8}, range=plotrange)
    fig.suptitle(title, fontsize=ftsize * 1.2, y=1.03)
    # -- check for the inverse-gamma prior
    if 'chkprior' in backend_file and 'ONpriorI' in backend_file:
        oiii_n, oiii_bins = histogram(chain_rsh[:, 2], bins=40, range=plotrange[2])
        oiii_in = argmax(oiii_n)
        print('[O III] peak = {:10.3e} {:10.3e}'.format(oiii_bins[oiii_in], oiii_bins[oiii_in + 1]))
        nii_n, nii_bins = histogram(chain_rsh[:, 3], bins=40, range=plotrange[3])
        nii_in = argmax(nii_n)
        print('[N II] peak  = {:10.3e} {:10.3e}'.format(nii_bins[nii_in], nii_bins[nii_in + 1]))
    # -- tau_iat evolution as an overplot
    if 'line3' in backend_file:
        enlarge = 1.28
        f_w, f_h = fig.get_size_inches()
        fig.set_size_inches(f_w * enlarge, f_h * enlarge)
    ax0: Axes = fig.add_subplot(111)
    for par in range(ndim): ax0.plot(chks, iat_arr[:, par], '-o', markersize=4, label=labels[par], alpha=0.5)
    chks_ext = concatenate([[chks[0] / 2.], chks])
    chks_ext = concatenate([chks_ext, [chks[-1] * 2.]])
    ess_chk = 1000. / nwalkers
    ax0.plot(chks_ext, chks_ext / ess_chk, '-.k', label='ESS = 1000')
    ess_chk = 2000. / nwalkers
    ax0.plot(chks_ext, chks_ext / ess_chk, ':k', label='ESS = 2000')
    ax0.semilogx()
    ax0.semilogy()
    ax0.tick_params(which='both', top=True, right=True,
                    labelbottom=False, labelleft=False, labeltop=True, labelright=True)
    ax0.xaxis.set_label_position('top')
    ax0.yaxis.set_label_position('right')
    ax0.set_xlim(chks[0] * 0.8, chks[-1] * 1.2)
    ax0.set_ylim(iat_arr.min() * 0.5, iat_arr.max() * 2)
    ax0.set_xlabel('iteration')
    ax0.set_ylabel(r'$\tau_{int}$')
    ax0.legend(fontsize=ftsize * 0.8, loc='upper right', bbox_to_anchor=(1.03, 0))
    pos2 = [0.57, 0.72, 0.38, 0.20]
    ax0.set_position(pos2)
    ## -- save test-run
    print('### Save... ' + outfile1)
    f_enlarge = 1.05
    shift = -0.1
    bb = fig.bbox_inches.bounds
    bb_new = (bb[0] + shift, bb[1] + shift, bb[2] * f_enlarge, bb[3] * f_enlarge)
    fig.savefig(outfile1, bbox_inches=Bbox.from_bounds(*bb_new))


burn_in = 200
# plotrange = [(8.7529, 8.8533), (7.0195, 7.216), (0.05, 0.15)]
plotrange = [1, 1, 1]
plot_corner(backend_file, burn_in, logohsun, truths=mock_input, plotrange=plotrange)
plot_corner(backend_backup_file, burn_in, logohsun, truths=mock_input, plotrange=plotrange)

S3) Determine whether to accept the newly added samples, then go to S2 above, run there, and check the convergence plot.¶

Users should set the following variables: iteration, dump¶

  • input
    • iteration: the number of iterations that you run more to extend the previous MCMC samples
    • dump: the choice of whether to remove the samples appended in the previous iteration or not
  • output
    • '*.h5': the file that appends the newly generated MCMC samples
    • '*.backup.h5': the file that contains the MCMC samples generated in the previous iteration
    • '*.last_coords.dat': the file that contains the last coordinates of the emcee walkers
    • '*.plot_chains.png': the plot that shows the parameter evolutions of the samples in the '*.h5' file
In [ ]:
iteration = 500
# -- choose one (True/False)
# dump = True
dump = False

backend_reset = False  # Do not change this value
if dump:
    copyfile(backend_backup_file, backend_file)
    print('### Copy the backup to the backend...')
else:
    copyfile(backend_file, backend_backup_file)
    print('### Copy the backend to the backup...')

input_snr, input_logz, input_logq, input_ebv, id, data, good, flag, idnorm = read_mock_data(mock_file)

random_state_last = mcmc_sampling(backend_file, backend_reset, glopt_file, iteration,
                                  lnprob, data, good, flag, intergrid, zrange, qrange, ebv_max, idnorm,
                                  mcmc_pass_kwargs, random_state=random_state_last)
plot_chains(backend_file, logohsun, truths=mock_input)

S4) Search for local extrema with an increased stretch width.¶

Users can adjust the following variables: iteration, stretch_width¶

  • input
    • iteration: the number of iterations to which you run the MCMC sampling
    • stretch_width: the range limit where the new parameter values are proposed (See J. Goodman and J. Weare, 'Ensemble samplers with affine invariance,' Communications in applied mathematics and computational science, vol. 5, no. 1, pp. 65–80, 2010.)
  • output
    • '*.extsrch.h5': the file that contains the newly generated MCMC samples with the stretch_width value
    • '*.extsrch.plot_chains.png': the plot that shows the parameter evolutions of the samples in the '*.extsrch.h5' file
In [ ]:
iteration = 1000
stretch_width = 2. * 10

# -- sampler setting
lastco_file = backend_file.replace('.h5', '.last_coords.dat')
init_values = genfromtxt(lastco_file, delimiter='|', dtype='f8')
nwalkers, ndim = init_values.shape
# -- backend setting
backend_extsrch_file = backend_file.replace('.h5', '.extsrch.h5')
backend_extsrch = HDFBackend(backend_extsrch_file)
backend_extsrch.reset(nwalkers, ndim)
print('### Save backend file... : {}'.format(backend_extsrch_file))

input_snr, input_logz, input_logq, input_ebv, id, data, good, flag, idnorm = read_mock_data(mock_file)
# -- multiprocessing
with Pool() as pool:
    sampler = EnsembleSampler(nwalkers, ndim, lnprob, moves=StretchMove(a=stretch_width),
                              args=[data, good, flag, intergrid, zrange, qrange, ebv_max, idnorm],
                              kwargs=mcmc_pass_kwargs, pool=pool, backend=backend_extsrch)
    sampler.run_mcmc(init_values, iteration, progress=True)
print('### Total iteration: {}'.format(sampler.iteration))
plot_chains(backend_extsrch_file, logohsun, truths=mock_input)