Source code for leaspy.models.factory
from enum import Enum
from typing import Optional, Union
from .base import BaseModel
from .constant import ConstantModel
from .joint import JointModel
from .linear import LinearModel
from .lme import LMEModel
from .logistic import LogisticModel
from .mixture import LogisticMultivariateMixtureModel
from .shared_speed_logistic import SharedSpeedLogisticModel
__all__ = [
"ModelName",
"model_factory",
]
[docs]
class ModelName(str, Enum):
"""The available models that users can instantiate in Leaspy."""
JOINT = "joint"
LOGISTIC = "logistic"
LINEAR = "linear"
SHARED_SPEED_LOGISTIC = "shared_speed_logistic"
LME = "lme"
CONSTANT = "constant"
MIXTURE_LOGISTIC = "mixture_logistic"
[docs]
def model_factory(
name: Union[str, ModelName], instance_name: Optional[str] = None, **kwargs
) -> BaseModel:
"""
Return the model object corresponding to ``name`` arg with possible ``kwargs``.
Parameters
----------
name : :obj:`str` or ModelName
The name of the model class to be instantiated. Valid options include:
- ``"joint"``
- ``"logistic"``
- ``"linear"``
- ``"shared_speed_logistic"``
- ``"lme"``
- ``"constant"``
- ``"mixture_logistic"``
instance_name : :obj:`str`, optional
A custom name for the model instance. If not provided, the model's name
will be used as the instance name.
**kwargs
Additional keyword arguments corresponding to the model's hyperparameters.
These must be valid for the specified model, or an error will be raised.
Returns
-------
:class:`~leaspy.models.base.BaseModel`
A child class object of :class:`~leaspy.models.base.BaseModel` class object determined by ``name``.
Raises
------
ValueError
If an invalid model name is provided or the model cannot be instantiated
with the provided arguments.
"""
name = ModelName(name)
instance_name = instance_name or name.value
if name == ModelName.JOINT:
return JointModel(instance_name, **kwargs)
if name == ModelName.LOGISTIC:
return LogisticModel(instance_name, **kwargs)
if name == ModelName.LINEAR:
return LinearModel(instance_name, **kwargs)
if name == ModelName.SHARED_SPEED_LOGISTIC:
return SharedSpeedLogisticModel(instance_name, **kwargs)
if name == ModelName.LME:
return LMEModel(instance_name, **kwargs)
if name == ModelName.CONSTANT:
return ConstantModel(instance_name, **kwargs)
if name == ModelName.MIXTURE_LOGISTIC:
return LogisticMultivariateMixtureModel(instance_name, **kwargs)