Source code for autofit.non_linear.paths.database

import shutil
from typing import Optional, Union

from autoconf.output import conditional_output
from autofit.database.sqlalchemy_ import sa
from .abstract import AbstractPaths
import numpy as np

from autofit.database.model import Fit
from autoconf.dictable import to_dict
from autofit.database.aggregator.info import Info


[docs]class DatabasePaths(AbstractPaths): def __init__( self, session, name: Optional[str] = None, path_prefix: Optional[str] = None, is_identifier_in_paths=True, parent=None, save_all_samples=False, unique_tag: Optional["str"] = None, ): super().__init__( name=name, path_prefix=path_prefix, is_identifier_in_paths=is_identifier_in_paths, parent=parent, ) self.session = session self._fit = None self.save_all_samples = save_all_samples self.unique_tag = unique_tag __nullify_fields__ = ("session",) parent: "DatabasePaths" @AbstractPaths.parent.setter def parent(self, parent: "DatabasePaths"): """ The search performed before this search. For example, a search that is then compared to searches during a grid search. For database paths the parent must also be database paths. """ if not (parent is None or isinstance(parent, DatabasePaths)): raise TypeError( "The parent of search that uses the database must also use the database" ) self._parent = parent @property def is_grid_search(self) -> bool: return self.fit.is_grid_search
[docs] def create_child( self, name: Optional[str] = None, path_prefix: Optional[str] = None, is_identifier_in_paths: Optional[bool] = None, identifier: Optional[str] = None, ) -> "DatabasePaths": """ Create a paths object which is the child of some parent paths object. This is done during a GridSearch so that results can be stored in the correct directory. It also allows database fit objects to be related correctly. If no instance is set the prior median model is used to ensure that the parent object is queryable. Parameters ---------- identifier name path_prefix is_identifier_in_paths If False then this path's identifier will not be added to its output path. Returns ------- A new paths object """ self.fit.is_grid_search = True if self.fit.instance is None: self.fit.instance = self.model.instance_from_prior_medians( ignore_prior_limits=True ) child = type(self)( session=self.session, name=name or self.name, path_prefix=path_prefix or self.path_prefix, is_identifier_in_paths=( is_identifier_in_paths if is_identifier_in_paths is not None else self.is_identifier_in_paths ), parent=self, ) child.model = self.model child.search = self.search child._identifier = identifier return child
[docs] def zip_remove(self): """ Remove files from both the symlinked folder and the output directory """ self.session.commit() Info(self.session).write() if self.remove_files: shutil.rmtree(self.output_path, ignore_errors=True)
def __getstate__(self): d = self.__dict__.copy() del d["session"] return d
[docs] @conditional_output def save_json(self, name, object_dict: Union[dict, list], prefix: str = ""): """ Save a dictionary as a json file in the database Parameters ---------- name The name of the json object_dict The dictionary to save """ self.fit.set_json(name, object_dict)
[docs] def load_json(self, name: str, prefix: str = "") -> Union[dict, list]: """ Load a json file from the database Parameters ---------- name The name of the json Returns ------- The loaded dictionary """ return self.fit.get_json(name)
@property def samples(self): return self.fit.samples
[docs] @conditional_output def save_array(self, name, array: np.ndarray): """ Save an array as a json file in the database Parameters ---------- name The name of the array array The array to save """ self.fit.set_array(name, array)
[docs] def load_array(self, name: str) -> np.ndarray: """ Load an array from the database Parameters ---------- name The name of the array Returns ------- The loaded array """ return self.fit.get_array(name)
[docs] @conditional_output def save_fits(self, name: str, hdu, prefix: str = ""): """ Save a fits file in the database Parameters ---------- name The name of the fits file hdu The hdu to save """ self.fit.set_hdu(name, hdu)
[docs] def load_fits(self, name: str, prefix: str = ""): """ Load a fits file from the database Parameters ---------- name The name of the fits file Returns ------- The loaded hdu """ return self.fit.get_hdu(name)
@conditional_output def save_object(self, name: str, obj: object, prefix: str = ""): self.fit[name] = obj def load_object(self, name: str, prefix: str = ""): return self.fit[name] def remove_object(self, name: str): del self.fit[name] def is_object(self, name: str) -> bool: return name in self.fit def save_search_internal(self, obj): pass def load_search_internal(self): pass def remove_search_internal(self): pass @property def fit(self) -> Fit: if self._fit is None: try: self._fit = ( self.session.query(Fit).filter(Fit.id == self.identifier).one() ) except sa.orm.exc.NoResultFound: self._fit = Fit( id=self.identifier, is_complete=False, unique_tag=self.unique_tag, path_prefix=self.path_prefix, name=self.name, ) self.session.add(self._fit) if self.parent is not None: self._fit.parent = self.parent.fit return self._fit @property def is_complete(self) -> bool: return self.fit.is_complete def completed(self): self.fit.is_complete = True def save_summary( self, samples, latent_samples, log_likelihood_function_time, ): self.fit.instance = samples.max_log_likelihood() self.fit.max_log_likelihood = samples.max_log_likelihood_sample.log_likelihood
[docs] def save_samples(self, samples): if not self.save_all_samples: samples = samples.minimise() self.fit.samples = samples self.fit.set_json("samples_info", samples.samples_info)
[docs] def save_latent_samples(self, latent_samples): if not self.save_all_samples: latent_samples = latent_samples.minimise() self.fit.latent_samples = latent_samples
def load_latent_samples(self): return self.fit.latent_samples def _load_samples(self): samples = self.fit.samples samples.model = self.model return samples
[docs] def load_samples(self): return self._load_samples().sample_list
def load_samples_info(self): return self._load_samples().samples_info def save_all(self, info, *_, **kwargs): self.fit.info = info self.fit.model = self.model if info: self.save_json("info", info) self.save_json("search", to_dict(self.search)) self.save_json("model", to_dict(self.model)) self.session.commit() Info(self.session).write()