Source code for autofit.database.aggregator.aggregator

import logging
from abc import ABC, abstractmethod
from typing import Optional, List, Union, cast

from ..sqlalchemy_ import sa

from autofit.database import query as q
from autofit.database.aggregator.info import Info
from .scrape import Scraper
from autofit.database import model as m
from ..query.query import AbstractQuery, Attribute
from ..query.query.attribute import BestFitQuery

logger = logging.getLogger(__name__)


class NullPredicate(AbstractQuery):
    @property
    def fit_query(self) -> str:
        return "SELECT id FROM fit"

    def __and__(self, other):
        return other


class Query:
    """
    API for creating a query on the best fit instance
    """

    @staticmethod
    def for_name(name: str) -> q.Q:
        """
        Create a query for fits based on the name of a
        top level instance attribute

        Parameters
        ----------
        name
            The name of the attribute. e.g. galaxies

        Returns
        -------
        A query generating object
        """
        return q.Q(name)

    def __getattr__(self, name):
        return self.for_name(name)


class FitQuery(Query):
    """
    API for creating a query on the attributes of a fit,
    such as:
        name
        unique_tag
        path_prefix
        is_complete
        is_grid_search
    """

    @staticmethod
    def for_name(name: str) -> Union[AbstractQuery, Attribute]:
        """
        Create a query based on some attribute of the Fit.

        Parameters
        ----------
        name
            The name of an attribute of the Fit class

        Returns
        -------
        A query based on an attribute

        Examples
        --------
        aggregator.fit.name == 'example name'
        """
        if name not in m.fit_attributes:
            raise AttributeError(f"Fit has no attribute {name}")
        if m.fit_attributes[name].type.python_type == bool:
            return q.BA(name)
        return q.A(name)


class Reverse:
    def __init__(self, item):
        self.item = item

    @property
    def attribute(self):
        return self.item.attribute


class AbstractAggregator(ABC):
    @property
    @abstractmethod
    def fits(self) -> List[m.Fit]:
        pass

    def values(self, name: str, parser=lambda o: o) -> list:
        """
        Retrieve the value associated with each fit with the given
        parameter name

        Parameters
        ----------
        name
            The name of some pickle, such as 'samples'
        parser
            A function to parse the value

        Returns
        -------
        A list of objects, one for each fit
        """
        values = list()
        for fit in self:
            value = fit[name]
            if value is not None:
                values.append(parser(value))

        return values

    def child_values(self, name: str) -> List[list]:
        """
        Retrieve the value associated with each fit with the given
        parameter name

        Parameters
        ----------
        name
            The name of some pickle, such as 'samples'

        Returns
        -------
        A list of objects, one for each fit
        """
        return [fit.child_values(name) for fit in self]

    def __iter__(self):
        return iter(self.fits)

    def __len__(self):
        return len(self.fits)

    def __eq__(self, other):
        if isinstance(other, list):
            return self.fits == other
        return super().__eq__(other)


[docs]class Aggregator(AbstractAggregator): def __init__( self, session: sa.orm.Session, filename: Optional[str] = None, predicate: AbstractQuery = NullPredicate(), offset=0, limit=None, order_bys=None, top_level_only=True, ): """ Query results from an intermediary SQLite database. Results can be scraped from a directory structure and stored in the database. Parameters ---------- session A session for communicating with the database. filename The path to the database file. If None, the database is in memory. predicate A predicate to filter the results by. offset The number of results to skip limit The maximum number of results to return order_bys A list of attributes to order the results by top_level_only If True, only return the top level fits """ self.session = session self.filename = filename self._fits = None self._predicate = predicate self._offset = offset self._limit = limit self._order_bys = order_bys or list() self._top_level_only = top_level_only
[docs] def order_by(self, item: Attribute, reverse=False) -> "Aggregator": """ Order the results by a given attribute of the search. Can be applied multiple times with the first application taking precedence. Parameters ---------- item An attribute of the search reverse If True reverse the results Returns ------- An aggregator with ordering applied Examples -------- aggregator = aggregator.order_by( aggregator.search.unique_tag ) """ if reverse: item = Reverse(item) return self._new_with(order_bys=self._order_bys + [item])
@property def search(self) -> FitQuery: """ An object facilitating queries on fit attributes such as: name unique_tag path_prefix is_complete is_grid_search """ return FitQuery() @property def info(self): """ Query info associated with the fit in the info dictionary """ return q.AnonymousInfo() @property def fits(self) -> List[m.Fit]: """ Lazily query the database for a list of Fit objects that match the aggregator's predicate. """ if self._fits is None: self._fits = self._fits_for_query(self._predicate.fit_query) return self._fits def map(self, func): for fit in self.fits: yield func(fit) def __repr__(self): return f"<{self.__class__.__name__} {self.filename} {len(self)}>" @property def model(self) -> Query: """ Facilitates query construction. If the Fit class has an attribute with the given name then a predicate is generated based on that attribute. Otherwise the query is assumed to apply to the best fit instance. Returns ------- A query """ return Query()
[docs] def __call__(self, predicate) -> "Aggregator": """ Concise query syntax """ return self.query(predicate)
[docs] def query(self, predicate: AbstractQuery) -> "Aggregator": # noinspection PyUnresolvedReferences """ Apply a query on the model. Parameters ---------- predicate A predicate constructed to express which models should be included. Returns ------- A list of objects that match the predicate Examples -------- >>> >>> aggregator = Aggregator.from_database( >>> "my_database.sqlite" >>> ) >>> >>> lens = aggregator.galaxies.lens >>> >>> aggregator.filter((lens.bulge == SersicCore) & (lens.disk == Sersic)) >>> aggregator.filter((lens.bulge == SersicCore) | (lens.disk == Sersic)) """ return self._new_with(predicate=self._predicate & predicate)
def _new_with( self, type_=None, **kwargs, ) -> "Aggregator": """ Create a new instance with the same attribute values except for those overridden by kwargs Parameters ---------- type_ The type of the new instance (defaults to Aggregator) kwargs Names and values of attributes to override Returns ------- A new Aggregator with the same attributes except where they have been overridden """ kwargs = { "session": self.session, "filename": self.filename, "predicate": self._predicate, "order_bys": self._order_bys, "top_level_only": self._top_level_only, **kwargs, } type_ = type_ or type(self) return type_(**kwargs) def __getitem__(self, item): offset = self._offset limit = self._limit if isinstance(item, int): return self.fits[item] elif isinstance(item, slice): if item.start is not None: if item.start >= 0: offset += item.start else: offset = len(self) + item.start if item.stop is not None: if item.stop >= 0: limit = len(self) - item.stop - offset else: limit = len(self) + item.stop return self._new_with(offset=offset, limit=limit) def _fits_for_query(self, query: str) -> List[m.Fit]: """ Execute a raw SQL query and return a Fit object for each Fit id returned by the query Parameters ---------- query A SQL query that selects ids from the fit table Returns ------- A list of fit objects, one for each id returned by the query """ logger.debug(f"Executing query: {query}") fit_ids = {row[0] for row in self.session.execute(query)} logger.info(f"{len(fit_ids)} fit(s) found matching query") query = self.session.query(m.Fit).filter(m.Fit.id.in_(fit_ids)) for order_by in self._order_bys: attribute = getattr(m.Fit, order_by.attribute) if isinstance(order_by, Reverse): attribute = sa.desc(attribute) query = query.order_by(attribute) fits = query.offset(self._offset).limit(self._limit).all() if self._top_level_only: return [fit for fit in fits if fit.parent is None] return fits
[docs] def add_directory( self, directory: str, auto_commit=True, reference: Optional[dict] = None, completed_only: bool = False, ): """ Recursively search a directory for autofit results and add them to this database. Any pickles found in the pickles file are implicitly added to the fit object. Warnings -------- If a directory is added twice then that will result in duplicate entries in the database. Parameters ---------- auto_commit If True the session is committed writing the new objects to the database directory A directory containing autofit results embedded in a file structure reference A dictionary mapping the names of objects in the model to their class path. completed_only If true only searches that have completed are added """ scraper = Scraper( directory, self.session, reference=reference, completed_only=completed_only, ) scraper.scrape() if auto_commit: self.session.commit() Info(self.session).write()
[docs] @classmethod def from_database( cls, filename: str, completed_only: bool = False, top_level_only: bool = True, ) -> "Aggregator": """ Create an instance from a sqlite database file. If no file exists then one is created with the schema of the database. Parameters ---------- completed_only If True only completed fits are returned filename The name of the database file. top_level_only If True only top level fits are returned Returns ------- An aggregator connected to the database specified by the file. """ from autofit.database import open_database session = open_database(str(filename)) aggregator = Aggregator(session, filename, top_level_only=top_level_only) if completed_only: return aggregator(aggregator.search.is_complete) return aggregator
[docs] def grid_searches(self) -> "GridSearchAggregator": """ Filter to only grid searches and return an aggregator with grid search specific functionality. Grid searches are initially implicitly ordered by their id """ return cast( GridSearchAggregator, self._new_with( type_=GridSearchAggregator, predicate=self._predicate & self.search.is_grid_search, order_bys=[Attribute("id")], top_level_only=False, ), )
class GridSearchAggregator(Aggregator): def __init__( self, session: sa.orm.Session, filename: Optional[str] = None, predicate: AbstractQuery = NullPredicate(), offset=0, limit=None, order_bys=None, top_level_only=False, ): super().__init__( session=session, filename=filename, predicate=predicate, offset=offset, limit=limit, order_bys=order_bys, top_level_only=top_level_only, ) def best_fits(self) -> "GridSearchAggregator": """ The best fit from each of the grid searches Best fits are initially implicitly ordered by their parent id """ return self._new_with( predicate=BestFitQuery(self._predicate), order_bys=[Attribute("parent_id")] ) def children(self) -> "GridSearchAggregator": """ An aggregator comprising the children of the fits encapsulated by this aggregator. This is used to query children in a grid search. Children are initially implicitly ordered by their parent id """ return self._new_with( predicate=q.ChildQuery(self._predicate), order_bys=[Attribute("parent_id")] ) def cell_number(self, number: int) -> "CellAggregator": """ Create an aggregator for accessing all values for child fits with a given index, ordered by parameter values. Parameters ---------- number The number of the fit in the grid search Returns ------- An aggregator comprising fits for a given cell for each grid search """ return CellAggregator(number, self) class CellAggregator(AbstractAggregator): def __init__(self, number: int, aggregator: GridSearchAggregator): """ Aggregator for accessing data for a specific fit number in each grid search. Parameters ---------- number The number of the fit aggregator An aggregator comprising 0 or more grid searches """ self.number = number self.aggregator = aggregator self._fits = None @property def fits(self) -> List[m.Fit]: """ Retrieve one fit for each grid search matching the number of the cell. """ if self._fits is None: self._fits = list() for fit in self.aggregator: self._fits.append( sorted( fit.children, key=lambda f: f.model.order_no if f.model is not None else 0, )[self.number] ) return self._fits