"""Common functions and classes used in multiple places in the MDTF code.
Specifically, util.py implements general functionality that's not MDTF-specific.
"""
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import io
from src import six
import re
import shlex
import glob
import shutil
import collections
from distutils.spawn import find_executable
if os.name == 'posix' and six.PY2:
try:
import subprocess32 as subprocess
except ImportError:
import subprocess
else:
import subprocess
import signal
import threading
import errno
import json
from six.moves import getcwd, collections_abc
[docs]class _Singleton(type):
"""Private metaclass that creates a :class:`~util.Singleton` base class when
called. This version is copied from `<https://stackoverflow.com/a/6798042>`__ and
should be compatible with both Python 2 and 3.
"""
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(_Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
[docs]class Singleton(_Singleton(six.ensure_str('SingletonMeta'), (object,), {})):
"""Parent class defining the
`Singleton <https://en.wikipedia.org/wiki/Singleton_pattern>`_ pattern. We
use this as safer way to pass around global state.
"""
[docs] @classmethod
def _reset(cls):
"""Private method of all :class:`~util.Singleton`-derived classes added
for use in unit testing only. Calling this method on test teardown
deletes the instance, so that tests coming afterward will initialize the
:class:`~util.Singleton` correctly, instead of getting the state set
during previous tests.
"""
# pylint: disable=maybe-no-member
if cls in cls._instances:
del cls._instances[cls]
[docs]class ExceptionPropagatingThread(threading.Thread):
"""Class to propagate exceptions raised in a child thread back to the caller
thread when the child is join()ed.
Adapted from `<https://stackoverflow.com/a/31614591>`__.
"""
[docs] def run(self):
self.ret = None
self.exc = None
try:
if hasattr(self, '_Thread__target'):
# Thread uses name mangling prior to Python 3.
self.ret = self._Thread__target(*self._Thread__args, **self._Thread__kwargs)
else:
self.ret = self._target(*self._args, **self._kwargs)
except BaseException as e:
self.exc = e
[docs] def join(self, timeout=None):
super(ExceptionPropagatingThread, self).join(timeout)
if self.exc:
raise self.exc
return self.ret
[docs]class MultiMap(collections.defaultdict):
"""Extension of the :obj:`dict` class that allows doing dictionary lookups
from either keys or values.
Syntax for lookup from keys is unchanged, ``bd['key'] = 'val'``, while lookup
from values is done on the `inverse` attribute and returns a set of matching
keys if more than one match is present: ``bd.inverse['val'] = ['key1', 'key2']``.
See `<https://stackoverflow.com/a/21894086>`__.
"""
[docs] def __init__(self, *args, **kwargs):
"""Initialize :class:`~util.MultiMap` by passing an ordinary :py:obj:`dict`.
"""
super(MultiMap, self).__init__(set, *args, **kwargs)
for key in iter(self.keys()):
super(MultiMap, self).__setitem__(key, coerce_to_iter(self[key], set))
def __setitem__(self, key, value):
super(MultiMap, self).__setitem__(key, coerce_to_iter(value, set))
[docs] def get_(self, key):
if key not in list(self.keys()):
raise KeyError(key)
return coerce_from_iter(self[key])
[docs] def to_dict(self):
d = {}
for key in iter(self.keys()):
d[key] = self.get_(key)
return d
[docs] def inverse(self):
d = collections.defaultdict(set)
for key, val_set in iter(self.items()):
for v in val_set:
d[v].add(key)
return dict(d)
[docs] def inverse_get_(self, val):
# don't raise keyerror if empty; could be appropriate result
inv_lookup = self.inverse()
return coerce_from_iter(inv_lookup[val])
[docs]class NameSpace(dict):
""" A dictionary that provides attribute-style access.
For example, `d['key'] = value` becomes `d.key = value`. All methods of
:py:obj:`dict` are supported.
Note: recursive access (`d.key.subkey`, as in C-style languages) is not
supported.
Implementation is based on `<https://github.com/Infinidat/munch>`__.
"""
# only called if k not found in normal places
def __getattr__(self, k):
""" Gets key if it exists, otherwise throws AttributeError.
nb. __getattr__ is only called if key is not found in normal places.
"""
try:
# Throws exception if not in prototype chain
return object.__getattribute__(self, k)
except AttributeError:
try:
return self[k]
except KeyError:
raise AttributeError(k)
def __setattr__(self, k, v):
""" Sets attribute k if it exists, otherwise sets key k. A KeyError
raised by set-item (only likely if you subclass NameSpace) will
propagate as an AttributeError instead.
"""
try:
# Throws exception if not in prototype chain
object.__getattribute__(self, k)
except AttributeError:
try:
self[k] = v
except:
raise AttributeError(k)
else:
object.__setattr__(self, k, v)
def __delattr__(self, k):
""" Deletes attribute k if it exists, otherwise deletes key k. A KeyError
raised by deleting the key--such as when the key is missing--will
propagate as an AttributeError instead.
"""
try:
# Throws exception if not in prototype chain
object.__getattribute__(self, k)
except AttributeError:
try:
del self[k]
except KeyError:
raise AttributeError(k)
else:
object.__delattr__(self, k)
def __dir__(self):
return list(self.keys())
__members__ = __dir__ # for python2.x compatibility
def __repr__(self):
""" Invertible* string-form of a Munch.
(*) Invertible so long as collection contents are each repr-invertible.
"""
return '{0}({1})'.format(self.__class__.__name__, dict.__repr__(self))
def __getstate__(self):
""" Implement a serializable interface used for pickling.
See `<https://docs.python.org/3.6/library/pickle.html>`__.
"""
return {k: v for k, v in iter(self.items())}
def __setstate__(self, state):
""" Implement a serializable interface used for pickling.
See `<https://docs.python.org/3.6/library/pickle.html>`__.
"""
self.clear()
self.update(state)
[docs] def toDict(self):
""" Recursively converts a NameSpace back into a dictionary.
"""
return type(self)._toDict(self)
[docs] @classmethod
def _toDict(cls, x):
""" Recursively converts a NameSpace back into a dictionary.
nb. As dicts are not hashable, they cannot be nested in sets/frozensets.
"""
if isinstance(x, dict):
return dict((k, cls._toDict(v)) for k, v in iter(x.items()))
elif isinstance(x, (list, tuple)):
return type(x)(cls._toDict(v) for v in x)
else:
return x
@property
def __dict__(self):
return self.toDict()
[docs] @classmethod
def fromDict(cls, x):
""" Recursively transforms a dictionary into a NameSpace via copy.
nb. As dicts are not hashable, they cannot be nested in sets/frozensets.
"""
if isinstance(x, dict):
return cls((k, cls.fromDict(v)) for k, v in iter(x.items()))
elif isinstance(x, (list, tuple)):
return type(x)(cls.fromDict(v) for v in x)
else:
return x
[docs] def copy(self):
return type(self).fromDict(self)
__copy__ = copy
[docs] def _freeze(self):
"""Return immutable representation of (current) attributes.
We do this to enable comparison of two Namespaces, which otherwise would
be done by the default method of testing if the two objects refer to the
same location in memory.
See `<https://stackoverflow.com/a/45170549>`__.
"""
d = self.toDict()
d2 = {k: repr(d[k]) for k in d}
FrozenNameSpace = collections.namedtuple(
'FrozenNameSpace', sorted(list(d.keys()))
)
return FrozenNameSpace(**d2)
def __eq__(self, other):
if type(other) is type(self):
return (self._freeze() == other._freeze())
else:
return False
def __ne__(self, other):
return (not self.__eq__(other)) # more foolproof
def __hash__(self):
return hash(self._freeze())
# ------------------------------------
[docs]def read_json(file_path):
assert os.path.exists(file_path), \
"Couldn't find JSON file {}.".format(file_path)
try:
with io.open(file_path, 'r', encoding='utf-8') as file_:
str_ = file_.read()
except IOError:
print('Fatal IOError when trying to read {}. Exiting.'.format(file_path))
exit()
return parse_json(str_)
[docs]def parse_json(str_):
str_ = strip_comments(str_, delimiter= '//') # JSONC quasi-standard
try:
parsed_json = json.loads(str_, object_pairs_hook=collections.OrderedDict)
except UnicodeDecodeError:
print('{} contains non-ascii characters. Exiting.'.format(str_))
exit()
return parsed_json
[docs]def write_json(struct, file_path, verbose=0, sort_keys=False):
"""Wrapping file I/O simplifies unit testing.
Args:
struct (:py:obj:`dict`)
file_path (:py:obj:`str`): path of the JSON file to write.
verbose (:py:obj:`int`, optional): Logging verbosity level. Default 0.
"""
try:
str_ = json.dumps(struct,
sort_keys=sort_keys, indent=2, separators=(',', ': '))
with io.open(file_path, 'w', encoding='utf-8') as file_:
file_.write(six.ensure_text(str_, encoding='utf-8', errors='strict'))
except IOError:
print('Fatal IOError when trying to write {}. Exiting.'.format(file_path))
exit()
[docs]def pretty_print_json(struct, sort_keys=False):
"""Pseudo-YAML output for human-readable debugging output only -
not valid JSON"""
str_ = json.dumps(struct, sort_keys=sort_keys, indent=2)
for char in ['"', ',', '}', '[', ']']:
str_ = str_.replace(char, '')
str_ = re.sub(r"{\s+", "- ", str_)
# remove lines containing only whitespace
return os.linesep.join([s for s in str_.splitlines() if s.strip()])
[docs]def find_files(src_dirs, filename_globs):
"""Return list of files in `src_dirs` matching any of `filename_globs`.
Wraps glob.glob for the use cases encountered in cleaning up POD output.
Args:
src_dirs: Directory, or a list of directories, to search for files in.
The function will also search all subdirectories.
filename_globs: Glob, or a list of globs, for filenames to match. This
is a shell globbing pattern, not a full regex.
Returns: :py:obj:`list` of paths to files matching any of the criteria.
If no files are found, the list is empty.
"""
src_dirs = coerce_to_iter(src_dirs)
filename_globs = coerce_to_iter(filename_globs)
files = set([])
for d in src_dirs:
for g in filename_globs:
files.update(glob.glob(os.path.join(d, g)))
files.update(glob.glob(os.path.join(d, '**', g)))
return list(files)
[docs]def recursive_copy(src_files, src_root, dest_root, copy_function=None,
overwrite=False):
"""Copy src_files to dest_root, preserving relative subdirectory structure.
Copies a subset of files in a directory subtree rooted at src_root to an
identical subtree structure rooted at dest_root, creating any subdirectories
as needed. For example, `recursive_copy('/A/B/C.txt', '/A', '/D')` will
first create the destination subdirectory `/D/B` and copy '/A/B/C.txt` to
`/D/B/C.txt`.
Args:
src_files: Absolute path, or list of absolute paths, to files to copy.
src_root: Root subtree of all files in src_files. Raises a ValueError
if all files in src_files are not contained in the src_root directory.
dest_root: Destination directory in which to create the copied subtree.
copy_function: Function to use to copy individual files. Must take two
arguments, the source and destination paths, respectively. Defaults
to :py:meth:`shutil.copy2`.
overwrite: Boolean, deafult False. If False, raise an OSError if
any destination files already exist, otherwise silently overwrite.
"""
if copy_function is None:
copy_function = shutil.copy2
src_files = coerce_to_iter(src_files)
for f in src_files:
if not f.startswith(src_root):
raise ValueError('{} not a sub-path of {}'.format(f, src_root))
dest_files = [
os.path.join(dest_root, os.path.relpath(f, start=src_root)) \
for f in src_files
]
for f in dest_files:
if not overwrite and os.path.exists(f):
raise OSError('{} exists.'.format(f))
os.makedirs(os.path.normpath(os.path.dirname(f)), exist_ok=True)
for src, dest in zip(src_files, dest_files):
copy_function(src, dest)
[docs]def resolve_path(path, root_path="", env=None):
"""Abbreviation to resolve relative paths.
Args:
path (:obj:`str`): path to resolve.
root_path (:obj:`str`, optional): root path to resolve `path` with. If
not given, resolves relative to `cwd`.
Returns: Absolute version of `path`, relative to `root_path` if given,
otherwise relative to `os.getcwd`.
"""
def _expandvars(path, env_dict):
"""Expand quoted variables of the form $key and ${key} in path,
where key is a key in env_dict, similar to os.path.expandvars.
See `<https://stackoverflow.com/a/30777398>`__; specialize to not skipping
escaped characters and not changing unrecognized variables.
"""
return re.sub(
r'\$(\w+|\{([^}]*)\})',
lambda m: env_dict.get(m.group(2) or m.group(1), m.group(0)),
path
)
if path == '':
return path # default value set elsewhere
path = os.path.expanduser(path) # resolve '~' to home dir
path = os.path.expandvars(path) # expand $VAR or ${VAR} for shell envvars
if isinstance(env, dict):
path = _expandvars(path, env)
if '$' in path:
print("Warning: couldn't resolve all env vars in '{}'".format(path))
return path
if os.path.isabs(path):
return path
if root_path == "":
root_path = getcwd()
assert os.path.isabs(root_path)
return os.path.normpath(os.path.join(root_path, path))
[docs]def check_executable(exec_name):
"""Tests if <exec_name> is found on the current $PATH.
Args:
exec_name (:py:obj:`str`): Name of the executable to search for.
Returns: :py:obj:`bool` True/false if executable was found on $PATH.
"""
return (find_executable(exec_name) is not None)
[docs]def poll_command(command, shell=False, env=None):
"""Runs a shell command and prints stdout in real-time.
Optional ability to pass a different environment to the subprocess. See
documentation for the Python2 `subprocess
<https://docs.python.org/2/library/subprocess.html>`_ module.
Args:
command: list of command + arguments, or the same as a single string.
See `subprocess` syntax. Note this interacts with the `shell` setting.
shell (:py:obj:`bool`, optional): shell flag, passed to Popen,
default `False`.
env (:py:obj:`dict`, optional): environment variables to set, passed to
Popen, default `None`.
"""
process = subprocess.Popen(
command, shell=shell, env=env, stdout=subprocess.PIPE)
while True:
output = process.stdout.readline()
if output == '' and process.poll() is not None:
break
if output:
print(output.strip())
rc = process.poll()
return rc
[docs]class TimeoutAlarm(Exception):
# dummy exception for signal handling in run_command
pass
[docs]def run_command(command, env=None, cwd=None, timeout=0, dry_run=False):
"""Subprocess wrapper to facilitate running single command without starting
a shell.
Note:
We hope to save some process overhead by not running the command in a
shell, but this means the command can't use piping, quoting, environment
variables, or filename globbing etc.
See documentation for the Python2 `subprocess
<https://docs.python.org/2/library/subprocess.html>`_ module.
Args:
command (list of :py:obj:`str`): List of commands to execute
env (:py:obj:`dict`, optional): environment variables to set, passed to
`Popen`, default `None`.
cwd (:py:obj:`str`, optional): child processes' working directory, passed
to `Popen`. Default is `None`, which uses parent processes' directory.
timeout (:py:obj:`int`, optional): Optionally, kill the command's subprocess
and raise a CalledProcessError if the command doesn't finish in
`timeout` seconds.
Returns:
:py:obj:`list` of :py:obj:`str` containing output that was written to stdout
by each command. Note: this is split on newlines after the fact.
Raises:
CalledProcessError: If any commands return with nonzero exit code.
Stderr for that command is stored in `output` attribute.
"""
def _timeout_handler(signum, frame):
raise TimeoutAlarm
if isinstance(command, six.string_types):
command = shlex.split(command)
cmd_str = ' '.join(command)
if dry_run:
print('DRY_RUN: call {}'.format(cmd_str))
return
proc = None
pid = None
retcode = 1
stderr = ''
try:
proc = subprocess.Popen(
command, shell=False, env=env, cwd=cwd,
stdout=subprocess.PIPE, stderr=subprocess.PIPE,
universal_newlines=True, bufsize=0
)
pid = proc.pid
# py3 has timeout built into subprocess; this is a workaround
signal.signal(signal.SIGALRM, _timeout_handler)
signal.alarm(int(timeout))
(stdout, stderr) = proc.communicate()
signal.alarm(0) # cancel the alarm
retcode = proc.returncode
except TimeoutAlarm:
if proc:
proc.kill()
retcode = errno.ETIME
stderr = stderr+"\nKilled by timeout (>{}sec).".format(timeout)
except Exception as exc:
if proc:
proc.kill()
stderr = stderr+"\nCaught exception {0}({1!r})".format(
type(exc).__name__, exc.args)
if retcode != 0:
print('run_command on {} (pid {}) exit status={}:{}\n'.format(
cmd_str, pid, retcode, stderr
))
raise subprocess.CalledProcessError(
returncode=retcode, cmd=cmd_str, output=stderr)
if '\0' in stdout:
return stdout.split('\0')
else:
return stdout.splitlines()
[docs]def run_shell_command(command, env=None, cwd=None, dry_run=False):
"""Subprocess wrapper to facilitate running shell commands.
See documentation for the Python2 `subprocess
<https://docs.python.org/2/library/subprocess.html>`_ module.
Args:
commands (list of :py:obj:`str`): List of commands to execute
env (:py:obj:`dict`, optional): environment variables to set, passed to
`Popen`, default `None`.
cwd (:py:obj:`str`, optional): child processes' working directory, passed
to `Popen`. Default is `None`, which uses parent processes' directory.
Returns:
:py:obj:`list` of :py:obj:`str` containing output that was written to stdout
by each command. Note: this is split on newlines after the fact, so if
commands give != 1 lines of output this will not map to the list of commands
given.
Raises:
CalledProcessError: If any commands return with nonzero exit code.
Stderr for that command is stored in `output` attribute.
"""
# shouldn't lookup on each invocation, but need abs path to bash in order
# to pass as executable argument. Pass executable argument because we want
# bash specifically (not default /bin/sh, and we save a bit of overhead by
# starting bash directly instead of from sh.)
bash_exec = find_executable('bash')
if not isinstance(command, six.string_types):
command = ' '.join(command)
if dry_run:
print('DRY_RUN: call {}'.format(command))
return
proc = None
pid = None
retcode = 1
stderr = ''
try:
proc = subprocess.Popen(
command,
shell=True, executable=bash_exec,
env=env, cwd=cwd,
stdout=subprocess.PIPE, stderr=subprocess.PIPE,
universal_newlines=True, bufsize=0
)
pid = proc.pid
(stdout, stderr) = proc.communicate()
retcode = proc.returncode
except Exception as exc:
if proc:
proc.kill()
stderr = stderr+"\nCaught exception {0}({1!r})".format(
type(exc).__name__, exc.args)
if retcode != 0:
print('run_shell_command on {} (pid {}) exit status={}:{}\n'.format(
command, pid, retcode, stderr
))
raise subprocess.CalledProcessError(
returncode=retcode, cmd=command, output=stderr)
if '\0' in stdout:
return stdout.split('\0')
else:
return stdout.splitlines()
[docs]def is_iterable(obj):
return isinstance(obj, collections_abc.Iterable) \
and not isinstance(obj, six.string_types) # py3 strings have __iter__
[docs]def coerce_to_iter(obj, coll_type=list):
assert coll_type in [list, set, tuple] # only supported types for now
if obj is None:
return coll_type([])
elif isinstance(obj, coll_type):
return obj
elif is_iterable(obj):
return coll_type(obj)
else:
return coll_type([obj])
[docs]def coerce_from_iter(obj):
if is_iterable(obj):
if len(obj) == 1:
return list(obj)[0]
else:
return list(obj)
else:
return obj
[docs]def filter_kwargs(kwarg_dict, function):
"""Given a dict of kwargs, return only those kwargs accepted by function.
"""
named_args = set(six.get_function_code(function).co_varnames)
# if 'kwargs' in named_args:
# return kwarg_dict # presumably can handle anything
return dict((k, kwarg_dict[k]) for k in named_args \
if k in kwarg_dict and k not in ['self', 'args', 'kwargs'])
[docs]def signal_logger(caller_name, signum=None, frame=None):
"""Lookup signal name from number; `<https://stackoverflow.com/a/2549950>`__.
"""
if signum:
sig_lookup = {
k:v for v, k in reversed(sorted(list(signal.__dict__.items()))) \
if v.startswith('SIG') and not v.startswith('SIG_')
}
print("\tDEBUG: {} caught signal {} ({})".format(
caller_name, sig_lookup.get(signum, 'UNKNOWN'), signum
))
print("\tDEBUG: {}".format(frame))