Plotting the posterior distribution and the evolution of the integrated autocorrelation times¶

  1. pick one HDF5 file which contains the MCMC sampling results.
  2. set the burn-in iteration number (burn-in). Only the samples after this iteration number are used for plotting.
  3. run the script.
In [1]:
from os.path import basename

from corner import corner
from emcee.autocorr import integrated_time
from emcee.backends import HDFBackend
from matplotlib import rc
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.transforms import Bbox
from numpy import concatenate, empty, exp, linspace, log, mean, ndarray, where


def plot_corner(backend_file, burn_in, truths=None, pdf=False, chain_end=None, quantiles=None):
    # -- extract chains from backend
    reader = HDFBackend(backend_file)
    chain: ndarray = reader.get_chain()
    if chain_end:
        chain = chain[burn_in:burn_in + chain_end, :, :]
    else:
        chain = chain[burn_in:, :, :]
    iteration, nwalkers, ndim = chain.shape
    print('### (iter, nwalkers, ndim) = {}'.format(chain.shape))
    
    # -- integrated autocorrelation time calculation
    chk_num = 21
    chain_mean = mean(chain, axis=1)
    iat_arr = empty((chk_num, ndim))
    chks = exp(linspace(log(100), log(iteration), chk_num)).astype(int)
    for chk in chks:
        chk_ind = where(chks == chk)[0]
        for dim in range(ndim):
            iat_arr[chk_ind, dim] = integrated_time(chain_mean[:chk, dim], tol=0)
    
    #---- plotting
    ## -- corner plots
    labels = ['12 + log$\,$(O/H)', 'E$\,$(B-V)', 'intrinsic \n [O III] $\lambda$5007',
              'intrinsic \n [N II] $\lambda$6584', 'ln(posterior)']
    burnin_tag = '_BurnIn{:04d}'.format(burn_in)
    outfile1 = backend_file.replace('.h5', burnin_tag + '.plot_corner.png')
    if pdf: outfile1 = outfile1.replace('.png', '.pdf')
    chain_rsh = chain.reshape([-1, ndim])
    title = basename(backend_file.replace('.h5', burnin_tag))
    ftsize = 15
    rc('axes', labelsize=ftsize)
    rc('xtick', labelsize=ftsize * 0.8)
    rc('ytick', labelsize=ftsize * 0.8)
    if not quantiles: quantiles = [0.025, 0.16, 0.5, 0.84, 0.975]
    fig: Figure = corner(chain_rsh, bins=40, truth_color='red', truths=truths, labels=labels[0:4],
                         quantiles=quantiles, show_titles=True)
    fig.suptitle(title, fontsize=ftsize * 1.2, y=1.03)
    ## -- tau_iat evolution as an overplot
    if 'line3' in backend_file:
        enlarge = 1.28
        f_w, f_h = fig.get_size_inches()
        fig.set_size_inches(f_w * enlarge, f_h * enlarge)
    ax0: Axes = fig.add_subplot(111)
    for par in range(ndim): ax0.plot(chks, iat_arr[:, par], '-o', markersize=4, label=labels[par], alpha=0.5)
    chks_ext = concatenate([[chks[0] / 2.], chks])
    chks_ext = concatenate([chks_ext, [chks[-1] * 2.]])
    ess_chk = 1000. / nwalkers
    ax0.plot(chks_ext, chks_ext / ess_chk, '-.k', label='ESS = 1000')
    ess_chk = 2000. / nwalkers
    ax0.plot(chks_ext, chks_ext / ess_chk, '--k', label='ESS = 2000')
    ax0.semilogx()
    ax0.semilogy()
    ax0.tick_params(which='both', top=True, right=True,
                    labelbottom=False, labelleft=False, labeltop=True, labelright=True)
    ax0.xaxis.set_label_position('top')
    ax0.yaxis.set_label_position('right')
    ax0.set_xlim(chks[0] * 0.8, chks[-1] * 1.2)
    ax0.set_ylim(iat_arr.min() * 0.5, iat_arr.max() * 2)
    ax0.set_xlabel('iteration')
    ax0.set_ylabel(r'$\tau_{int}$')
    ax0.legend(fontsize=ftsize * 0.8, loc='upper right', bbox_to_anchor=(1.03, 0))
    if 'line3' in backend_file:
        pos2 = [0.56, 0.72, 0.38, 0.20]
    else:
        pos2 = [0.57, 0.72, 0.38, 0.20]
    ax0.set_position(pos2)
    
    # -- save
    print('### Save... ' + outfile1)
    f_enlarge = 1.05
    shift = -0.1
    bb = fig.bbox_inches.bounds
    bb_new = (bb[0] + shift, bb[1] + shift, bb[2] * f_enlarge, bb[3] * f_enlarge)
    fig.savefig(outfile1, bbox_inches=Bbox.from_bounds(*bb_new))

dir = 'HPS_results/'
backend_file = dir+'HPS035_line3_mderrY_EBVpriorG.h5'
burn_in = 200
plot_corner(backend_file, burn_in)
### (iter, nwalkers, ndim) = (16800, 32, 3)
### Save... HPS_results/HPS035_line3_mderrY_EBVpriorG_BurnIn0200.plot_corner.png