# Show the prior

import os
import numpy as np
import matplotlib.pyplot as plt
from eazy import utils, photoz

zgrid = utils.log_zgrid((0.001, 7), 0.01)

path = utils.path_to_eazy_data()
prior_file = os.path.join(path, 'templates/prior_F160W_TAO.dat')

prior_mags, prior_data = photoz.PhotoZ.read_prior(zgrid=zgrid, 
                            prior_file=prior_file, 
                            prior_floor=1.e-2)

fig, ax = plt.subplots(1,1,figsize=(6,4))

for i, m in enumerate(prior_mags):
    if (m > 28.1) | (m - np.floor(m) > 0.1):
        continue

    ax.plot(np.log(1+zgrid), prior_data[:,i], 
            label=f'm = {m:.1f}', color=plt.cm.rainbow((m-15)/13))

for m_i in np.arange(26.2, 26.9, 0.2):
    prior_m = photoz.PhotoZ._get_prior_mag(m_i, prior_mags, 
                                       prior_data)
    ax.plot(np.log(1+zgrid), prior_m, color='k', linewidth=1, 
        label=f'{m_i:.1f}', alpha=0.2)

xt = np.arange(0,7.1,1)
ax.set_xticks(np.log(1+xt))
ax.set_xticklabels(xt.astype(int))
ax.set_xlim(0, np.log(8))

ax.grid()
ax.legend(ncol=3, fontsize=8, title=os.path.basename(prior_file))
ax.set_xlabel('redshift')
ax.set_ylabel('Mag prior')
ax.semilogy()
fig.tight_layout(pad=0.1)