Source code for src.data_manager

from __future__ import absolute_import, division, print_function, unicode_literals
import os
from src import six
import glob
import copy
import shutil
from collections import defaultdict, namedtuple
from itertools import chain
from operator import attrgetter
from abc import ABCMeta, abstractmethod
import datetime
if os.name == 'posix' and six.PY2:
    try:
        from subprocess32 import CalledProcessError
    except ImportError:
        from subprocess import CalledProcessError
else:
    from subprocess import CalledProcessError
from src import util
from src import util_mdtf
from src import datelabel
from src import netcdf_helper
from src.shared_diagnostic import PodRequirementFailure


[docs]@six.python_2_unicode_compatible class DataQueryFailure(Exception): """Exception signaling a failure to find requested data in the remote location. Raised by :meth:`~data_manager.DataManager.queryData` to signal failure of a data query. Should be caught properly in :meth:`~data_manager.DataManager.planData` or :meth:`~data_manager.DataManager.fetchData`. """
[docs] def __init__(self, dataset, msg=''): self.dataset = dataset self.msg = msg
def __str__(self): if hasattr(self.dataset, 'name'): return 'Query failure for {}: {}.'.format(self.dataset.name, self.msg) else: return 'Query failure: {}.'.format(self.msg)
[docs]@six.python_2_unicode_compatible class DataAccessError(Exception): """Exception signaling a failure to obtain data from the remote location. """
[docs] def __init__(self, dataset, msg=''): self.dataset = dataset self.msg = msg
def __str__(self): if hasattr(self.dataset, '_remote_data'): return 'Data access error for {}: {}.'.format( self.dataset._remote_data, self.msg) else: return 'Data access error: {}.'.format(self.msg)
[docs]class DataSet(util.NameSpace): """Class to describe datasets. `<https://stackoverflow.com/a/48806603>`__ for implementation. """
[docs] def __init__(self, *args, **kwargs): if 'DateFreqMixin' not in kwargs: self.DateFreq = datelabel.DateFrequency else: self.DateFreq = kwargs['DateFreqMixin'] del kwargs['DateFreqMixin'] # assign explicitly else linter complains self.name = None self.date_range = None self.date_freq = None self._local_data = None self._remote_data = [] self.alternates = [] self.axes = dict() super(DataSet, self).__init__(*args, **kwargs) if ('var_name' in self) and (self.name is None): self.name = self.var_name del self.var_name if ('freq' in self) and (self.date_freq is None): self.date_freq = self.DateFreq(self.freq) del self.freq
[docs] def copy(self, new_name=None): temp = super(DataSet, self).copy() if new_name is not None: temp.name = new_name return temp
[docs] @classmethod def from_pod_varlist(cls, pod_convention, var, dm_args): translate = util_mdtf.VariableTranslator() var_copy = var.copy() var_copy.update(dm_args) ds = cls(**var_copy) ds.original_name = ds.name ds.CF_name = translate.toCF(pod_convention, ds.name) alt_ds_list = [] for alt_var in ds.alternates: alt_ds = ds.copy(new_name=alt_var) alt_ds.original_name = ds.original_name alt_ds.CF_name = translate.toCF(pod_convention, alt_ds.name) alt_ds.alternates = [] alt_ds_list.append(alt_ds) ds.alternates = alt_ds_list return ds
[docs] def _freeze(self): """Return immutable representation of (current) attributes. Exclude attributes starting with '_' from the comparison, in case we want DataSets with different timestamps, temporary directories, etc. to compare as equal. """ d = self.toDict() keys_to_hash = sorted(k for k in d if not k.startswith('_')) d2 = {k: repr(d[k]) for k in keys_to_hash} FrozenDataSet = namedtuple('FrozenDataSet', keys_to_hash) return FrozenDataSet(**d2)
[docs]class DataManager(six.with_metaclass(ABCMeta)): # analogue of TestFixture in xUnit
[docs] def __init__(self, case_dict, DateFreqMixin=None): if not DateFreqMixin: self.DateFreq = datelabel.DateFrequency else: self.DateFreq = DateFreqMixin self.case_name = case_dict['CASENAME'] self.model_name = case_dict['model'] self.firstyr = datelabel.Date(case_dict['FIRSTYR']) self.lastyr = datelabel.Date(case_dict['LASTYR']) self.date_range = datelabel.DateRange(self.firstyr, self.lastyr) self.convention = case_dict.get('convention', 'CF') if 'data_freq' in case_dict: self.data_freq = self.DateFreq(case_dict['data_freq']) else: self.data_freq = None self.pod_list = case_dict['pod_list'] self.pods = [] config = util_mdtf.ConfigManager() self.envvars = config.global_envvars.copy() # gets appended to # assign explicitly else linter complains self.dry_run = config.config.dry_run self.file_transfer_timeout = config.config.file_transfer_timeout self.make_variab_tar = config.config.make_variab_tar self.keep_temp = config.config.keep_temp self.overwrite = config.config.overwrite self.file_overwrite = self.overwrite # overwrite config and .tar d = config.paths.model_paths(case_dict, overwrite=self.overwrite) self.code_root = config.paths.CODE_ROOT self.MODEL_DATA_DIR = d.MODEL_DATA_DIR self.MODEL_WK_DIR = d.MODEL_WK_DIR self.MODEL_OUT_DIR = d.MODEL_OUT_DIR self.TEMP_HTML = os.path.join(self.MODEL_WK_DIR, 'pod_output_temp.html') # dynamic inheritance to add netcdf manipulation functions # source: https://stackoverflow.com/a/8545134 mixin = config.config.get(netcdf_helper, 'NcoNetcdfHelper') mixin = getattr(netcdf_helper, 'NcoNetcdfHelper') self.__class__ = type(self.__class__.__name__, (self.__class__, mixin), {}) try: self.nc_check_environ() # make sure we have dependencies except Exception: raise
[docs] def iter_pods(self): """Generator iterating over all pods which haven't been skipped due to requirement errors. """ for p in self.pods: if p.skipped is None: yield p
[docs] def iter_vars(self): """Generator iterating over all variables in all pods which haven't been skipped due to requirement errors. """ for p in self.iter_pods(): for var in p.varlist: yield var
# -------------------------------------
[docs] def setUp(self, verbose=0): util_mdtf.check_required_dirs( already_exist =[], create_if_nec = [self.MODEL_WK_DIR, self.MODEL_DATA_DIR], verbose=verbose) self.envvars.update({ "DATADIR": self.MODEL_DATA_DIR, "variab_dir": self.MODEL_WK_DIR, "CASENAME": self.case_name, "model": self.model_name, "FIRSTYR": self.firstyr.format(precision=1), "LASTYR": self.lastyr.format(precision=1) }) # set env vars for unit conversion factors (TODO: honest unit conversion) translate = util_mdtf.VariableTranslator() if self.convention not in translate.units: raise AssertionError(("Variable name translation doesn't recognize " "{}.").format(self.convention)) temp = translate.variables[self.convention].to_dict() for key, val in iter(temp.items()): util_mdtf.setenv(key, val, self.envvars, verbose=verbose) temp = translate.units[self.convention].to_dict() for key, val in iter(temp.items()): util_mdtf.setenv(key, val, self.envvars, verbose=verbose) for pod in self.iter_pods(): self._setup_pod(pod) self._build_data_dicts()
[docs] def _setup_pod(self, pod): config = util_mdtf.ConfigManager() translate = util_mdtf.VariableTranslator() # transfer DataManager-specific settings pod.__dict__.update(config.paths.pod_paths(pod, self)) pod.TEMP_HTML = self.TEMP_HTML pod.pod_env_vars.update(self.envvars) pod.dry_run = self.dry_run # express varlist as DataSet objects ds_list = [] for var in pod.varlist: ds_list.append(DataSet.from_pod_varlist( pod.convention, var, {'DateFreqMixin': self.DateFreq})) pod.varlist = ds_list for var in pod.iter_vars_and_alts(): var.name_in_model = translate.fromCF(self.convention, var.CF_name) var.date_range = self.date_range var._local_data = self.local_path(self.dataset_key(var)) var.axes = copy.deepcopy(translate.axes[self.convention]) if self.data_freq is not None: for var in pod.iter_vars_and_alts(): if var.date_freq != self.data_freq: pod.skipped = PodRequirementFailure( pod, ("{0} requests {1} (= {2}) at {3} frequency, which isn't " "compatible with case {4} providing data at {5} frequency " "only.").format( pod.name, var.name_in_model, var.name, var.date_freq, self.case_name, self.data_freq )) break
[docs] @staticmethod def dataset_key(dataset): """Return immutable representation of DataSet. Two DataSets should have the same key """ return dataset._freeze()
[docs] def local_path(self, data_key): """Returns the absolute path of the local copy of the file for dataset. This determines the local model data directory structure, which is `$MODEL_DATA_ROOT/<CASENAME>/<freq>/<CASENAME>.<var name>.<freq>.nc'`. Files not following this convention won't be found. """ assert 'name_in_model' in data_key._fields assert 'date_freq' in data_key._fields # values in key are repr strings by default, so need to instantiate the # datelabel object to use its formatting method try: # value in key is from __str__ freq = self.DateFreq(data_key.date_freq) except ValueError: # value in key is from __repr__ freq = eval('datelabel.'+data_key.date_freq) freq = freq.format_local() return os.path.join( self.MODEL_DATA_DIR, freq, "{}.{}.{}.nc".format( self.case_name, data_key.name_in_model, freq) )
[docs] def _build_data_dicts(self): self.data_keys = defaultdict(list) self.data_pods = util.MultiMap() self.data_files = util.MultiMap() for pod in self.iter_pods(): for var in pod.iter_vars_and_alts(): key = self.dataset_key(var) self.data_pods[key].update(set([pod])) self.data_keys[key].append(var) self.data_files[key].update(var._remote_data)
# -------------------------------------
[docs] def fetch_data(self): self._query_data() # populate vars with found files for data_key in self.data_keys: for var in self.data_keys[data_key]: var._remote_data.extend(list(self.data_files[data_key])) for pod in self.iter_pods(): try: new_varlist = [var for var \ in self._iter_populated_varlist(pod.varlist, pod.name)] except DataQueryFailure as exc: print("Data query failed on pod {}; skipping.".format(pod.name)) pod.skipped = exc new_varlist = [] for var in new_varlist: var.alternates = [] pod.varlist = new_varlist # revise DataManager's to-do list, now that we've marked some PODs as # being skipped due to data inavailability self._build_data_dicts() self.plan_data_fetch_hook() for file_ in self.remote_data_list(): try: self.fetch_dataset(file_) except CalledProcessError as caught_exc: exc = DataAccessError( file_, """Running external command {} when fetching {} @ {} returned error: {} (status {}). Did not retry. """.format( caught_exc.cmd, file_.name_in_model, file_.date_freq, caught_exc.output, caught_exc.returncode ) ) self._fetch_exception_handler(exc) continue except Exception as caught_exc: exc = DataAccessError( file_, """Caught {} exception ({}) when fetching {} @ {}. Did not retry. """.format( type(caught_exc).__name__, caught_exc, file_.name_in_model, file_.date_freq ) ) self._fetch_exception_handler(exc) continue
[docs] def _fetch_exception_handler(self, exc): print(exc) keys_from_file = self.data_files.inverse() for key in keys_from_file[exc.dataset]: for pod in self.data_pods[key]: print(("\tSkipping pod {} due to data fetch error." "").format(pod.name)) pod.skipped = exc
[docs] def _query_data(self): for data_key in self.data_keys: try: var = self.data_keys[data_key][0] print("Calling query_dataset on {} @ {}".format( var.name_in_model, var.date_freq)) files = self.query_dataset(var) self.data_files[data_key].update(files) except DataQueryFailure: continue
[docs] def _iter_populated_varlist(self, var_iter, pod_name): """Generator function yielding either a variable, its alternates if the variable was not found in the data query, or DataQueryFailure if the variable request can't be satisfied with found data. """ for var in var_iter: if var._remote_data: print("Found {} (= {}) @ {} for {}".format( var.name_in_model, var.name, var.date_freq, pod_name )) yield var elif not var.alternates: raise DataQueryFailure( var, ("Couldn't find {} (= {}) @ {} for {} & no other " "alternates").format( var.name_in_model, var.name, var.date_freq, pod_name )) else: print(("Couldn't find {} (= {}) @ {} for {}, trying " "alternates").format( var.name_in_model, var.name, var.date_freq, pod_name )) for alt_var in self._iter_populated_varlist(var.alternates, pod_name): yield alt_var # no 'yield from' in py2.7
[docs] def remote_data_list(self): """Process list of requested data to make data fetching efficient. This is intended as a hook to be used by subclasses. Default behavior is to delete from the list duplicate datasets and datasets where a local copy of the data already exists and is current (as determined by :meth:`~data_manager.DataManager.local_data_is_current`). Returns: collection of :class:`~util.DataSet` objects. """ # flatten list of all _remote_datas and remove duplicates unique_files = set(f for f in chain.from_iterable(iter(self.data_files.values()))) # filter out any data we've previously fetched that's up to date unique_files = [f for f in unique_files if not self.local_data_is_current(f)] # fetch data in sorted order to make interpreting logs easier if unique_files: if self._fetch_order_function is not None: sort_key = self._fetch_order_function if hasattr(unique_files[0], '_remote_data'): sort_key = attrgetter('_remote_data') else: sort_key = None unique_files.sort(key=sort_key) return unique_files
_fetch_order_function=None
[docs] def local_data_is_current(self, dataset): """Determine if local copy of data needs to be refreshed. This is intended as a hook to be used by subclasses. Default is to always return `False`, ie always fetch remote data. Returns: `True` if local copy of data exists and remote copy hasn't been updated. """ return False
[docs] def plan_data_fetch_hook(self): pass
[docs] def preprocess_local_data(self, *args, **kwargs): # do translation/ transformations of data here pass
# ------------------------------------- # following are specific details that must be implemented in child class
[docs] @abstractmethod def query_dataset(self, dataset): pass
[docs] @abstractmethod def fetch_dataset(self, dataset): pass
# -------------------------------------
[docs] def tearDown(self): # TODO: handle OSErrors in all of these config = util_mdtf.ConfigManager() self._make_html() _ = self._backup_config_file(config) if self.make_variab_tar: _ = self._make_tar_file(config.paths.OUTPUT_DIR) self._copy_to_output()
[docs] def _make_html(self, cleanup=True): src_dir = os.path.join(self.code_root, 'src', 'html') dest = os.path.join(self.MODEL_WK_DIR, 'index.html') if os.path.isfile(dest): print("WARNING: index.html exists, deleting.") os.remove(dest) template_dict = self.envvars.copy() template_dict['DATE_TIME'] = \ datetime.datetime.utcnow().strftime("%A, %d %B %Y %I:%M%p (UTC)") util_mdtf.append_html_template( os.path.join(src_dir, 'mdtf_header.html'), dest, template_dict ) util_mdtf.append_html_template(self.TEMP_HTML, dest, {}) util_mdtf.append_html_template( os.path.join(src_dir, 'mdtf_footer.html'), dest, template_dict ) if cleanup: os.remove(self.TEMP_HTML) shutil.copy2( os.path.join(src_dir, 'mdtf_diag_banner.png'), self.MODEL_WK_DIR )
[docs] def _backup_config_file(self, config): """Record settings in file variab_dir/config_save.json for rerunning """ out_file = os.path.join(self.MODEL_WK_DIR, 'config_save.json') if not self.file_overwrite: out_file, _ = util_mdtf.bump_version(out_file) elif os.path.exists(out_file): print('Overwriting {}.'.format(out_file)) util.write_json(config.config.toDict(), out_file) return out_file
[docs] def _make_tar_file(self, tar_dest_dir): """Make tar file of web/bitmap output. """ out_file = os.path.join(tar_dest_dir, self.MODEL_WK_DIR+'.tar') if not self.file_overwrite: out_file, _ = util_mdtf.bump_version(out_file) print("Creating {}.".format(out_file)) elif os.path.exists(out_file): print('Overwriting {}.'.format(out_file)) tar_flags = ["--exclude=.{}".format(s) for s in ['netCDF','nc','ps','PS','eps']] tar_flags = ' '.join(tar_flags) util.run_shell_command( 'tar {} -czf {} -C {} .'.format(tar_flags, out_file, self.MODEL_WK_DIR), dry_run = self.dry_run ) return out_file
[docs] def _copy_to_output(self): if self.MODEL_WK_DIR == self.MODEL_OUT_DIR: return # no copying needed print("copy {} to {}".format(self.MODEL_WK_DIR, self.MODEL_OUT_DIR)) try: if os.path.exists(self.MODEL_OUT_DIR): if not self.overwrite: print('Error: {} exists, overwriting anyway.'.format( self.MODEL_OUT_DIR)) shutil.rmtree(self.MODEL_OUT_DIR) except Exception: raise shutil.move(self.MODEL_WK_DIR, self.MODEL_OUT_DIR)
[docs]class LocalfileDataManager(DataManager): # Assumes data files are already present in required directory structure DataKey = namedtuple('DataKey', ['name_in_model', 'date_freq'])
[docs] def dataset_key(self, dataset): return self.DataKey( name_in_model=dataset.name_in_model, date_freq=str(dataset.date_freq) )
[docs] def query_dataset(self, dataset): path = self.local_path(self.dataset_key(dataset)) if os.path.isfile(path): return [path] else: raise DataQueryFailure(dataset, 'File not found at {}'.format(path))
[docs] def local_data_is_current(self, dataset): return True
[docs] def fetch_dataset(self, dataset): pass