from dataclasses import fields, is_dataclass from typing import Callable, Protocol, Self, Sequence, TypeVar import operator class Interpolable(Protocol): def __add__(self: Self, other: Self, /) -> Self: ... def __mul__(self: Self, other: float, /) -> Self: ... def __neg__(self: Self, /) -> Self: ... EXTRAPOLATE_NONE = 0 EXTRAPOLATE_CONST = 1 EXTRAPOLATE_LIN = 2 T1 = TypeVar('T1', bound=Interpolable) def lin_fun_interpolate(x_values: Sequence[float], y_values: Sequence[Callable[[], T1]], x: float, extr_down = EXTRAPOLATE_NONE, extr_up = EXTRAPOLATE_NONE) -> T1: # x0,x1,y0,y1 = None if x < x_values[0]: if extr_down == EXTRAPOLATE_NONE: raise ValueError("x is outside the interpolation range!") elif extr_down == EXTRAPOLATE_CONST: return y_values[0]() elif extr_down == EXTRAPOLATE_LIN: pass else: raise ValueError('Invalid extrapoltion method') for i in range(1, len(x_values)): x0, x1 = x_values[i-1], x_values[i] y0, y1 = y_values[i-1](), y_values[i]() if x_values[i] > x: break else: if extr_up == EXTRAPOLATE_NONE: raise ValueError("x is outside the interpolation range!") elif extr_up == EXTRAPOLATE_CONST: return y1 elif extr_up == EXTRAPOLATE_LIN: pass else: raise ValueError('Invalid extrapoltion method') return y0 + (y1 - y0) * (x - x0) / (x1 - x0) # Operators def map_fields(obj, func): if not is_dataclass(obj): raise TypeError("Expected dataclass instance") cls = type(obj) return cls(**{ f.name: func(f.name, getattr(obj, f.name)) for f in fields(obj) }) def apply_binary_vector_op(a, b, op): if type(a) is not type(b): return NotImplemented return map_fields(a, lambda name, val: op(val, getattr(b, name))) def apply_scalar_op(a, scalar, op): if not isinstance(scalar, (int, float)): return NotImplemented return map_fields(a, lambda name, val: op(val, scalar)) def apply_rscalar_op(a, scalar, op): if not isinstance(scalar, (int, float)): return NotImplemented return map_fields(a, lambda name, val: op(scalar, val)) def apply_unary_op(a, op): return map_fields(a, lambda name, val: op(val)) def _add_ops(cls, op_names, method_name_fn, apply_func): for name in op_names: magic = method_name_fn(name) op = getattr(operator, name) def method(self, other=None, *, _op=op, _name=name): if other is None: return apply_func(self, _op) return apply_func(self, other, _op) setattr(cls, magic, method) return cls def with_ops_vector(*op_names): return lambda cls: _add_ops(cls, op_names, lambda name: f"__{name}__", apply_binary_vector_op) def with_ops_scalar(*op_names): return lambda cls: _add_ops(cls, op_names, lambda name: f"__{name}__", apply_scalar_op) def with_ops_rscalar(*op_names): return lambda cls: _add_ops(cls, op_names, lambda name: f"__r{name}__", apply_rscalar_op) def with_ops_unary(*op_names): return lambda cls: _add_ops(cls, op_names, lambda name: f"__{name}__", apply_unary_op)