Source code for eazy.hdf5

"""
Tools for saving/recoving state from HDF5 files
"""
import numpy as np

try:
    import h5py
except:
    pass
    
from . import photoz
from . import utils
from . import templates as templates_code

[docs] def write_hdf5(pzobj, h5file='test.hdf5', include_fit_coeffs=False, include_templates=True, verbose=True): """ Write self-contained HDF5 file Parameters ---------- pzobj : `~eazy.photoz.PhotoZ` Original code object with computed redshifts, coeffs, etc. h5file : str HDF5 filename include_fit_coeffs : bool Inlude full `fit_coeffs` array with ``(NOBJ, NZ, NTEMP)`` fit coefficients. This can make the file very large, and it's really only needed if you want to use `~eazy.photoz.PhotoZ.prior_beta`. include_templates : bool Include template arrays """ if verbose: print(f'h5: create file {h5file}') with h5py.File(h5file,'w') as f: grp = f.create_group("cat") dset = grp.create_dataset('id', data=pzobj.OBJID) dset = grp.create_dataset('ra', data=pzobj.RA) dset = grp.create_dataset('dec', data=pzobj.DEC) dset = grp.create_dataset('z_spec', data=pzobj.ZSPEC) for k in pzobj.cat.meta: if verbose: print('h5: cat meta: ', k, pzobj.cat.meta[k]) grp.attrs[k] = pzobj.cat.meta[k] for name in ['flux_columns','err_columns']: if verbose: print(f'h5: cat/{name}') attr = getattr(pzobj, name) dset = grp.create_dataset(name, data=[a.encode('utf8') for a in attr]) for name in ['f_numbers','fnu','efnu_orig','ok_data','zp', 'ext_corr','ext_redden','pivot']: if verbose: print(f'h5: cat/{name}') attr = getattr(pzobj, name) dset = grp.create_dataset(name, data=attr) grp.attrs['MW_EBV'] = pzobj.param['MW_EBV'] grp = f.create_group("fit") for name in ['zml','zbest','chi2_fit','coeffs_best']: if verbose: print(f'h5: fit/{name}') attr = getattr(pzobj, name) dset = grp.create_dataset(name, data=attr) dset = grp.create_dataset('tef_x', data=pzobj.TEF.te_x) dset = grp.create_dataset('tef_y', data=pzobj.TEF.te_y) f['fit/zml'].attrs['ZML_WITH_PRIOR'] = pzobj.ZML_WITH_PRIOR f['fit/zml'].attrs['ZML_WITH_BETA_PRIOR'] = pzobj.ZML_WITH_BETA_PRIOR f['fit/zbest'].attrs['ZPHOT_USER'] = pzobj.ZPHOT_USER if include_fit_coeffs | pzobj.ZML_WITH_BETA_PRIOR: name = 'fit_coeffs' if verbose: print(f'h5: fit/{name}') attr = getattr(pzobj, name) dset = grp.create_dataset(name, data=attr) # Parameters for k in pzobj.param.params: grp.attrs[k] = pzobj.param.params[k] dset = grp.create_dataset('tempfilt', data=pzobj.tempfilt.tempfilt) dset = grp.create_dataset('tempfilt_scale', data=pzobj.tempfilt.scale) func = pzobj.tempfilt.interpolator_function dset.attrs['interpolator_function'] = func.__name__ # Templates if include_templates: grp = f.create_group("templates") grp.attrs['NTEMP'] = pzobj.NTEMP for i, templ in enumerate(pzobj.templates): grp.attrs[f'TEMPL{i:03d}'] = templ.name if verbose: print(f'h5: templates/{templ.name}') dset = grp.create_dataset(f'wave {templ.name}', data=templ.wave.astype(pzobj.ARRAY_DTYPE)) dset = grp.create_dataset(f'flux {templ.name}', data=templ.flux.astype(pzobj.ARRAY_DTYPE)) dset = grp.create_dataset(f'z {templ.name}', data=templ.redshifts)
[docs] def cat_from_hdf5(h5file): """ Parameters ---------- h5file : str HDF5 filename Returns ------- cat : `~astropy.table.Table` Catalog table generated from HDF5 data trans : `eazy.param.TranslateFile` TranslateFile object """ with h5py.File(h5file,'r') as f: cat, trans = photoz.PhotoZ._csv_from_arrays(f['cat/id'][:], f['cat/ra'][:], f['cat/dec'][:], f['cat/z_spec'][:], f['cat/fnu'], f['cat/efnu_orig'], f['cat/ok_data'], f['cat/flux_columns'].asstr(), f['cat/err_columns'].asstr(), f['cat/zp'][:]**0, f['cat/f_numbers'][:]) return cat, trans
[docs] def params_from_hdf5(h5file): """ Read full parameters from HDF5 file Parameters ---------- h5file : str HDF5 filename Returns ------- params : dict """ from collections import OrderedDict params = OrderedDict() with h5py.File(h5file,'r') as f: dset = f['fit'] for k in dset.attrs: params[k] = dset.attrs[k] return params
[docs] def templates_from_hdf5(h5file, verbose=False): """ Read list of templates Parameters ---------- h5file : str HDF5 filename Returns ------- templates : list List of `eazy.templates.Template` objects read from ``h5file``. """ templates = [] with h5py.File(h5file,'r') as f: NTEMP = f['templates'].attrs['NTEMP'] for i in range(NTEMP): name = f['templates'].attrs[f'TEMPL{i:03d}'] if verbose: print(f'h5: read {name}') wave = f[f'templates/wave {name}'][:] flux = f[f'templates/flux {name}'][:] redshifts = f[f'templates/z {name}'][:] templ_i = templates_code.Template(arrays=(wave, flux), name=name, redshifts=redshifts) templates.append(templ_i) return templates
[docs] def show_info(h5file): """ Print summary info of contents of ``h5file`` to the screen """ with h5py.File(h5file,'r') as f: def show_size(name): if hasattr(f[name], 'shape'): print(name, f[name].shape) else: comment = '-'*(len(name)+3) print(f'\n{comment}\n {name}/\n{comment}') f.visit(show_size)
[docs] def get_dataset_shape(h5file, dataset): """ Return the shape of a datset in ``h5file`` """ with h5py.File(h5file,'r') as f: if dataset not in f: raise ValueError(f'Dataset {dataset} not found in {h5file}') return f[dataset].shape
[docs] def get_dataset_slice(h5file, dataset, sl=None): """ Get a slice of a dataset in ``h5file``. If no ``sl`` slice provided, then return the full array """ with h5py.File(h5file,'r') as f: if dataset not in f: raise ValueError(f'Dataset {dataset} not found in {h5file}') if sl is None: return f[dataset][:] else: return f[dataset][sl]
[docs] def initialize_from_hdf5(h5file='test.hdf5', verbose=True): """ Initialize a `~eazy.photoz.PhotoZ` object from HDF5 data Parameters ---------- h5file : str HDF5 filename from `eazy.hdf5. Returns ------- pzobj : '~eazy.photoz.PhotoZ' """ # Parameter dictionary params = params_from_hdf5(h5file) # Generate a catalog table from H5 data cat, trans = cat_from_hdf5(h5file) # Put catalog in CATALOG_FILE parameter params['CATALOG_FILE'] = cat with h5py.File(h5file, 'r') as f: pzobj = photoz.PhotoZ(param_file=None, translate_file=trans, zeropoint_file=None, params=params, load_prior=True, load_products=False, tempfilt_data=f['fit/tempfilt'][:]) pzobj.tempfilt.scale = f['fit/tempfilt_scale'][:] pzobj.chi2_fit = f['fit/chi2_fit'][:] pzobj.zp = f['cat/zp'][:] if 'templates' in f: pzobj.templates = templates_from_hdf5(h5file, verbose=verbose) if 'fit/fit_coeffs' in f: pzobj.fit_coeffs = f['fit/fit_coeffs'][:] pzobj.compute_lnp(prior=f['fit/zml'].attrs['ZML_WITH_PRIOR'], beta_prior=f['fit/zml'].attrs['ZML_WITH_BETA_PRIOR']) pzobj.evaluate_zml(prior=f['fit/zml'].attrs['ZML_WITH_PRIOR'], beta_prior=f['fit/zml'].attrs['ZML_WITH_BETA_PRIOR']) if f['fit/zbest'].attrs['ZPHOT_USER']: pzobj.fit_at_zbest(zbest=f['fit/zbest'], prior=f['fit/zml'].attrs['ZML_WITH_PRIOR'], beta_prior=f['fit/zml'].attrs['ZML_WITH_BETA_PRIOR']) else: pzobj.fit_at_zbest(zbest=None, prior=f['fit/zml'].attrs['ZML_WITH_PRIOR'], beta_prior=f['fit/zml'].attrs['ZML_WITH_BETA_PRIOR']) return pzobj
[docs] class Viewer(object): def __init__(self, h5file, verbose=True): """ Tool to replicate functionality of `PhotoZ.show_fit` but with data read from a stored HDF5 file rather than a "live" object """ from astropy.cosmology import LambdaCDM self.h5file = h5file self.param = params_from_hdf5(h5file) photoz.PhotoZ.set_zgrid(self) self.NZ = len(self.zgrid) self.templates = templates_from_hdf5(h5file, verbose=verbose) self.NTEMP = len(self.templates) self.set_attrs_from_hdf5() self.set_template_error() self.cosmology = LambdaCDM(H0=self.param['H0'], Om0=self.param['OMEGA_M'], Ode0=self.param['OMEGA_L'], Tcmb0=2.725, Ob0=0.048) self.set_tempfilt()
[docs] def info(self): """ Print file contents """ show_info(self.h5file)
[docs] def get_catalog(self): """ Create a full catalog table from the photometry data in the HDF5 file """ cat, trans = cat_from_hdf5(self.h5file) return cat
[docs] def get_table(self, names, sl=None, verbose=True, strip_slash=True, as_table=True): """ Build a table from multiple dataset names Parameters ---------- names : str Dataset names in `h5file` sl : slice-like Array slice verbose : bool Print some status info Returns ------- data : array-like Sliced dataset data """ from astropy.table import Table tab = {} for name in names: if strip_slash: key = name.split('/')[-1] else: key = name tab[key] = self.get_dataset(name, sl=sl, verbose=verbose) if as_table: tab = Table(tab) tab.meta['to_ujy'] = self.to_uJy return tab else: return tab
[docs] def get_dataset(self, name, sl=None, verbose=True): """ Get a single column of data from the HDF5 file If ``name`` is in the `self.flux_columns` or `self.err_columns` lists, then compute the appropriate index of the `fit/fnu` or `fit/efnu_orig` datasets, respectively. Parameters ---------- name : str Dataset name in `h5file` sl : slice-like Array slice verbose : bool Print some status info Returns ------- data : array-like Sliced dataset data """ if name in self.flux_columns: c_i = list(self.flux_columns).index(name) if verbose: print(f'{name} = cat/fnu[:,{c_i}]') data = get_dataset_slice(self.h5file, 'cat/fnu', sl) data = np.atleast_2d(data)[:,c_i] elif name in self.err_columns: c_i = list(self.err_columns).index(name) print(f'{name} = cat/efnu_orig[:,{c_i}]') data = get_dataset_slice(self.h5file, 'cat/efnu_orig', sl) data = np.atleast_2d(data)[:,c_i] else: data = get_dataset_slice(self.h5file, name, sl) if data.size == 1: return data[0] else: return data
[docs] def set_tempfilt(self): """ Generate the `eazy.photoz.TemlpateGrid` object needed for plotting SEDs and refitting at other redshifts. """ with h5py.File(self.h5file, 'r') as f: self.tempfilt = photoz.TemplateGrid(self.zgrid, self.templates, RES=self.param['FILTERS_RES'], f_numbers=self.f_numbers, add_igm=self.param['IGM_SCALE_TAU'], galactic_ebv=self.MW_EBV, Eb=self.param['SCALE_2175_BUMP'], n_proc=1, cosmology=self.cosmology, array_dtype=self.ARRAY_DTYPE, tempfilt_data=f['fit/tempfilt'][:], verbose=False) self.tempfilt.scale = f['fit/tempfilt_scale'][:]
[docs] def set_template_error(self): """ Set the Template Error Function from `fit/tef_x` and `fit/tef_y` data in the HDF5 file Returns ------- Sets `TEF` attributes """ with h5py.File(self.h5file, 'r') as f: arrays = (f['fit/tef_x'][:], f['fit/tef_y'][:]) self.TEF = templates_code.TemplateError(arrays=arrays, filter_wavelengths=self.pivot, scale=self.param['TEMP_ERR_A2'])
[docs] def set_attrs_from_hdf5(self): """ Set fixed attributes needed for `show_fit` """ with h5py.File(self.h5file,'r') as f: self.NOBJ, self.NFILT = f['cat/fnu'].shape self.pivot = f['cat/pivot'][:] self.OBJID = f['cat/id'][:] self.zp = f['cat/zp'][:] self.ext_corr = f['cat/ext_corr'][:] self.ext_redden = f['cat/ext_redden'][:] self.f_numbers = f['cat/f_numbers'][:] self.flux_columns = f['cat/flux_columns'].asstr()[:] self.err_columns = f['cat/err_columns'].asstr()[:] self.idx = np.arange(self.NOBJ, dtype=int)
[docs] def set_attrs_for_dash(self): """ Set additional attributes needed for the Dash/Plotly tool. They're not set in `set_attrs_from_hdf5` as they may be memory-intensive for very large catalogs and aren't explicitly necessary for just `show_fit` showing SEDs. """ with h5py.File(self.h5file,'r') as f: self.RA = f['cat/ra'][:] self.DEC = f['cat/dec'][:] self.ZSPEC = f['cat/z_spec'][:] self.nusefilt = f['cat/ok_data'][:].sum(axis=1) if not hasattr(self, 'cat'): self.cat = self.get_catalog()
[docs] def get_object_data(self, ix): """ Pull out data for a given array index (not id!) corresponding to a particular object Returns ------- z_i, fnu_i, efnu_i, ra_i, dec_i, chi2_i, zspec_i, ok_i """ with h5py.File(self.h5file,'r') as f: fnu_i = f['cat/fnu'][ix,:] efnu_orig = f['cat/efnu_orig'][ix,:] ra_i = f['cat/ra'][ix] dec_i = f['cat/dec'][ix] z_i = f['fit/zbest'][ix] chi2_i = f['fit/chi2_fit'][ix,:] zspec_i = f['cat/z_spec'][ix] ok_i = f['cat/ok_data'][ix,:] efnu_i = np.sqrt(efnu_orig**2 + (self.param['SYS_ERR']*np.maximum(fnu_i, 0.))**2) return z_i, fnu_i, efnu_i, ra_i, dec_i, chi2_i, zspec_i, ok_i
[docs] def get_lnp(self, ix): """ Compute ln P(z) directly from chi2. N.B. No prior implemented yet! """ with h5py.File(self.h5file,'r') as f: chi2_i = f['fit/chi2_fit'][ix,:] return -0.5*(chi2_i - np.nanmin(chi2_i))
[docs] def show_fit(self, id, **kwargs): """ Wrapper around `eazy.photoz.PhotoZ.show_fit` """ _ = photoz.PhotoZ.show_fit(self, id, **kwargs) return _
[docs] def show_fit_plotly(self, id, **kwargs): """ Wrapper around `eazy.photoz.PhotoZ.show_fit_plotly` """ _ = photoz.PhotoZ.show_fit_plotly(self, id, **kwargs) return _
@property def to_flam(self): """ Conversion factor to :math:`10^{-19} erg/s/cm^2/Å` """ to_flam = 10**(-0.4*(self.param['PRIOR_ABZP']+48.6)) to_flam *= utils.CLIGHT*1.e10/1.e-19/self.pivot**2/self.ext_corr return to_flam @property def to_uJy(self): """ Conversion of observed fluxes to `~astropy.units.microJansky` """ return 10**(-0.4*(self.param['PRIOR_ABZP']-23.9)) @property def lc(self): """ Filter pivot wavelengths (deprecated, use `pivot`) """ return self.pivot @property def ARRAY_DTYPE(self): """ Array data type from `ARRAY_NBITS` parameter """ if 'ARRAY_NBITS' in self.param: if self.param['ARRAY_NBITS'] == 64: ARRAY_DTYPE = np.float64 else: ARRAY_DTYPE = np.float32 else: ARRAY_DTYPE = np.float32 return ARRAY_DTYPE @property def MW_EBV(self): """ Galactic extinction E(B-V) """ if 'MW_EBV' not in self.param: return 0. # 0.0354 # MACS0416 else: return self.param['MW_EBV']