Source code for leaspy.utils.weighted_tensor._factory

from functools import wraps
from typing import Callable, Optional, TypeVar

import torch

from ._weighted_tensor import TensorOrWeightedTensor, WeightedTensor

__all__ = [
    "factory_weighted_tensor_unary_operator",
]


VT = TypeVar("VT")


[docs] def factory_weighted_tensor_unary_operator( f: Callable[[torch.Tensor], torch.Tensor], *, fill_value: Optional[VT] = None, ) -> Callable[[TensorOrWeightedTensor[VT]], TensorOrWeightedTensor[VT]]: """Factory/decorator to create a weighted-tensor compatible function from the provided unary-tensor function.""" @wraps(f) def f_compatible( x: TensorOrWeightedTensor[VT], *args, **kws ) -> TensorOrWeightedTensor[VT]: if not isinstance(x, WeightedTensor): return f(x, *args, **kws) r = f(x.filled(fill_value), *args, **kws) conv = x.valued if isinstance(r, (tuple, list, set, frozenset)): return type(r)(map(conv, r)) return conv(r) return f_compatible