"""Extensions to Python :py:mod:`dataclasses`, for streamlined class definition.
"""
import collections
import copy
import dataclasses
import enum
import functools
import re
import typing
from . import basic
from . import exceptions
import logging
_log = logging.getLogger(__name__)
# The ClassMaker is cribbed from SO
# https://stackoverflow.com/questions/1176136/convert-string-to-python-class-object
# Classmaker and the @catalog_class.maker decorator allow class instantiation from
# strings. The main block can simply call the desired class using the convention
# argument instead of messy if/then/else blocks
# Instantiate the class maker with catalog_class = ClassMaker()
[docs]
class ClassMaker:
""" Class to instantiate other classes from strings"""
def __init__(self):
self.classes = {}
[docs]
def add_class(self, c):
self.classes[c.__name__] = c
# define the class decorator to return the class passed
[docs]
def maker(self, c):
self.add_class(c)
return c
def __getitem__(self, n):
return self.classes[n]
[docs]
class RegexPatternBase:
"""Dummy parent class for :class:`RegexPattern` and
:class:`ChainedRegexPattern`.
"""
pass
[docs]
class RegexPattern(collections.UserDict, RegexPatternBase):
"""Wraps :py:class:`re.Pattern` with more convenience methods for the use case
of parsing information in a string, using a regex with named capture groups
corresponding to the data fields being collected from the string.
"""
data: dict
fields: frozenset
input_string: str = ""
input_field: str = ""
is_matched: bool = False
_defaults: dict
[docs]
def __init__(self, regex, defaults=None, input_field=None,
match_error_filter=None):
"""Constructor.
Args:
regex (str or :py:class:`re.Pattern`): regex to use for string
parsing. Should contain named match groups corresponding to the
fields to parse.
defaults (dict): Optional. If supplied, any fields not matched by the
named match groups in *regex* will be set equal to their values here.
input_field (str): Optional. If supplied, add a field to the match with
the supplied name which will be set equal to the contents of the
input string on a successful match.
match_error_filter (bool or :class:`RegexPattern` or :class:`ChainedRegexPattern`):
Optional. If supplied, determines whether a ValueError is raised
when the :meth:`match` method fails to parse a string (see below.)
Attributes:
data (dict): Key:value pairs corresponding to the contents of the
matching groups from the last successful call to :meth:`match`, or
empty if no successful call has been made. From
:py:class:`collections.UserDict`.
fields (frozenset): Set of fields matched by the pattern. Consists of the
union of named match groups in *regex*, and all keys in *defaults*.
input_string (str): Contains string that was input to last call of
:meth:`match`, whether successful or not.
is_matched (bool): True if the last call to :meth:`match` was
successful, False otherwise.
"""
self.data = dict()
self.input_string = ""
self.is_matched = False
try:
if isinstance(regex, re.Pattern):
self.regex = regex
else:
self.regex = re.compile(regex, re.VERBOSE)
except re.error as exc:
raise ValueError('Malformed input regex.') from exc
if self.regex.groups != len(self.regex.groupindex):
# _log.warning("Unnamed match groups in regex")
pass
if self.regex.groups == 0:
# _log.warning("No named match groups in regex")
pass
if not defaults:
self._defaults = dict()
else:
self._defaults = defaults.copy()
self.input_field = input_field
self._match_error_filter = match_error_filter
self._update_fields()
[docs]
def clear(self):
"""Erase field values parsed from a pre-existing match.
"""
self.data = dict()
self.input_string = ""
self.is_matched = False
def _update_fields(self):
self.regex_fields = frozenset(self.regex.groupindex.keys())
self.fields = self.regex_fields.union(self._defaults.keys())
if self.input_field:
self.fields = self.fields.union((self.input_field,))
self.clear()
[docs]
def update_defaults(self, d):
"""Update the default values used for the match with the values in *d*.
"""
if d:
self._defaults.update(d)
self._update_fields()
[docs]
def match(self, str_, *args):
"""Match *str_* using Python :py:func:`re.fullmatch` with *regex* and
populate object's fields according to the values captured by the named
capture groups in *regex*.
Args:
str_ (str): Input string to parse.
args: Optional. Flags (as defined in Python :py:mod:`re`) to use in
the :py:func:`re.fullmatch` method of the *regex* and *match_error_filter*
(if defined.)
Raises:
:class:`~exceptions.RegexParseError`: If :meth:`match` fails to parse
the input string, and the following conditions on *match_error_filter*
are met. If *match_error_filter* not supplied
(default), always raise when :meth:`match` fails. If *match_error_filter*
is bool, always/never raise. If *match_error_filter*
is a :class:`RegexPattern` or :class:`ChainedRegexPattern`, attempt
to :meth:`match` the input string that caused the failed match
against the value of *match_error_filter*. If it matches, do not
raise an error; otherwise raise an error.
:class:`~exceptions.RegexSuppressedError`: If :meth:`match` fails to
parse the input string and the above conditions involving
*match_error_filter* are not met. One of RegexParseError or
RegexSuppressedError is always raised on failure.
"""
self.clear() # to be safe
self.input_string = str_
m = self.regex.fullmatch(str_, *args)
if not m:
self.is_matched = False
if hasattr(self._match_error_filter, 'match'):
try:
self._match_error_filter.match(str_, *args)
except Exception as exc:
raise exceptions.RegexParseError(
f"Couldn't match {str_} against {self.regex}.")
raise exceptions.RegexSuppressedError(str_)
elif self._match_error_filter:
raise exceptions.RegexSuppressedError(str_)
else:
raise exceptions.RegexParseError(
f"Couldn't match {str_} against {self.regex}.")
else:
self.data = m.groupdict(default=NOTSET)
for k, v in self._defaults.items():
if self.data.get(k, NOTSET) is NOTSET:
self.data[k] = v
if self.input_field:
self.data[self.input_field] = m.string
self._validate_match(m)
if any(self.data[f] is NOTSET for f in self.fields):
bad_names = [f for f in self.fields if self.data[f] is NOTSET]
raise exceptions.RegexParseError((f"Couldn't match the "
f"following fields in {str_}: " + ', '.join(bad_names)))
self.is_matched = True
def _validate_match(self, match_obj):
"""Hook for post-processing of match, running after all fields are
assigned but before final check that all fields are set.
"""
pass
def __str__(self):
if not self.is_matched:
str_ = ', '.join(self.fields)
else:
str_ = ', '.join([f'{k}={v}' for k, v in self.data.items()])
return f"<{self.__class__.__name__}({str_})>"
def __copy__(self):
if hasattr(self._match_error_filter, 'copy'):
match_error_filter_copy = self._match_error_filter.copy()
else:
# bool or None
match_error_filter_copy = self._match_error_filter
obj = self.__class__(
self.regex.pattern,
defaults=self._defaults.copy(),
input_field=self.input_field,
match_error_filter=match_error_filter_copy,
)
obj.data = self.data.copy()
return obj
def __deepcopy__(self, memo):
obj = self.__class__(
copy.deepcopy(self.regex.pattern, memo),
defaults=copy.deepcopy(self._defaults, memo),
input_field=copy.deepcopy(self.input_field, memo),
match_error_filter=copy.deepcopy(self._match_error_filter, memo)
)
obj.data = copy.deepcopy(self.data, memo)
return obj
[docs]
class RegexPatternWithTemplate(RegexPattern):
"""Adds formatted output to :class:`RegexPattern`.
"""
template: str = ""
[docs]
def __init__(self, regex, defaults=None, input_field=None,
match_error_filter=None, template=None, log=_log):
"""Constructor.
Args:
template (str): Optional. Template string to use for formatting
contents of match in :meth:`format` method. Contents of the matched
fields will be subsituted using the {}-syntax of python string
formatting.
Other arguments are the same as in :class:`RegexPattern`.
"""
super(RegexPatternWithTemplate, self).__init__(regex, defaults=defaults,
input_field=input_field,
match_error_filter=match_error_filter)
self.template = template
for f in self.fields:
if f not in self.template:
log.warning("Field %s not included in output.", f)
def __copy__(self):
if hasattr(self._match_error_filter, 'copy'):
match_error_filter_copy = self._match_error_filter.copy()
else:
# bool or None
match_error_filter_copy = self._match_error_filter
obj = self.__class__(
self.regex.pattern,
defaults=self._defaults.copy(),
input_field=self.input_field,
match_error_filter=match_error_filter_copy,
template=self.template
)
obj.data = self.data.copy()
return obj
def __deepcopy__(self, memo):
obj = self.__class__(
copy.deepcopy(self.regex.pattern, memo),
defaults=copy.deepcopy(self._defaults, memo),
input_field=copy.deepcopy(self.input_field, memo),
match_error_filter=copy.deepcopy(self._match_error_filter, memo),
template=copy.deepcopy(self.template, memo)
)
obj.data = copy.deepcopy(self.data, memo)
return obj
[docs]
class ChainedRegexPattern(RegexPatternBase):
"""Class which takes an 'or' of multiple :class:`RegexPatterns to parse
data that may be represented as a string in one of multiple formats.
Matches are attempted on the supplied RegexPatterns in order, with the first
one that succeeds determining the parsed field values. Public methods work
the same as on :class:`RegexPattern`.
"""
[docs]
def __init__(self, *string_patterns, defaults=None, input_field=None,
match_error_filter=None):
"""Constructor.
Args:
string_patterns (iterable of :class:`RegexPattern`): Individual
regexes which will be tried, in order, when :meth:`match` is
called. Parsing will be done by the first RegexPattern whose
:meth:`match` succeeds.
.. note::
The constructor changes attributes on :class:`RegexPattern` objects
passed as *string_patterns*, so once the object is created its
component :class:`RegexPattern` objects shouldn't be accessed on
their own.
Other arguments and attributes are the same as in :class:`RegexPattern`.
"""
input_string: str
_match: int
is_matched: bool = False
# NB, changes attributes on patterns passed as arguments, so
# once created they can't be used on their own
new_pats = []
self.input_string = ""
self._match = -1
for pat in string_patterns:
if isinstance(pat, RegexPattern):
new_pats.append(pat)
elif isinstance(pat, ChainedRegexPattern):
new_pats.extend(pat._patterns)
else:
raise ValueError("Bad input")
self._patterns = tuple(string_patterns)
if input_field:
self.input_field = input_field
self._match_error_filter = match_error_filter
for pat in self._patterns:
if defaults:
pat.update_defaults(defaults)
if input_field:
pat.input_field = input_field
pat._match_error_filter = None
pat._update_fields()
self._update_fields()
@property
def is_matched(self):
return self._match >= 0
@property
def data(self):
if self.is_matched:
return self._patterns[self._match].data
else:
return dict()
[docs]
def clear(self):
for pat in self._patterns:
pat.clear()
self._match = -1
self.input_string = ""
def _update_fields(self):
self.fields = self._patterns[0].fields
for pat in self._patterns:
if pat.fields != self.fields:
raise ValueError("Incompatible fields.")
self.clear()
[docs]
def update_defaults(self, d):
if d:
for pat in self._patterns:
pat.update_defaults(d)
self._update_fields()
[docs]
def match(self, str_, *args):
self.clear()
self.input_string = str_
for i, pat in enumerate(self._patterns):
try:
pat.match(str_, *args)
if not pat.is_matched:
raise ValueError()
self._match = i
except ValueError:
continue
if not self.is_matched:
if hasattr(self._match_error_filter, 'match'):
try:
self._match_error_filter.match(str_, *args)
except Exception as exc:
raise exceptions.RegexParseError((f"Couldn't match {str_} "
f"against any pattern in {self.__class__.__name__}."))
raise exceptions.RegexSuppressedError(str_)
elif self._match_error_filter:
raise exceptions.RegexSuppressedError(str_)
else:
raise exceptions.RegexParseError((f"Couldn't match {str_} "
f"against any pattern in {self.__class__.__name__}."))
def __str__(self):
if not self.is_matched:
str_ = ', '.join(self.fields)
else:
str_ = ', '.join([f'{k}={v}' for k, v in self.data.items()])
return f"<{self.__class__.__name__}({str_})>"
def __copy__(self):
new_pats = (pat.copy() for pat in self._patterns)
return self.__class__(
*new_pats,
match_error_filter=self._match_error_filter.copy()
)
def __deepcopy__(self, memo):
new_pats = (copy.deepcopy(pat, memo) for pat in self._patterns)
return self.__class__(
*new_pats,
match_error_filter=copy.deepcopy(self._match_error_filter, memo)
)
# ---------------------------------------------------------
NOTSET = basic.sentinel_object_factory('NotSet')
"""
Sentinel object to detect uninitialized values for fields in :func:`mdtf_dataclass`
objects, for use in cases where ``None`` is a valid value for the field.
"""
MANDATORY = basic.sentinel_object_factory('Mandatory')
"""
Sentinel object to mark all :func:`mdtf_dataclass` fields that do not take a default
value. This is a workaround to avoid errors with non-default fields coming after
default fields in the dataclass auto-generated ``__init__`` method under
`inheritance <https://docs.python.org/3/library/dataclasses.html#inheritance>`__:
we use the second solution described in `<https://stackoverflow.com/a/53085935>`__.
"""
def _mdtf_dataclass_get_field_types(obj, f, log):
"""Common functionality for :func:`_mdtf_dataclass_type_coercion` and
:func:`_mdtf_dataclass_type_check`. Given a :py:class:`datacalsses.Field`
object *f*, return either a tuple of the type its value should be coerced to
and a tuple of the valid types its value can have, or (None, None) to signal
a case we don't handle.
"""
if not f.init:
# ignore fields that aren't handled at init
return None, None
value = getattr(obj, f.name)
# ignore unset field values, regardless of type
if value is None or value is NOTSET:
return None, None
# guess what types are valid
new_type = None
if f.type is typing.Any or isinstance(f.type, typing.TypeVar):
return None, None
if dataclasses.is_dataclass(f.type):
# ignore if type is a dataclass: use this type annotation to
# implement dataclass inheritance
if not isinstance(obj, f.type):
raise exceptions.DataclassParseError((f"Field {f.name} specified "
f"as dataclass {f.type.__name__}, which isn't a parent class "
f"of {obj.__class__.__name__}."))
return None, None
elif isinstance(f.type, typing._GenericAlias) \
or isinstance(f.type, typing._SpecialForm):
# type is a generic from typing module, eg "typing.List"
if f.type.__origin__ is typing.Union:
new_type = None # can't do coercion, but can test type
valid_types = list(f.type.__args__)
else:
try:
new_type = f.type.__origin__
valid_types = [new_type]
except Exception as exc:
log.debug(f"Caught exception when checking types for {f.type.__name__}", exc,
"Routine will return None")
return None, None # can't do anything in this case
else:
new_type = f.type
valid_types = [new_type]
# Get types of field's default value, if present. Dataclass doesn't
# require defaults to be same type as what's given for field.
if not isinstance(f.default, dataclasses._MISSING_TYPE):
valid_types.append(type(f.default))
if not isinstance(f.default_factory, dataclasses._MISSING_TYPE):
valid_types.append(type(f.default_factory()))
return new_type, valid_types
def _mdtf_dataclass_type_coercion(self, log):
"""Do type checking on all dataclass fields after the auto-generated
``__init__`` method, but before any ``__post_init__`` method.
.. warning::
Type checking logic used is specific to the ``typing`` module in python
3.7. It may or may not work on newer pythons, and definitely will not
work with 3.5 or 3.6. See `<https://stackoverflow.com/a/52664522>`__.
"""
for f in dataclasses.fields(self):
value = getattr(self, f.name, NOTSET)
new_type, valid_types = _mdtf_dataclass_get_field_types(self, f, log)
try:
if valid_types is None or isinstance(value, tuple(valid_types)):
continue # don't coerce if we're already a valid type
if new_type is None or hasattr(new_type, '__abstract_methods__'):
continue # can't do type coercion
else:
if hasattr(new_type, 'from_struct'):
new_value = new_type.from_struct(value)
elif isinstance(new_type, enum.Enum):
# need to use item syntax to create enum from name
new_value = new_type.__getitem__(value)
else:
new_value = new_type(value)
# https://stackoverflow.com/a/54119384 for implementation
object.__setattr__(self, f.name, new_value)
except (TypeError, ValueError, dataclasses.FrozenInstanceError) as exc:
raise exceptions.DataclassParseError((f"{self.__class__.__name__}: "
f"Couldn't coerce value {repr(value)} for field {f.name} from "
f"type {type(value)} to type {new_type}.")) from exc
except Exception as exc:
log.exception("%s: Caught exception: %r", self.__class__.__name__, exc)
raise exc
def _mdtf_dataclass_type_check(self, log):
"""Do type checking on all dataclass fields after ``__init__`` and
``__post_init__`` methods.
.. warning::
Type checking logic used is specific to the ``typing`` module in python
3.7. It may or may not work on newer pythons, and definitely will not
work with 3.5 or 3.6. See `<https://stackoverflow.com/a/52664522>`__.
"""
for f in dataclasses.fields(self):
value = getattr(self, f.name, NOTSET)
if value is None or value is NOTSET:
continue
if value is MANDATORY:
raise exceptions.DataclassParseError((f"{self.__class__.__name__}: "
f"No value supplied for mandatory field {f.name}."))
_, valid_types = _mdtf_dataclass_get_field_types(self, f, log)
if valid_types is not None and not isinstance(value, tuple(valid_types)):
log.exception("%s: Failed type check for field '%s': %s != %s.",
self.__class__.__name__, f.name, type(value), valid_types)
raise exceptions.DataclassParseError((f"{self.__class__.__name__}: "
f"Expected {f.name} to be {f.type}, got {type(value)} "
f"({repr(value)})."))
DEFAULT_MDTF_DATACLASS_KWARGS = {'init': True, 'repr': True, 'eq': True,
'order': False, 'unsafe_hash': False, 'frozen': False}
# declaration to allow calling with and without args: python cookbook 9.6
# https://github.com/dabeaz/python-cookbook/blob/master/src/9/defining_a_decorator_that_takes_an_optional_argument/example.py
[docs]
def mdtf_dataclass(cls=None, **deco_kwargs):
"""Wrap the Python :py:func:`~dataclasses.dataclass` class decorator to customize
dataclasses to provide rudimentary type checking and conversion. This
is hacky, since dataclasses don't enforce type annotations for their fields.
A better solution would be to use the third-party
`cattrs <https://github.com/Tinche/cattrs>`__ package, which has essentially
the same aim.
The decorator rewrites the class's constructor as follows:
1. Execute the auto-generated ``__init__`` method from Python
:py:func:`~dataclasses.dataclass`.
2. Verify that fields with ``MANDATORY`` default have been assigned values.
We have to work around the usual :py:func:`~dataclasses.dataclass` way of
doing this, because it leads to errors in the signature of the auto-generated
``__init__`` method under inheritance (mandatory fields can't come after
optional fields in the signature.)
3. Execute the class's ``__post_init__`` method, if defined, which can do
more complex type coercion and validation.
4. Finally, check each field's value to see if it's consistent with the given
type information. If not, attempt to coerce it to that type, using a
``from_struct`` method on that type if it exists.
.. warning::
Unlike :py:func:`~dataclasses.dataclass`, all fields **must** have a
*default* or *default_factory* defined. Fields which are mandatory must
have their default value set to the sentinel object ``MANDATORY``.
This is necessary in order for dataclass inheritance to work properly, and
is not currently enforced when the class is decorated.
Args:
cls (class): Class to be decorated.
deco_kwargs: Optional. Keyword arguments to pass to the Python
:py:func:`~dataclasses.dataclass` class decorator.
Raises:
:class:`~exceptions.DataclassParseError`: If we attempted to construct an
instance without giving values for ``MANDATORY`` fields, or if values
of some fields after ``__post_init__`` could not be coerced into the
types given in their annotation.
"""
dc_kwargs = DEFAULT_MDTF_DATACLASS_KWARGS.copy()
dc_kwargs.update(deco_kwargs)
if cls is None:
# called without arguments
return functools.partial(mdtf_dataclass, **dc_kwargs)
if not hasattr(cls, '__post_init__'):
# create dummy __post_init__ if none defined, so we can wrap it.
# contrast with what we do below in regex_dataclass()
def _dummy_post_init(self, *args, **kwargs): pass
type.__setattr__(cls, '__post_init__', _dummy_post_init)
# apply dataclasses' decorator
cls = dataclasses.dataclass(cls, **dc_kwargs)
# Do type coercion after dataclass' __init__, but before user __post_init__
# Do type check after __init__ and __post_init__
_old_post_init = cls.__post_init__
@functools.wraps(_old_post_init)
def _new_post_init(self, *args, **kwargs):
if hasattr(self, 'log'):
_post_init_log = self.log # for object hierarchy
else:
_post_init_log = _log # fallback: use module-level logger
_mdtf_dataclass_type_coercion(self, _post_init_log)
_old_post_init(self, *args, **kwargs)
_mdtf_dataclass_type_check(self, _post_init_log)
type.__setattr__(cls, '__post_init__', _new_post_init)
return cls
[docs]
def is_regex_dataclass(obj):
"""Returns True if *obj* is a :func:`regex_dataclass`.
"""
return hasattr(obj, '_is_regex_dataclass') and obj._is_regex_dataclass == True
def _regex_dataclass_preprocess_kwargs(self, kwargs):
"""Edit kwargs going to the auto-generated __init__ method of this dataclass.
If any fields are regex_dataclasses, construct and parse their values first.
Raises a DataclassParseError if different regex_dataclasses (at any level of
inheritance) try to assign different values to a field of the same name. We
do this by assigning to a :class:`~src.util.basic.ConsistentDict`.
"""
new_kw = filter_dataclass(kwargs, self, init='all')
new_kw = basic.ConsistentDict.from_struct(new_kw)
for cls_ in self.__class__.__bases__:
if not is_regex_dataclass(cls_):
continue
for f in dataclasses.fields(self):
if not f.type == cls_:
continue
if f.name in kwargs:
val = kwargs[f.name]
elif not isinstance(f.default, dataclasses._MISSING_TYPE):
val = f.default
elif not isinstance(f.default_factory, dataclasses._MISSING_TYPE):
val = f.default_factory()
else:
raise exceptions.DataclassParseError(f"Can't set value for {f.name}.")
new_d = dataclasses.asdict(f.type.from_string(val))
new_d = filter_dataclass(new_d, self, init='all')
try:
new_kw.update(new_d)
except exceptions.WormKeyError as exc:
raise exceptions.DataclassParseError((f"{self.__class__.__name__}: "
f"Tried to make inconsistent field assignment when parsing "
f"{f.name} as an instance of {f.type.__name__}.")) from exc
post_init = dict()
for f in dataclasses.fields(self):
if not f.init and f.name in new_kw:
post_init[f.name] = new_kw.pop(f.name)
return new_kw, post_init
[docs]
def regex_dataclass(pattern, **deco_kwargs):
"""Decorator combining the functionality of :class:`RegexPattern` and
:func:`mdtf_dataclass`: dataclass fields are parsed from a regex and coerced
to appropriate classes.
Specifically, this is done via a ``from_string`` classmethod, added by this
decorator, which creates dataclass instances by parsing an input string with
a :class:`RegexPattern` or :class:`ChainedRegexPattern`. The values of all
fields returned by the :meth:`~RegexPattern.match` method of the pattern are
passed to the ``__init__`` method of the dataclass as kwargs.
Additionally, if the type of one or more fields is set to a class that's
also been decorated with regex_dataclass, the parsing logic for that field's
regex_dataclass will be invoked on that field's value (i.e., a string obtained
by regex matching in *this* regex_dataclass), and the parsed values of those
fields will be supplied to this regex_dataclass constructor. This is our
implementation of composition for regex_dataclasses.
.. note::
Unlike :func:`mdtf_dataclass`, type coercion here is done *after*
``__post_init__`` for these dataclasses. This is necessary due to
composition: if a regex_dataclass is being instantiated as a field of
another regex_dataclass, all values being passed to it will be strings
(the regex fields), and type coercion is the job of ``__post_init__``.
"""
dc_kwargs = DEFAULT_MDTF_DATACLASS_KWARGS.copy()
dc_kwargs.update(deco_kwargs)
def _dataclass_decorator(cls):
if '__post_init__' not in cls.__dict__:
# Prevent class from inheriting __post_init__ from parents if it
# doesn't overload it (which is why we use __dict__ and not
# hasattr().) __post_init__ of all parents will have been called when
# the parent classes are instantiated by _regex_dataclass_preprocess_kwargs.
def _dummy_post_init(self, *args, **kwargs): pass
type.__setattr__(cls, '__post_init__', _dummy_post_init)
# apply dataclasses' decorator
cls = dataclasses.dataclass(cls, **dc_kwargs)
# check that all DCs specified as fields are also in class hierarchy
# so that we inherit their fields; probably no way this could happen though
for f in dataclasses.fields(cls):
if is_regex_dataclass(f.type) and f.type not in cls.__mro__:
raise TypeError((f"{cls.__name__}: Field {f.name} specified as "
f"{f.type.__name__}, but we don't inherit from it."))
_old_init = cls.__init__
@functools.wraps(_old_init)
def _new_init(self, first_arg=None, *args, **kwargs):
if isinstance(first_arg, str) and not args and not kwargs:
# instantiate from running regex on string, if a string is the
# only argument to the constructor
self._pattern.match(first_arg)
first_arg = None
kwargs = self._pattern.data
new_kw, other_kw = _regex_dataclass_preprocess_kwargs(self, kwargs)
for k, v in other_kw.items():
# set field values that aren't arguments to _old_init
object.__setattr__(self, k, v)
if first_arg is None:
_old_init(self, *args, **new_kw)
else:
_old_init(self, first_arg, *args, **new_kw)
_mdtf_dataclass_type_coercion(self, _log)
_mdtf_dataclass_type_check(self, _log)
type.__setattr__(cls, '__init__', _new_init)
def _from_string(cls_, str_, *args):
"""Create an object instance from a string representation *str_*.
Used by :func:`regex_dataclass` for parsing field values and automatic
type coercion.
"""
cls_._pattern.match(str_, *args)
return cls_(**cls_._pattern.data)
type.__setattr__(cls, 'from_string', classmethod(_from_string))
type.__setattr__(cls, '_is_regex_dataclass', True)
type.__setattr__(cls, '_pattern', pattern)
return cls
return _dataclass_decorator
[docs]
def filter_dataclass(d, dc, init=False):
"""Return a dict of the subset of fields or entries in *d* that correspond to
the fields in dataclass *dc*.
Args:
d (dict, dataclass or dataclass instance): Object to take field values from.
dc (dataclass or dataclass instance): Dataclass defining the set of fields
that are returned. Values of fields in *d* that are not fields of *dc*
are discarded.
init (bool or 'all'): Optional, default False. Controls whether `init-only fields
<https://docs.python.org/3/library/dataclasses.html#init-only-variables>`__
are included:
- If False: Include only the fields of *dc* as returned by
:py:func:`dataclasses.fields`.
- If True: Include only the arguments to *dc*'s constructor (i.e.,
include any init-only fields and exclude any of *dc*'s fields
with *init*=False.)
- If 'all': Include the union of the above two options.
Returns:
dict: The subset of key:value pairs from *d* such that the keys are
included in the set of *dc*'s fields specified by the value of *init*.
"""
assert dataclasses.is_dataclass(dc)
if dataclasses.is_dataclass(d):
if isinstance(d, type):
d = d() # d is a class; instantiate with default field values
d = dataclasses.asdict(d)
if not init or (init == 'all'):
ans = {f.name: d[f.name] for f in dataclasses.fields(dc) if f.name in d}
else:
ans = {f.name: d[f.name] for f in dataclasses.fields(dc)
if (f.name in d and f.init)}
if init or (init == 'all'):
init_fields = filter(
(lambda f: f.type == dataclasses.InitVar),
dc.__dataclass_fields__.values()
)
ans.update({f.name: d[f.name] for f in init_fields if f.name in d})
return ans
[docs]
def coerce_to_dataclass(d, dc, **kwargs):
"""Given a dataclass *dc* (may be the class or an instance of it), and a dict,
dataclass or dataclass instance *d*, return an instance of *dc*'s class with
field values initialized from those in *d*, along with any extra values
passed in *kwargs*.
Because this constructs a new dataclass instance, it copies field values
according to the *init*=True logic in :func:`filter_dataclass`.
Args:
d (dict, dataclass or dataclass instance): Object to take field values from.
dc (dataclass or dataclass instance): Class to instantiate.
kwargs: Optional. If provided, override field values provided in *d*.
Returns:
Instance of dataclass *dc* with field values populated from *kwargs* and *d*.
"""
new_kwargs = filter_dataclass(d, dc, init=True)
if kwargs:
new_kwargs.update(kwargs)
new_kwargs = filter_dataclass(new_kwargs, dc, init=True)
if not isinstance(dc, type):
dc = dc.__class__
return dc(**new_kwargs)