102 lines
3.2 KiB
Python
102 lines
3.2 KiB
Python
from dataclasses import fields, is_dataclass
|
|
from typing import Callable, Protocol, Self, Sequence, TypeVar
|
|
import operator
|
|
|
|
class Interpolable(Protocol):
|
|
def __add__(self: Self, other: Self, /) -> Self: ...
|
|
def __mul__(self: Self, other: float, /) -> Self: ...
|
|
def __neg__(self: Self, /) -> Self: ...
|
|
|
|
EXTRAPOLATE_NONE = 0
|
|
EXTRAPOLATE_CONST = 1
|
|
EXTRAPOLATE_LIN = 2
|
|
|
|
T1 = TypeVar('T1', bound=Interpolable)
|
|
def lin_fun_interpolate(x_values: Sequence[float], y_values: Sequence[Callable[[], T1]], x: float, extr_down = EXTRAPOLATE_NONE, extr_up = EXTRAPOLATE_NONE) -> T1:
|
|
# x0,x1,y0,y1 = None
|
|
if x < x_values[0]:
|
|
if extr_down == EXTRAPOLATE_NONE:
|
|
raise ValueError("x is outside the interpolation range!")
|
|
elif extr_down == EXTRAPOLATE_CONST:
|
|
return y_values[0]()
|
|
elif extr_down == EXTRAPOLATE_LIN:
|
|
pass
|
|
else:
|
|
raise ValueError('Invalid extrapoltion method')
|
|
|
|
for i in range(1, len(x_values)):
|
|
x0, x1 = x_values[i-1], x_values[i]
|
|
y0, y1 = y_values[i-1](), y_values[i]()
|
|
|
|
if x_values[i] > x:
|
|
break
|
|
else:
|
|
if extr_up == EXTRAPOLATE_NONE:
|
|
raise ValueError("x is outside the interpolation range!")
|
|
elif extr_up == EXTRAPOLATE_CONST:
|
|
return y1
|
|
elif extr_up == EXTRAPOLATE_LIN:
|
|
pass
|
|
else:
|
|
raise ValueError('Invalid extrapoltion method')
|
|
|
|
return y0 + (y1 - y0) * (x - x0) / (x1 - x0)
|
|
|
|
|
|
# Operators
|
|
|
|
def map_fields(obj, func):
|
|
if not is_dataclass(obj):
|
|
raise TypeError("Expected dataclass instance")
|
|
cls = type(obj)
|
|
return cls(**{
|
|
f.name: func(f.name, getattr(obj, f.name))
|
|
for f in fields(obj)
|
|
})
|
|
|
|
def apply_binary_vector_op(a, b, op):
|
|
if type(a) is not type(b):
|
|
return NotImplemented
|
|
return map_fields(a, lambda name, val: op(val, getattr(b, name)))
|
|
|
|
def apply_scalar_op(a, scalar, op):
|
|
if not isinstance(scalar, (int, float)):
|
|
return NotImplemented
|
|
return map_fields(a, lambda name, val: op(val, scalar))
|
|
|
|
def apply_rscalar_op(a, scalar, op):
|
|
if not isinstance(scalar, (int, float)):
|
|
return NotImplemented
|
|
return map_fields(a, lambda name, val: op(scalar, val))
|
|
|
|
def apply_unary_op(a, op):
|
|
return map_fields(a, lambda name, val: op(val))
|
|
|
|
def _add_ops(cls, op_names, method_name_fn, apply_func):
|
|
for name in op_names:
|
|
magic = method_name_fn(name)
|
|
op = getattr(operator, name)
|
|
|
|
def method(self, other=None, *, _op=op, _name=name):
|
|
if other is None:
|
|
return apply_func(self, _op)
|
|
return apply_func(self, other, _op)
|
|
|
|
setattr(cls, magic, method)
|
|
return cls
|
|
|
|
def with_ops_vector(*op_names):
|
|
return lambda cls: _add_ops(cls, op_names, lambda name: f"__{name}__", apply_binary_vector_op)
|
|
|
|
def with_ops_scalar(*op_names):
|
|
return lambda cls: _add_ops(cls, op_names, lambda name: f"__{name}__", apply_scalar_op)
|
|
|
|
def with_ops_rscalar(*op_names):
|
|
return lambda cls: _add_ops(cls, op_names, lambda name: f"__r{name}__", apply_rscalar_op)
|
|
|
|
def with_ops_unary(*op_names):
|
|
return lambda cls: _add_ops(cls, op_names, lambda name: f"__{name}__", apply_unary_op)
|
|
|
|
|
|
|