Source code for eazy.param

import os
import collections
import numpy as np

__all__ = ["EazyParam", "TranslateFile", "read_param_file"]

[docs] def read_param_file(param_file=None, verbose=True): """ Load a param file and add default parameters if any missing """ param = EazyParam(param_file, verbose=True) if param_file is not None: # Read defaults defaults = EazyParam(None, verbose=False) for k in defaults.param_names: if k not in param.param_names: param[k] = defaults[k] if verbose: print(f'Parameter default: {k} = {defaults[k]}') return param
[docs] class EazyParam(object): def __init__(self, PARAM_FILE=None, verbose=True): """ Read an Eazy zphot.param file. Example: >>> if os.path.exists('zphot.param'): ... params = EazyParam(PARAM_FILE='zphot.param') ... print(params['Z_STEP']) Defaults are in `eazy/data/zphot.param.default <https://github.com/gbrammer/eazy-py/blob/master/eazy/data/zphot.param.default>`_ Parameters ---------- param_file : str Name of parameter file. If None, then will get `data/zphot.param.default` from within the module. Attributes ---------- params : `collections.OrderedDict` Parameter dictionary. Don't modify this directly but rather use `__getitem__` and `__setitem__` methods. param_names formats : list List indicating if parameters are interpreted as string ('s') or scalar ('f') values. """ if PARAM_FILE is None: PARAM_FILE = os.path.join(os.path.dirname(__file__), 'data/zphot.param.default') if verbose: print('Read default param file: '+PARAM_FILE) self.filename = PARAM_FILE self.param_path = os.path.dirname(PARAM_FILE) f = open(PARAM_FILE,'r') self.lines = f.readlines() f.close() self.params = collections.OrderedDict() self.formats = collections.OrderedDict() self._process_params() @property def param_names(self): """ Keywords of the `params` dictionary """ return list(self.params.keys()) def _process_params(self): """ Process parameter dictionary """ params = collections.OrderedDict() formats = collections.OrderedDict() #self.param_names = [] for line in self.lines: if not line.strip().startswith('#'): lsplit = line.split() if lsplit.__len__() >= 2: params[lsplit[0]] = lsplit[1] #self.param_names.append(lsplit[0]) try: flt = float(lsplit[1]) formats[lsplit[0]] = 'f' params[lsplit[0]] = flt except: formats[lsplit[0]] = 's' self.params = params self.formats = formats @property def to_mJy(self): """ Return catalog conversion factor to mJy based on ``PRIOR_ABZP``. """ return 10**(-0.4*(self.params['PRIOR_ABZP']-23.9))/1000.
[docs] def write(self, file=None): """ Write to an ascii file """ if file == None: print('No output file specified...') else: fp = open(file,'w') for param in self.param_names: fp.write('{0:25s} {1}\n'.format(param, self.params[param])) fp.close()
def __getitem__(self, param_name): """ Get item from ``params`` dict and return None if parameter not found. """ if param_name.upper() not in self.param_names: print(f'Parameter {param_name} not found. Check `param_names`' ' attribute.') return None else: return self.params[param_name.upper()] def __setitem__(self, param_name, value): """ Set item in ``params`` dict. """ self.params[param_name.upper()] = value
[docs] def verify_params(self): """ Some checks on the parameters """ assert(self['Z_MAX'] > self['Z_MIN']) for k in ['TEMPLATES_FILE', 'TEMP_ERR_FILE', 'CATALOG_FILE', 'FILTERS_RES']: if isinstance(self[k], str): if not os.path.exists(self[k]): raise FileNotFoundError(f'{k} ({self[k]}) not found') assert(int(self['ARRAY_NBITS']) in [32,64]) # Positive for k in ['TEMP_ERR_A2', 'SYS_ERR', 'IGM_SCALE_TAU', 'MW_EBV', 'OMEGA_M', 'OMEGA_L']: if self[k] < 0: raise ValueError(f'{k} ({self[k]}) must be >= 0') # Positive nonzero for k in ['Z_STEP','H0', 'RF_PADDING']: if self[k] < 0: raise ValueError(f'{k} ({self[k]}) must be > 0')
@property def kwargs(self): """ Dictionary with lower-case parameter names for passing as ``**kwargs`` """ kws = collections.OrderedDict() for k in self.param_names: kws[k.lower()] = self.params[k] return kws
[docs] class TranslateFile(): def __init__(self, file='zphot.translate'): """ File for translating catalog columns to associate bandbasses to them The `file` has format .. code-block:: flux_irac_ch1 F18 err_irac_ch1 E18 ... or a CSV table with format .. code-block:: column, trans, error flux_irac_ch1, F18 err_irac_ch1, E18, 1.0 ... where `flux_irac_ch1` is a column in the catalog table corresponding to the IRAC 3.6 µm flux density. ``F18`` indicates that this is a *flux density* column and is associated with filter number 18 in the `~eazy.params.filters.FilterFile`. ``E18`` indicates an uncertainty column, and filters must have both flux density and uncertainty columns to be considered. The original catalog could have had column names ``F18`` and ``E18`` and not needed a translate file but it is generally preferable to have more descriptive column names that aren't necessarily tied to a particular `eazy` filter file. Note, similarly, that columns like `F{N}` and `E{N}` are treated as these types of flux and uncertainty columns. If they correspond to something else, they should be "translated" to avoid confusion """ from astropy.table import Table self.file = file self.trans = collections.OrderedDict() self.error = collections.OrderedDict() if hasattr(file, 'colnames'): tr = file self.file = 'input_table.translate' if 'error' not in tr.colnames: tr['error'] = 1.0 if tr.colnames != ['column', 'trans', 'error']: msg = f"table translate_file file must have columns" msg += f" 'column', 'trans' [, 'error']. The file {file}" msg += f' has columns {tr.colnames}.' raise ValueError(msg) for i, k in enumerate(tr['column']): self.trans[k] = tr['trans'][i] self.error[k] = tr['error'][i] elif file.endswith('csv'): tr = Table.read(file) if 'error' not in tr.colnames: tr['error'] = 1.0 if tr.colnames != ['column', 'trans', 'error']: msg = f"csv translate_file file must have columns" msg += f" 'column', 'trans' [, 'error']. The file {file}" msg += f' has columns {tr.colnames}.' raise ValueError(msg) for i, k in enumerate(tr['column']): self.trans[k] = tr['trans'][i] self.error[k] = tr['error'][i] else: lines = open(file).readlines() for line in lines: spl = line.split() if (line.strip() == '') | (len(spl) < 2): continue key = spl[0] self.trans[key] = spl[1] if len(spl) == 3: self.error[key] = float(spl[2]) else: self.error[key] = 1.
[docs] def change_error(self, filter=88, value=1.e8): """ Modify uncertainties based on error scaling factors in translate file """ if isinstance(filter, str): if 'f_' in filter: err_filt = filter.replace('f_','e_') else: err_filt = 'e'+filter if err_filt in self.error: self.error[err_filt] = value return True if isinstance(filter, int): for key in self.trans.keys(): if self.trans[key] == 'E{0:0d}'.format(filter): self.error[key] = value return True print('Filter {0} not found in list.'.format(str(filter)))
[docs] def write(self, file=None, show_ones=False): """ Write to an ascii file """ lines = [] for key in self.error: line = '{0} {1}'.format(key, self.trans[key]) if self.trans[key].startswith('E') & ((self.error[key] != 1.0) | show_ones): line += ' {0:.1f}'.format(self.error[key]) lines.append(line+'\n') if file is None: file = self.file if file: fp = open(file,'w') fp.writelines(lines) fp.close() else: for line in lines: print(line[:-1])
[docs] def to_csv(self): """ Generate CSV string """ rows = 'column,trans,error\n' for k in self.error: rows += f'{k},{self.trans[k]},{self.error[k]}\n' return rows