import builtins
import collections.abc
import copy
import inspect
import logging
from typing import List
import typing
from autofit.jax_wrapper import register_pytree_node_class, register_pytree_node
from autoconf.class_path import get_class_path
from autoconf.exc import ConfigException
from autofit.mapper.model import assert_not_frozen
from autofit.mapper.model_object import ModelObject
from autofit.mapper.prior.abstract import Prior
from autofit.mapper.prior.deferred import DeferredInstance
from autofit.mapper.prior.tuple_prior import TuplePrior
from autofit.mapper.prior_model.abstract import AbstractPriorModel
from autofit.tools.namer import namer
logger = logging.getLogger(__name__)
class_args_dict = dict()
[docs]@register_pytree_node_class
class Model(AbstractPriorModel):
"""
@DynamicAttrs
"""
@property
def name(self):
return self.cls.__name__
def __str__(self):
prior_string = ", ".join(map(str, self.prior_tuples))
return f"{self.name} {prior_string}"
def __repr__(self):
return f"<{self.__class__.__name__} {self}>"
def as_model(self):
return Model(self.cls)
def __hash__(self):
return self.id
def __add__(self, other):
if self.cls != other.cls:
raise TypeError(
f"Cannot add PriorModels with different classes "
f"({self.cls.__name__} and {other.cls.__name__})"
)
return super().__add__(other)
def __init__(
self,
cls,
**kwargs,
):
"""
The object a Python class is input into to create a model-component, which has free parameters that are fitted
by a non-linear search.
The ``Model`` object is flexible, and can create models from many input Python data structures
(e.g. a list of classes, dictionary of classes, hierarchy of classes).
For a complete description of the model composition API, see the **PyAutoFit** model API cookbooks:
https://pyautofit.readthedocs.io/en/latest/cookbooks/cookbook_1_basics.html
The Python class input into a ``Model`` to create a model component is written using the following format:
- The name of the class is the name of the model component (e.g. ``Gaussian``).
- The input arguments of the constructor are the parameters of the mode (e.g. ``centre``, ``normalization`` and ``sigma``).
- The default values of the input arguments tell PyAutoFit whether a parameter is a single-valued float or a
multi-valued tuple.
[Rich explain everything else]
Parameters
----------
cls
The class associated with this instance
Examples
--------
class Gaussian:
def __init__(
self,
centre=0.0, # <- PyAutoFit recognises these
normalization=0.1, # <- constructor arguments are
sigma=0.01, # <- the Gaussian's parameters.
):
self.centre = centre
self.normalization = normalization
self.sigma = sigma
model = af.Model(Gaussian)
"""
super().__init__(
label=namer(cls.__name__) if inspect.isclass(cls) else None,
)
if cls is self:
return
if not (inspect.isclass(cls) or inspect.isfunction(cls)):
raise AssertionError(f"{cls} is not a class or function")
self.cls = cls
try:
annotations = inspect.getfullargspec(cls).annotations
for key, value in annotations.items():
if isinstance(value, str):
annotations[key] = getattr(builtins, value)
except TypeError:
annotations = dict()
try:
arg_spec = inspect.getfullargspec(cls)
defaults = dict(
zip(arg_spec.args[-len(arg_spec.defaults) :], arg_spec.defaults)
)
except TypeError:
defaults = {}
args = self.constructor_argument_names
if "settings" in defaults:
del defaults["settings"]
if "settings" in args:
args.remove("settings")
for arg in args:
if isinstance(defaults.get(arg), str):
continue
if arg in kwargs:
keyword_arg = kwargs[arg]
if isinstance(keyword_arg, (list, dict)):
from autofit.mapper.prior_model.collection import Collection
ls = Collection(keyword_arg)
setattr(self, arg, ls)
else:
keyword_arg = self._convert_value(keyword_arg)
setattr(self, arg, keyword_arg)
elif arg in defaults and isinstance(defaults[arg], tuple):
setattr(self, arg, self.make_tuple_prior(arg, len(defaults[arg])))
elif arg in annotations and annotations[arg] is not float:
spec = annotations[arg]
if isinstance(spec, typing._GenericAlias) and spec.__origin__ is tuple:
setattr(self, arg, self.make_tuple_prior(arg, len(spec.__args__)))
# noinspection PyUnresolvedReferences
elif inspect.isclass(spec) and issubclass(spec, float):
from autofit.mapper.prior_model.annotation import (
AnnotationPriorModel,
)
setattr(self, arg, AnnotationPriorModel(spec, cls, arg))
elif hasattr(spec, "__args__") and type(None) in spec.__args__:
setattr(self, arg, None)
else:
annotation = annotations[arg]
if (
hasattr(annotation, "__origin__")
and issubclass(
annotation.__origin__, collections.abc.Collection
)
) or isinstance(annotation, collections.abc.Collection):
from autofit.mapper.prior_model.collection import Collection
value = Collection()
else:
value = Model(annotation)
setattr(self, arg, value)
else:
prior = self.make_prior(arg)
if (
isinstance(prior, ConfigException)
and hasattr(cls, "__default_fields__")
and arg in cls.__default_fields__
):
prior = defaults[arg]
setattr(self, arg, prior)
for key, value in kwargs.items():
if not hasattr(self, key):
setattr(self, key, self._convert_value(value))
try:
# noinspection PyTypeChecker
register_pytree_node(
self.cls,
self.instance_flatten,
self.instance_unflatten,
)
except ValueError:
pass
@staticmethod
def _convert_value(value):
if inspect.isclass(value):
value = Model(value)
if isinstance(value, int):
value = float(value)
return value
@property
def direct_argument_names(self) -> List[str]:
"""
The names of priors, constants and other attributes that are direct
attributes of this model.
"""
return [
t.name
for t in self.direct_prior_tuples
+ self.direct_prior_model_tuples
+ self.direct_instance_tuples
+ self.direct_deferred_tuples
+ self.direct_prior_tuples
]
[docs] def instance_flatten(self, instance):
"""
Flatten an instance of this model as a PyTree.
"""
return (
[getattr(instance, name) for name in self.direct_argument_names],
None,
)
[docs] def instance_unflatten(self, aux_data, children):
"""
Unflatten a PyTree into an instance of this model.
Parameters
----------
aux_data
children
Returns
-------
An instance of this model.
"""
return self.cls(**dict(zip(self.direct_argument_names, children)))
[docs] def tree_flatten(self):
"""
Flatten this model as a PyTree.
"""
names, priors = zip(*self.direct_prior_tuples)
return priors, (names, self.cls)
[docs] @classmethod
def tree_unflatten(cls, aux_data, children):
"""
Unflatten a PyTree into a model.
"""
names, cls_ = aux_data
arguments = {name: child for name, child in zip(names, children)}
return cls(cls_, **arguments)
[docs] def dict(self):
return {"class_path": get_class_path(self.cls), **super().dict()}
# noinspection PyAttributeOutsideInit
@property
def constructor_argument_names(self) -> List[str]:
"""
The argument names of the constructor of the class of this model.
"""
if self.cls not in class_args_dict:
try:
class_args_dict[self.cls] = inspect.getfullargspec(self.cls).args[1:]
except TypeError:
class_args_dict[self.cls] = []
return class_args_dict[self.cls]
def __eq__(self, other):
return (
isinstance(other, Model)
and self.cls == other.cls
and self.prior_tuples == other.prior_tuples
)
[docs] def make_prior(self, attribute_name):
"""
Returns a prior for an attribute of a class with a given name. The prior is
created by searching the default prior config for the attribute.
Entries in configuration with a u become uniform priors; with a g become
gaussian priors; with a c become instances.
If prior configuration for a given attribute is not specified in the
configuration for a class then the configuration corresponding to the parents
of that class is searched. If no configuration can be found then a prior
exception is raised.
Parameters
----------
attribute_name: str
The name of the attribute for which a prior is created
Returns
-------
prior: p.Prior
A prior
Raises
------
exc.PriorException
If no configuration can be found
"""
cls = self.cls
if not inspect.isclass(cls):
# noinspection PyProtectedMember
cls = inspect._findclass(cls)
try:
return Prior.for_class_and_attribute_name(cls, attribute_name)
except ConfigException as e:
return e
def make_tuple_prior(self, name, length):
tuple_prior = TuplePrior()
for i in range(length):
attribute_name = "{}_{}".format(name, i)
setattr(tuple_prior, attribute_name, self.make_prior(attribute_name))
return tuple_prior
@assert_not_frozen
def __setattr__(self, key, value):
try:
value.label = namer(key)
except (AttributeError, TypeError):
pass
if key not in (
"component_number",
"phase_property_position",
"mapping_name",
"id",
"_is_frozen",
"_frozen_cache",
):
try:
if "_" in key:
name = key.split("_")[0]
tuple_prior = [v for k, v in self.tuple_prior_tuples if name == k][
0
]
setattr(tuple_prior, key, value)
return
except IndexError:
pass
try:
super().__setattr__(key, value)
except AttributeError as e:
logger.exception(e)
logger.exception(key)
def __getattr__(self, item):
try:
if (
"_" in item
and item not in ("_is_frozen", "tuple_prior_tuples")
and not item.startswith("_")
):
return getattr(
[v for k, v in self.tuple_prior_tuples if item.split("_")[0] == k][
0
],
item,
)
except IndexError:
pass
self.__getattribute__(item)
@property
def is_deferred_arguments(self):
return len(self.direct_deferred_tuples) > 0
# noinspection PyUnresolvedReferences
def _instance_for_arguments(
self,
arguments: {ModelObject: object},
ignore_assertions=False,
):
"""
Returns an instance of the associated class for a set of arguments
Parameters
----------
arguments: {Prior: float}
Dictionary mapping_matrix priors to attribute analysis_path and value pairs
Returns
-------
An instance of the class
"""
model_arguments = dict()
attribute_arguments = {
key: value
for key, value in self.__dict__.items()
if key in self.constructor_argument_names
}
for tuple_prior in self.tuple_prior_tuples:
model_arguments[tuple_prior.name] = tuple_prior.prior.value_for_arguments(
arguments
)
for prior_model_tuple in self.direct_prior_model_tuples:
prior_model = prior_model_tuple.prior_model
model_arguments[
prior_model_tuple.name
] = prior_model.instance_for_arguments(
arguments,
ignore_assertions=ignore_assertions,
)
prior_arguments = dict()
for name, prior in self.direct_prior_tuples:
try:
prior_arguments[name] = arguments[prior]
except KeyError as e:
raise KeyError(f"No argument given for prior {name}") from e
constructor_arguments = {
**attribute_arguments,
**model_arguments,
**prior_arguments,
}
if self.is_deferred_arguments:
return DeferredInstance(self.cls, constructor_arguments)
if not inspect.isclass(self.cls):
result = object.__new__(inspect._findclass(self.cls))
cls = self.cls
cls(result, **constructor_arguments)
else:
result = self.cls(**constructor_arguments)
for key, value in self.__dict__.items():
if (
not hasattr(result, key)
and not isinstance(value, Prior)
and not key == "cls"
and not key.startswith("_")
):
if isinstance(value, Model):
value = value.instance_for_arguments(
arguments,
ignore_assertions=ignore_assertions,
)
elif isinstance(value, Prior):
value = arguments[value]
try:
setattr(result, key, value)
except AttributeError:
pass
return result
[docs] def gaussian_prior_model_for_arguments(self, arguments):
"""
Returns a new instance of model mapper with a set of Gaussian priors based on \
tuples provided by a previous nonlinear search.
Parameters
----------
arguments: [(float, float)]
Tuples providing the mean and sigma of gaussians
Returns
-------
new_model: ModelMapper
A new model mapper populated with Gaussian priors
"""
self.unfreeze()
new_model = copy.deepcopy(self)
new_model._assertions = list()
model_arguments = {t.name: arguments[t.prior] for t in self.direct_prior_tuples}
for tuple_prior_tuple in self.tuple_prior_tuples:
setattr(
new_model,
tuple_prior_tuple.name,
tuple_prior_tuple.prior.gaussian_tuple_prior_for_arguments(arguments),
)
for prior_tuple in self.direct_prior_tuples:
setattr(new_model, prior_tuple.name, model_arguments[prior_tuple.name])
for instance_tuple in self.direct_instance_tuples:
setattr(new_model, instance_tuple.name, instance_tuple.instance)
for name, prior_model in self.direct_prior_model_tuples:
setattr(
new_model,
name,
prior_model.gaussian_prior_model_for_arguments(arguments),
)
return new_model