aaaa
This commit is contained in:
parent
7a9ed25097
commit
b8b26f4d06
|
|
@ -0,0 +1,29 @@
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import numbers
|
||||||
|
import operator
|
||||||
|
from typing import Callable, Protocol, Self, cast
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class FFTData:
|
||||||
|
data: dict[float, float]
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def num_operation(cls, op: Callable[[float, float], float], a: Self | float, b: Self | float):
|
||||||
|
results = {}
|
||||||
|
keyset: set[float] = set.intersection(*(set(x.data.keys()) for x in (a, b) if not isinstance(x, (float, int))))
|
||||||
|
|
||||||
|
for k in keyset:
|
||||||
|
val_a: float = a if isinstance(a, (float, int)) else a.data[k]
|
||||||
|
val_b: float = b if isinstance(b, (float, int)) else b.data[k]
|
||||||
|
results[k] = op(val_a, val_b)
|
||||||
|
return cls(results)
|
||||||
|
|
||||||
|
def __add__(self, other: Self):
|
||||||
|
return self.num_operation(operator.add, self, other)
|
||||||
|
|
||||||
|
def __mul__(self, other: float):
|
||||||
|
return self.num_operation(operator.mul, self, other)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,41 @@
|
||||||
|
from typing import Callable, Protocol, Self, Sequence
|
||||||
|
|
||||||
|
|
||||||
|
class Interpolable(Protocol):
|
||||||
|
def __add__(self: Self, other: Self, /) -> Self: ...
|
||||||
|
def __mul__(self: Self, other: float, /) -> Self: ...
|
||||||
|
|
||||||
|
def lin_interpolate[T: Interpolable](arg: float, arg_l: float, arg_h: float, val_l: T, val_h: T) -> T:
|
||||||
|
if arg_l == arg_h:
|
||||||
|
return val_l
|
||||||
|
|
||||||
|
w_l = (arg_h - arg) / (arg_h - arg_l)
|
||||||
|
w_h = (arg - arg_l) / (arg_h - arg_l)
|
||||||
|
|
||||||
|
return val_l * w_l + val_h * w_h
|
||||||
|
|
||||||
|
# sequence assumed to be in order
|
||||||
|
def find_args_binary[T](seq: Sequence[T], key: Callable[[T], float], val: float) -> tuple[T, T]:
|
||||||
|
# print(seq)
|
||||||
|
if len(seq) < 2:
|
||||||
|
raise ValueError()
|
||||||
|
if val < key(seq[0]):
|
||||||
|
raise ValueError()
|
||||||
|
if val > key(seq[-1]):
|
||||||
|
raise ValueError()
|
||||||
|
if len(seq) == 2:
|
||||||
|
return (seq[0], seq[1])
|
||||||
|
|
||||||
|
index_mid = len(seq)//2
|
||||||
|
val_mid = key(seq[index_mid])
|
||||||
|
# print(f'{index_mid}: {val_mid}')
|
||||||
|
if val > val_mid:
|
||||||
|
return find_args_binary(seq[index_mid:], key, val)
|
||||||
|
else:
|
||||||
|
return find_args_binary(seq[:index_mid+1], key, val)
|
||||||
|
|
||||||
|
def interp_keyed[T, U: Interpolable](seq: Sequence[T], key_key: Callable[[T], float], key_val: Callable[[T], U], x: float) -> U:
|
||||||
|
args = find_args_binary(seq, key_key, x)
|
||||||
|
karg_l, karg_h = map(key_key, args)
|
||||||
|
val_l, val_h = map(key_val, args)
|
||||||
|
return lin_interpolate(x, karg_l, karg_h, val_l, val_h)
|
||||||
Loading…
Reference in New Issue