Source code for autofit.non_linear.analysis.analysis

import inspect
import logging
from abc import ABC
import functools
import numpy as np
import time
from typing import Optional, Dict

from autofit import exc
from autofit.mapper.prior_model.abstract import AbstractPriorModel
from autofit.non_linear.paths.abstract import AbstractPaths
from autofit.non_linear.samples.summary import SamplesSummary
from autofit.non_linear.samples.pdf import SamplesPDF
from autofit.non_linear.result import Result
from autofit.non_linear.samples.samples import Samples
from autofit.non_linear.samples.sample import Sample

from .visualize import Visualizer
from ..samples.util import simple_model_for_kwargs

logger = logging.getLogger(__name__)


[docs] class Analysis(ABC): """ Protocol for an analysis. Defines methods that can or must be implemented to define a class that compute the likelihood that some instance fits some data. """ Result = Result Visualizer = Visualizer LATENT_KEYS = [] def __init__( self, use_jax: bool = False, use_jax_for_visualization: bool = False, **kwargs, ): import os if os.environ.get("PYAUTO_DISABLE_JAX") == "1": use_jax = False use_jax_for_visualization = False # If the user requested JAX but it isn't installed (e.g. Python <3.11 # without the [jax] extra), fall back to numpy with a loud warning # rather than crashing later when the analysis tries to jit-compile. if use_jax: import importlib.util import warnings if importlib.util.find_spec("jax") is None: warnings.warn( "\n" "+----------------------------------------------------------------------+\n" "| use_jax=True was requested but JAX is not installed. |\n" "| |\n" "| Falling back to numpy. The fit will run, but JAX acceleration |\n" "| (typically 10-100x for large lens models) is unavailable. |\n" "| |\n" "| To enable JAX, install on Python 3.11+ via your library's [jax] |\n" "| extra, e.g.: pip install autolens[jax] |\n" "+----------------------------------------------------------------------+", UserWarning, stacklevel=2, ) use_jax = False use_jax_for_visualization = False if use_jax_for_visualization and not use_jax: logger.warning( "use_jax_for_visualization=True requires use_jax=True; " "disabling use_jax_for_visualization." ) use_jax_for_visualization = False self._use_jax = use_jax self._use_jax_for_visualization = use_jax_for_visualization self.kwargs = kwargs
[docs] def fit_for_visualization(self, instance): """ Build the fit used by the visualizer. Dispatch over ``self.fit_from`` with an opt-in ``jax.jit`` fast path: * ``use_jax_for_visualization=False`` (default) — plain ``self.fit_from(instance)``. Untouched by JAX. * ``use_jax_for_visualization=True`` — lazily construct ``jax.jit(self.fit_from)`` on the first call and cache it on the instance as ``_jitted_fit_from``, then call that for every subsequent visualization. The first call pays the compile cost; subsequent calls reuse the cached compiled function. Caching is per-``Analysis`` instance so each analysis gets its own compiled function keyed off that instance's closed-over state (``self.dataset``, ``self.settings``, etc. — these ride as pytree aux data via ``register_instance_pytree(FitImaging, no_flatten=...)`` in PyAutoLens). ``fit_from`` is defined by Analysis subclasses (e.g. ``AnalysisImaging``), not the base class — this method is only callable on subclasses that provide it. Downstream visualizers should prefer this over calling ``fit_from`` directly so the JIT seam stays in one place. For the JIT path to succeed, the ``Fit*`` return type (and every nested autoarray / galaxy / lens type it carries) must be pytree- registered. That wiring lives in each analysis subclass (see ``AnalysisImaging._register_fit_imaging_pytrees`` in PyAutoLens). Variants that have not yet been pytree-audited must leave ``use_jax_for_visualization`` at its default of ``False``. """ if not self._use_jax_for_visualization: return self.fit_from(instance=instance) if getattr(self, "_jitted_fit_from", None) is None: import jax self._jitted_fit_from = jax.jit(self.fit_from) return self._jitted_fit_from(instance)
def __getattr__(self, item: str): """ If a method starts with 'visualize_' then we assume it is associated with the Visualizer and forward the call to the visualizer. It may be desirable to remove this behaviour as the visualizer component of the system becomes more sophisticated. """ if item.startswith("visualize") or item.startswith("should_visualize"): _method = getattr(self.Visualizer, item) else: raise AttributeError(f"Analysis has no attribute {item}") def method(*args, **kwargs): parameters = inspect.signature(_method).parameters if "analyses" in parameters: logger.debug(f"Skipping {item} as this is not a combined analysis") return return _method(self, *args, **kwargs) return method @property def _xp(self): if self._use_jax: import jax.numpy as jnp return jnp return np
[docs] def compute_latent_samples(self, samples: Samples, batch_size : Optional[int] = None) -> Optional[Samples]: """ Compute latent variables from a model instance. A latent variable is not itself a free parameter of the model but can be derived from it. Latent variables may provide physically meaningful quantities that help interpret a model fit, and their values (with errors) are stored in `latent.csv` in parallel with `samples.csv`. This implementation is designed to be compatible with both NumPy and JAX: - It is written to be side-effect free, so it can be JIT-compiled with `jax.jit`. - It can be vectorized over many parameter sets at once using `jax.vmap`, enabling efficient batched evaluation of latent variables for multiple samples. - Returned values should be simple JAX/NumPy scalars or arrays (no Python objects), so they can be stacked into arrays of shape `(n_samples, n_latents)` for batching. - Any NaNs introduced (e.g. from invalid model states) can be masked or replaced downstream. Parameters ---------- parameters : array-like The parameter vector of the model sample. This will typically come from the non-linear search. Inside this method it is mapped back to a model instance via `model.instance_from_vector`. model : Model The model object defining how the parameter vector is mapped to an instance. Passed explicitly so that this function can be used inside JAX transforms (`vmap`, `jit`) with `functools.partial`. Returns ------- tuple of (float or jax.numpy scalar) A tuple containing the latent variables in a fixed order: `(intensity_total, magnitude, angle)`. Each entry may be NaN if the corresponding component of the model is not present. """ batch_size = batch_size or 10 try: start_latent = time.time() compute_latent_for_model = functools.partial(self.compute_latent_variables, model=samples.model) if self._use_jax: import jax start = time.time() logger.info("JAX: Applying vmap and jit to likelihood function for latent variables -- may take a few seconds.") batched_compute_latent = jax.jit(jax.vmap(compute_latent_for_model)) logger.info(f"JAX: vmap and jit applied in {time.time() - start} seconds.") else: n_latents = len(self.LATENT_KEYS) nan_row = np.full(n_latents, np.nan) def _safe_compute(xx): try: return compute_latent_for_model(xx) except exc.FitException: return nan_row def batched_compute_latent(x): return np.array([_safe_compute(xx) for xx in x]) parameter_array = np.array(samples.parameter_lists) latent_samples = [] # process in batches for i in range(0, len(parameter_array), batch_size): batch = parameter_array[i:i + batch_size] batch_samples = samples.sample_list[i:i + batch_size] # batched JAX call on this chunk latent_values_batch = batched_compute_latent(batch) if self._use_jax: import jax.numpy as jnp latent_values_batch = jnp.stack(latent_values_batch, axis=-1) # (batch, n_latents) mask = jnp.all(jnp.isfinite(latent_values_batch), axis=0) latent_values_batch = latent_values_batch[:, mask] else: # Drop samples whose latent computation failed (e.g. FitException from # model assertions surfaced as a NaN row in _safe_compute). This leaves # the per-latent column mask to continue handling degenerate latent # dimensions that produce NaN for all remaining samples. row_mask = np.all(np.isfinite(latent_values_batch), axis=1) latent_values_batch = latent_values_batch[row_mask] batch_samples = [s for s, keep in zip(batch_samples, row_mask) if keep] if len(latent_values_batch): col_mask = np.all(np.isfinite(latent_values_batch), axis=0) latent_values_batch = latent_values_batch[:, col_mask] for sample, values in zip(batch_samples, latent_values_batch): kwargs = {k: float(v) for k, v in zip(self.LATENT_KEYS, values)} latent_samples.append( Sample( log_likelihood=sample.log_likelihood, log_prior=sample.log_prior, weight=sample.weight, kwargs=kwargs, ) ) print(f"Time to compute latent variables: {time.time() - start_latent} seconds for {len(samples)} samples.") return type(samples)( sample_list=latent_samples, model=simple_model_for_kwargs(latent_samples[0].kwargs), samples_info=samples.samples_info, ) except NotImplementedError: return None
[docs] def compute_latent_variables(self, parameters, model) -> Dict[str, float]: """ Override to compute latent variables from the instance. Latent variables are expressed as a dictionary: {"name": value} More complex models can be expressed by separating variables names by '.' {"name.attribute": value} Parameters ---------- instance An instance of the model. Returns ------- The computed latent variables. """ raise NotImplementedError()
[docs] def with_model(self, model): """ Associate an explicit model with this analysis. Instances of the model will be used to compute log likelihood in place of the model passed from the search. Parameters ---------- model A model to associate with this analysis Returns ------- An analysis for that model """ from .model_analysis import ModelAnalysis return ModelAnalysis(analysis=self, model=model)
def log_likelihood_function(self, instance): raise NotImplementedError() def save_attributes(self, paths: AbstractPaths): pass def save_results(self, paths: AbstractPaths, result: Result): pass def save_results_combined(self, paths: AbstractPaths, result: Result): pass
[docs] def modify_before_fit(self, paths: AbstractPaths, model: AbstractPriorModel): """ Overwrite this method to modify the attributes of the `Analysis` class before the non-linear search begins. An example use-case is using properties of the model to alter the `Analysis` class in ways that can speed up the fitting performed in the `log_likelihood_function`. """ return self
def modify_model(self, model): return model
[docs] def modify_after_fit( self, paths: AbstractPaths, model: AbstractPriorModel, result: Result ): """ Overwrite this method to modify the attributes of the `Analysis` class before the non-linear search begins. An example use-case is using properties of the model to alter the `Analysis` class in ways that can speed up the fitting performed in the `log_likelihood_function`. """ return self
[docs] def make_result( self, samples_summary: SamplesSummary, paths: AbstractPaths, samples: Optional[SamplesPDF] = None, search_internal: Optional[object] = None, analysis: Optional[object] = None, ) -> Result: """ Returns the `Result` of the non-linear search after it is completed. The result type is defined as a class variable in the `Analysis` class. It can be manually overwritten by a user to return a user-defined result object, which can be extended with additional methods and attributes specific to the model-fit. The standard `Result` object may include: - The samples summary, which contains the maximum log likelihood instance and median PDF model. - The paths of the search, which are used for loading the samples and search internal below when a search is resumed. - The samples of the non-linear search (e.g. MCMC chains) also stored in `samples.csv`. - The non-linear search used for the fit in its internal representation, which is used for resuming a search and making bespoke visualization using the search's internal results. - The analysis used to fit the model (default disabled to save memory, but option may be useful for certain projects). Parameters ---------- samples_summary The summary of the samples of the non-linear search, which include the maximum log likelihood instance and median PDF model. paths An object describing the paths for saving data (e.g. hard-disk directories or entries in sqlite database). samples The samples of the non-linear search, for example the chains of an MCMC run. search_internal The internal representation of the non-linear search used to perform the model-fit. analysis The analysis used to fit the model. Returns ------- Result The result of the non-linear search, which is defined as a class variable in the `Analysis` class. """ return self.Result( samples_summary=samples_summary, paths=paths, samples=samples, search_internal=search_internal, analysis=analysis, )
@property def supports_background_update(self) -> bool: """Whether this analysis supports background quick updates.""" return False @property def supports_jax_visualization(self) -> bool: """ Whether the visualizer can work directly with JAX arrays. Derived from the ``use_jax_for_visualization`` flag passed at construction time. Subclasses may override to force a specific answer (e.g. an Analysis that has been audited to support JAX visualization unconditionally). """ return self._use_jax_for_visualization def perform_quick_update(self, paths, instance): raise NotImplementedError
[docs] def print_vram_use(self, model, batch_size : int) -> str: """ Print JAX VRAM use for a given batch size. Parameters ---------- batch_size The batch size to profile, which is the number of model evaluations JAX will perform simultaneously. """ from autofit.non_linear.test_mode import skip_fit_output if skip_fit_output(): return if not self._use_jax: print("use_jax=False for this analysis, therefore does not use GPU and VRAM use cannot be profiled.") return import jax import jax.numpy as jnp from autofit.non_linear.fitness import Fitness fitness = Fitness( model=model, analysis=self, fom_is_log_likelihood=True, use_jax_vmap=True, batch_size=batch_size, ) parameters = np.zeros((batch_size, model.total_free_parameters)) for i in range(batch_size): parameters[i, :] = model.physical_values_from_prior_medians parameters = jnp.array(parameters) batched_call = jax.jit(jax.vmap(fitness.call)) lowered = batched_call.lower(parameters) compiled = lowered.compile() memory_analysis = compiled.memory_analysis() vram_bytes = ( memory_analysis.output_size_in_bytes + memory_analysis.temp_size_in_bytes ) if vram_bytes == 0: print( "VRAM USE = 0.000 GB " "(this likely means JAX is running in CPU-only mode)" ) else: print( f"VRAM USE = {vram_bytes / 1024 ** 3:.3f} GB" )