Source code for autofit.mapper.prior_model.collection

from collections.abc import Iterable

from autofit.jax_wrapper import register_pytree_node_class

from autofit.mapper.model import ModelInstance, assert_not_frozen
from autofit.mapper.prior.abstract import Prior
from autofit.mapper.prior_model.abstract import AbstractPriorModel


[docs]@register_pytree_node_class class Collection(AbstractPriorModel):
[docs] def name_for_prior(self, prior: Prior) -> str: """ Construct a name for the prior. This is the path taken to get to the prior. Parameters ---------- prior Returns ------- A string of object names joined by underscores """ for name, prior_model in self.prior_model_tuples: prior_name = prior_model.name_for_prior(prior) if prior_name is not None: return "{}_{}".format(name, prior_name) for name, direct_prior in self.direct_prior_tuples: if prior == direct_prior: return name
def tree_flatten(self): keys, values = zip(*self.items()) return values, keys @classmethod def tree_unflatten(cls, aux_data, children): instance = cls() for key, value in zip(aux_data, children): setattr(instance, key, value) return instance def __contains__(self, item): return item in self._dict or item in self._dict.values() def __getitem__(self, item): if isinstance(item, str): return self._dict[item] return self.values[item] def __len__(self): return len(self.values) def __str__(self): return "\n".join(f"{key} = {value}" for key, value in self.items()) def __hash__(self): return self.id def __repr__(self): return f"<{self.__class__.__name__} {self}>" @property def values(self): return list(self._dict.values()) def items(self): return self._dict.items()
[docs] def with_prefix(self, prefix: str): """ Filter members of the collection, only returning those that start with a given prefix as a new collection. """ return Collection( {key: value for key, value in self.items() if key.startswith(prefix)} )
def as_model(self): return Collection( { key: value.as_model() if isinstance(value, AbstractPriorModel) else value for key, value in self.dict().items() } ) def __init__( self, *arguments, **kwargs, ): """ The object multiple Python classes are input into to create model-components, which has free parameters that are fitted by a non-linear search. Multiple Python classes can be input into a `Collection` in order to compose high dimensional models made of multiple model-components. The ``Collection`` object is highly flexible, and can create models from many input Python data structures (e.g. a list of classes, dictionary of classes, hierarchy of classes). For a complete description of the model composition API, see the **PyAutoFit** model API cookbooks: https://pyautofit.readthedocs.io/en/latest/cookbooks/cookbook_1_basics.html The Python class input into a ``Model`` to create a model component is written using the following format: - The name of the class is the name of the model component (e.g. ``Gaussian``). - The input arguments of the constructor are the parameters of the mode (e.g. ``centre``, ``normalization`` and ``sigma``). - The default values of the input arguments tell PyAutoFit whether a parameter is a single-valued float or a multi-valued tuple. [Rich document more clearly] A prior model used to represent a list of prior models for convenience. Arguments are flexibly converted into a collection. Parameters ---------- arguments Classes, prior models, instances or priors Examples -------- class Gaussian: def __init__( self, centre=0.0, # <- PyAutoFit recognises these normalization=0.1, # <- constructor arguments are sigma=0.01, # <- the Gaussian's parameters. ): self.centre = centre self.normalization = normalization self.sigma = sigma model = af.Collection(gaussian_0=Gaussian, gaussian_1=Gaussian) """ super().__init__() self.item_number = 0 arguments = list(arguments) if len(arguments) == 0: self.add_dict_items(kwargs) elif len(arguments) == 1: arguments = arguments[0] if isinstance(arguments, dict): self.add_dict_items(arguments) elif isinstance(arguments, Iterable): for argument in arguments: self.append(argument) else: self.append(arguments) else: self.__init__(arguments) @assert_not_frozen def add_dict_items(self, item_dict): for key, value in item_dict.items(): if isinstance(key, tuple): key = ".".join(key) setattr(self, key, AbstractPriorModel.from_object(value)) def __eq__(self, other): if other is None: return False if len(self) != len(other): return False for i, item in enumerate(self): if item != other[i]: return False return True @assert_not_frozen def append(self, item): setattr(self, str(self.item_number), AbstractPriorModel.from_object(item)) self.item_number += 1 @assert_not_frozen def __setitem__(self, key, value): obj = AbstractPriorModel.from_object(value) try: obj.id = getattr(self, str(key)).id except AttributeError: pass setattr(self, str(key), obj) @assert_not_frozen def __setattr__(self, key, value): if key.startswith("_"): super().__setattr__(key, value) else: try: super().__setattr__(key, AbstractPriorModel.from_object(value)) except AttributeError: pass def remove(self, item): for key, value in self.__dict__.copy().items(): if value == item: del self.__dict__[key] def gaussian_prior_model_for_arguments(self, arguments): """ Create a new collection, updating its priors according to the argument dictionary. Parameters ---------- arguments A dictionary of arguments Returns ------- A new collection """ collection = Collection() for key, value in self.items(): if key in ("component_number", "item_number", "id") or key.startswith("_"): continue if isinstance(value, AbstractPriorModel): collection[key] = value.gaussian_prior_model_for_arguments(arguments) if isinstance(value, Prior): collection[key] = arguments[value] return collection def _instance_for_arguments( self, arguments, ignore_assertions=False, ): """ Parameters ---------- arguments: {Prior: float} A dictionary of arguments Returns ------- model_instances: [object] A list of instances constructed from the list of prior models. """ result = ModelInstance() for key, value in self.__dict__.items(): if key.startswith("_"): continue if isinstance(value, AbstractPriorModel): value = value.instance_for_arguments( arguments, ignore_assertions=ignore_assertions, ) elif isinstance(value, Prior): value = arguments[value] setattr(result, key, value) return result
[docs] def gaussian_prior_model_for_arguments(self, arguments): """ Create a new collection, updating its priors according to the argument dictionary. Parameters ---------- arguments A dictionary of arguments Returns ------- A new collection """ collection = Collection() for key, value in self.items(): if key in ("component_number", "item_number", "id") or key.startswith("_"): continue if isinstance(value, AbstractPriorModel): collection[key] = value.gaussian_prior_model_for_arguments(arguments) if isinstance(value, Prior): collection[key] = arguments[value] return collection
@property def prior_class_dict(self): return { **{ prior: cls for prior_model in self.direct_prior_model_tuples for prior, cls in prior_model[1].prior_class_dict.items() }, **{prior: ModelInstance for _, prior in self.direct_prior_tuples}, }