Source code for leaspy.models.utils.attributes.abstract_attributes

from abc import ABC, abstractmethod

import torch

from leaspy.exceptions import LeaspyModelInputError
from leaspy.utils.typing import DictParamsTorch, ParamType

__all__ = ["AbstractAttributes"]


[docs] class AbstractAttributes(ABC): """ Abstract base class for attributes of models. Contains the common attributes & methods of the different attributes classes. Such classes are used to update the models' attributes. Parameters ---------- name : str dimension : int (default None) source_dimension : int (default None) Attributes ---------- name : str Name of the associated leaspy model. dimension : int Number of features of the model source_dimension : int Number of sources of the model TODO? move to AbstractManifoldModelAttributes? univariate : bool Whether model is univariate or not (i.e. dimension == 1) has_sources : bool Whether model has sources or not (not univariate and source_dimension >= 1) TODO? move to AbstractManifoldModelAttributes? update_possibilities : set[str] (default empty) Contains the available parameters to update. Different models have different parameters. Raises ------ :exc:`.LeaspyModelInputError` if any inconsistent parameter. """ def __init__(self, name: str, dimension: int, source_dimension: int): if not (isinstance(name, str) and len(name)): raise LeaspyModelInputError( "In model attributes, you must provide a non-empty string for the parameter `name`." ) self.name = name if not (isinstance(dimension, int) and dimension >= 1): raise LeaspyModelInputError( "In model attributes, you must provide an integer >= 1 for the parameter `dimension`." ) self.dimension = dimension self.univariate = dimension == 1 self.source_dimension = source_dimension self.has_sources = bool(source_dimension) if self.univariate and self.has_sources: raise LeaspyModelInputError( "Inconsistent attributes: presence of sources for a univariate model." ) if not ( isinstance(source_dimension, int) and (0 <= source_dimension < dimension) ): raise LeaspyModelInputError( "In model attributes, you must provide an integer in [0, dimension - 1] for the parameter `source_dimension`." ) self.update_possibilities: set[ParamType] = set()
[docs] @abstractmethod def get_attributes( self, ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: """ Returns the attributes of the model, which is a tuple of three torch tensors. For the precise definitions of those attributes please refer to the exact attributes class associated to your model. Returns ------- positions: `torch.Tensor` velocities: `torch.Tensor` mixing_matrix: `torch.Tensor` """
[docs] @abstractmethod def update( self, names_of_changed_values: set[ParamType], values: DictParamsTorch ) -> None: """ Update model group average parameter(s). Parameters ---------- names_of_changed_values : set[str] Values to be updated values : dict [str, `torch.Tensor`] New values used to update the model's group average parameters Raises ------ :exc:`.LeaspyModelInputError` If `names_of_changed_values` contains unknown values to update. """
[docs] def move_to_device(self, device: torch.device): """ Move the tensor attributes of this class to the specified device. Parameters ---------- device : torch.device """ for attribute_name in dir(self): if attribute_name.startswith("__"): continue attribute = getattr(self, attribute_name) if isinstance(attribute, torch.Tensor): setattr(self, attribute_name, attribute.to(device))
def _check_names(self, names_of_changed_values: set[ParamType]): """ Check if the name of the parameter(s) to update are in the possibilities allowed by the model. Parameters ---------- names_of_changed_values : set[str] Raises ------ :exc:`.LeaspyModelInputError` If `names_of_changed_values` contains unknown values to update. """ unknown_update_possibilities = names_of_changed_values.difference( self.update_possibilities ) if len(unknown_update_possibilities): raise LeaspyModelInputError( f"{unknown_update_possibilities} not in the attributes that can be updated" )