diff --git a/fft.py b/fft.py new file mode 100644 index 0000000..58d05e3 --- /dev/null +++ b/fft.py @@ -0,0 +1,29 @@ +from dataclasses import dataclass +import numbers +import operator +from typing import Callable, Protocol, Self, cast + + +@dataclass(frozen=True) +class FFTData: + data: dict[float, float] + + + @classmethod + def num_operation(cls, op: Callable[[float, float], float], a: Self | float, b: Self | float): + results = {} + keyset: set[float] = set.intersection(*(set(x.data.keys()) for x in (a, b) if not isinstance(x, (float, int)))) + + for k in keyset: + val_a: float = a if isinstance(a, (float, int)) else a.data[k] + val_b: float = b if isinstance(b, (float, int)) else b.data[k] + results[k] = op(val_a, val_b) + return cls(results) + + def __add__(self, other: Self): + return self.num_operation(operator.add, self, other) + + def __mul__(self, other: float): + return self.num_operation(operator.mul, self, other) + + diff --git a/interpolation.py b/interpolation.py new file mode 100644 index 0000000..f3f8cae --- /dev/null +++ b/interpolation.py @@ -0,0 +1,41 @@ +from typing import Callable, Protocol, Self, Sequence + + +class Interpolable(Protocol): + def __add__(self: Self, other: Self, /) -> Self: ... + def __mul__(self: Self, other: float, /) -> Self: ... + +def lin_interpolate[T: Interpolable](arg: float, arg_l: float, arg_h: float, val_l: T, val_h: T) -> T: + if arg_l == arg_h: + return val_l + + w_l = (arg_h - arg) / (arg_h - arg_l) + w_h = (arg - arg_l) / (arg_h - arg_l) + + return val_l * w_l + val_h * w_h + +# sequence assumed to be in order +def find_args_binary[T](seq: Sequence[T], key: Callable[[T], float], val: float) -> tuple[T, T]: + # print(seq) + if len(seq) < 2: + raise ValueError() + if val < key(seq[0]): + raise ValueError() + if val > key(seq[-1]): + raise ValueError() + if len(seq) == 2: + return (seq[0], seq[1]) + + index_mid = len(seq)//2 + val_mid = key(seq[index_mid]) + # print(f'{index_mid}: {val_mid}') + if val > val_mid: + return find_args_binary(seq[index_mid:], key, val) + else: + return find_args_binary(seq[:index_mid+1], key, val) + +def interp_keyed[T, U: Interpolable](seq: Sequence[T], key_key: Callable[[T], float], key_val: Callable[[T], U], x: float) -> U: + args = find_args_binary(seq, key_key, x) + karg_l, karg_h = map(key_key, args) + val_l, val_h = map(key_val, args) + return lin_interpolate(x, karg_l, karg_h, val_l, val_h)