Source code for autofit.mapper.model

import copy
import logging
from functools import wraps
from typing import Optional, Union, Tuple, List, Iterable, Type, Dict

from autofit.mapper.model_object import ModelObject
from autofit.mapper.prior_model.recursion import DynamicRecursionCache

logger = logging.getLogger(__name__)


def frozen_cache(func):
    """
    Decorator that caches results from function calls when
    a model is frozen.

    Value is cached by function name, instance and arguments.

    Parameters
    ----------
    func
        Some function attached to a freezable, hashable object
        that takes hashable arguments

    Returns
    -------
    Function with cache
    """

    @wraps(func)
    def cache(self, *args, **kwargs):
        if hasattr(self, "_is_frozen") and self._is_frozen:
            key = (
                func.__name__,
                self,
                *args,
            ) + tuple(kwargs.items())

            if key not in self._frozen_cache:
                self._frozen_cache[key] = func(self, *args, **kwargs)
            return self._frozen_cache[key]
        return func(self, *args, **kwargs)

    return cache


def assert_not_frozen(func):
    """
    Decorator that asserts a function is not called when an object
    is frozen. For example, it should not be possible to set an
    attribute on a frozen model as that might invalidate the results
    in the cache.

    Parameters
    ----------
    func
        Some function

    Raises
    ------
    AssertionError
        If the function is called when the object is frozen
    """

    @wraps(func)
    def wrapper(self, *args, **kwargs):
        string_args = list(filter(lambda arg: isinstance(arg, str), args))
        if (
            "_is_frozen" not in string_args
            and "_frozen_cache" not in string_args
            and hasattr(self, "_is_frozen")
            and self._is_frozen
        ):
            raise AssertionError("Frozen models cannot be modified")
        return func(self, *args, **kwargs)

    return wrapper


class AbstractModel(ModelObject):
    def __init__(self, label=None, id_=None):
        self._is_frozen = False
        self._frozen_cache = dict()
        super().__init__(label=label, id_=id_)

    def __getstate__(self):
        return {
            key: value for key, value in self.__dict__.items() if key != "_frozen_cache"
        }

    def __setstate__(self, state):
        self.__dict__.update(state)
        self._frozen_cache = {}

    def freeze(self):
        """
        Freeze this object.

        A frozen object caches results for some function calls
        and does not allow its state to be modified.
        """
        logger.debug("Freezing model")
        tuples = self.direct_tuples_with_type(AbstractModel)
        for _, model in tuples:
            if model is not self:
                model.freeze()
        self._is_frozen = True

    def unfreeze(self):
        """
        Unfreeze this object. Allows modification and removes
        caches associated with some functions.
        """
        logger.debug("Thawing model")
        self._is_frozen = False
        tuples = self.direct_tuples_with_type(AbstractModel)
        for _, model in tuples:
            if model is not self:
                model.unfreeze()
        self._frozen_cache = dict()

    def __add__(self, other):
        instance = self.__class__()

        def add_items(item_dict):
            for key, value in item_dict.items():
                if isinstance(value, list) and hasattr(instance, key):
                    setattr(instance, key, getattr(instance, key) + value)
                else:
                    setattr(instance, key, value)

        add_items(self.__dict__)
        add_items(other.__dict__)
        return instance

    def copy(self):
        """
        Create a copy of the model. All priors remain equivalent - i.e. two
        copies of a model in a collection has the same prior count as a single
        model.
        """
        return copy.deepcopy(self)

    def object_for_path(
        self, path: Iterable[Union[str, int, type]]
    ) -> Union[object, List]:
        """
        Get the object at a given path.

        The path describes the location of some object in the model.

        String entries get an attribute.
        Int entries index an attribute.
        Type entries product a new ModelInstance which collates all of the instances
        of a given type in the path.

        Parameters
        ----------
        path
            A tuple describing the path to an object in the model tree

        Returns
        -------
        An object or Instance collating a collection of objects with a given type.
        """
        instance = self
        for name in path:
            if isinstance(name, int):
                instance = instance[name]
            elif isinstance(name, type):
                from autofit.mapper.prior_model.prior_model import Model

                instances = [
                    instance
                    for _, instance in self.path_instance_tuples_for_class(name)
                ]
                instances += [
                    instance
                    for _, instance in self.path_instance_tuples_for_class(Model)
                    if issubclass(instance.cls, name)
                ]
                instance = ModelInstance(instances)
            else:
                instance = getattr(instance, name)
        return instance

    @frozen_cache
    def path_instance_tuples_for_class(
        self,
        cls: Union[Tuple, Type],
        ignore_class: Optional[type] = None,
        ignore_children: bool = True,
    ):
        """
        Tuples containing the path tuple and instance for every instance of the class
        in the model tree.

        Parameters
        ----------
        ignore_class
            Children of instances of this class are ignored
        ignore_children
            If true do not continue to recurse the children of an object once found
        cls
            The type to find instances of

        Returns
        -------
        path_instance_tuples: [((str,), object)]
            Tuples containing the path to and instance of objects of the given type.
        """
        return path_instances_of_class(
            self, cls, ignore_class=ignore_class, ignore_children=ignore_children
        )

    @frozen_cache
    def direct_tuples_with_type(self, class_type):
        return list(
            filter(
                lambda t: t[0] != "id"
                and not t[0].startswith("_")
                and isinstance(t[1], class_type),
                self.__dict__.items(),
            )
        )

    @frozen_cache
    def models_with_type(
        self,
        cls: Union[Type, Tuple[Type, ...]],
        include_zero_dimension=False,
    ) -> List["AbstractModel"]:
        """
        Return all models of a given type in the model tree.

        Parameters
        ----------
        cls
            The type to find instances of
        include_zero_dimension
            If true, include models with zero dimensions

        Returns
        -------
        A list of models of the given type
        """
        # noinspection PyTypeChecker
        return [
            t[1]
            for t in self.model_tuples_with_type(
                cls, include_zero_dimension=include_zero_dimension
            )
        ]

    @frozen_cache
    def model_tuples_with_type(
        self, cls: Union[Type, Tuple[Type, ...]], include_zero_dimension=False
    ):
        """
        All models of the class in this model which have at least
        one free parameter, recursively.

        Parameters
        ----------
        cls
            The type of the model
        include_zero_dimension
            If true, include models with 0 free parameters

        Returns
        -------
        Models with free parameters
        """
        from .prior_model.prior_model import Model

        return [
            (path, model)
            for path, model in self.attribute_tuples_with_type(
                Model, ignore_children=False
            )
            if issubclass(model.cls, cls)
            and (include_zero_dimension or model.prior_count > 0)
        ]

    @frozen_cache
    def attribute_tuples_with_type(
        self,
        class_type,
        ignore_class=None,
        ignore_children=True,
    ) -> List[tuple]:
        """
        Tuples describing the name and instance for attributes in the model
        with a given type, recursively.

        Parameters
        ----------
        ignore_children
            If True then recursion stops at instances with the type
        class_type
            The type of the objects to find
        ignore_class
            Any classes which should not be recursively searched

        Returns
        -------
        Tuples containing the name and instance of each attribute with the type
        """
        return [
            (path[-1] if len(path) > 0 else "", value)
            for path, value in self.path_instance_tuples_for_class(
                class_type,
                ignore_class=ignore_class,
                ignore_children=ignore_children,
            )
        ]


@DynamicRecursionCache()
def path_instances_of_class(
    obj,
    cls: type,
    ignore_class: Optional[Union[type, Tuple[type]]] = None,
    ignore_children: bool = False,
):
    """
    Recursively search the object for instances of a given class

    Parameters
    ----------
    obj
        The object to recursively search
    cls
        The type to search for
    ignore_class
        A type or tuple of classes to skip
    ignore_children
        If true stop recursion at found objects

    Returns
    -------
    instance of type
    """
    if ignore_class is not None and isinstance(obj, ignore_class):
        return []

    results = []
    if isinstance(obj, cls):
        results.append((tuple(), obj))
        if ignore_children:
            return results

    if isinstance(obj, list):
        for i, item in enumerate(obj):
            for path, instance in path_instances_of_class(
                item, cls, ignore_class=ignore_class, ignore_children=ignore_children
            ):
                results.append(((i,) + path, instance))
        return results

    try:
        from autofit.mapper.prior_model.annotation import AnnotationPriorModel

        if isinstance(obj, dict):
            d = obj
        else:
            d = obj.__dict__

        for key, value in d.items():
            if key.startswith("_"):
                continue
            for item in path_instances_of_class(
                value, cls, ignore_class=ignore_class, ignore_children=ignore_children
            ):
                if isinstance(value, AnnotationPriorModel):
                    path = (key,)
                else:
                    path = (key, *item[0])
                results.append((path, item[1]))
        return results
    except (AttributeError, TypeError):
        return results


[docs] class ModelInstance(AbstractModel): """ An instance of a Collection or Model. This is created by optimisers and correspond to a point in the parameter space. @DynamicAttrs """ __dictable_type__ = "instance" def __init__(self, child_items: Optional[Union[List, Dict]] = None, id_=None): """ An instance of a Collection or Model. This is created by optimisers and correspond to a point in the parameter space. Parameters ---------- child_items The child items of the instance. This can be a list or dict. If a list, the items are assigned to the instance in order. If a dict, the items are assigned to the instance by key and accessed by attribute. """ super().__init__() self.child_items = child_items self.id = id_ def __eq__(self, other): try: return self.__dict__ == other.__dict__ except AttributeError: return False def __getitem__(self, item): if isinstance(item, int): return list(self.values())[item] if isinstance(item, slice): return ModelInstance(list(self.values())[item]) return self.__dict__[item] def __setitem__(self, key, value): self.__dict__[key] = value @property def child_items(self): return self.dict @child_items.setter def child_items(self, child_items): if isinstance(child_items, list): for i, item in enumerate(child_items): self[i] = item if isinstance(child_items, dict): for key, value in child_items.items(): self[key] = value def items(self): return self.dict.items() def __hash__(self): return self.id @property def dict(self): return { key: value for key, value in self.__dict__.items() if key not in ("id", "component_number", "item_number") and not (isinstance(key, str) and key.startswith("_")) }
[docs] def tree_flatten(self) -> Tuple[List, Tuple]: """ Flatten the instance into a PyTree """ keys, values = zip(*self.dict.items()) return values, ( *keys, self.id, )
[docs] @classmethod def tree_unflatten( cls, aux_data: Tuple, children: List, ): """ Create an instance from a flattened PyTree Parameters ---------- aux_data Auxiliary information that remains unchanged including the keys of the dict children Child objects subject to change Returns ------- An instance of this class """ *keys, id_ = aux_data instance = cls(id_=id_) for key, value in zip(keys, children): instance[key] = value return instance
def values(self): return self.dict.values() def __len__(self): return len(self.values())
[docs] def as_model( self, model_classes: Union[type, Iterable[type]] = tuple(), excluded_classes: Union[type, Iterable[type]] = tuple(), ): """ Convert this instance to a model Parameters ---------- model_classes The classes to convert to models excluded_classes The classes to exclude from conversion Returns ------- A model """ from autofit.mapper.prior_model.abstract import AbstractPriorModel return AbstractPriorModel.from_instance( self, model_classes, excluded_classes, )