diff --git a/sasdata/data.py b/sasdata/data.py index 7f0cbfb..544ba27 100644 --- a/sasdata/data.py +++ b/sasdata/data.py @@ -2,6 +2,8 @@ from typing import TypeVar, Any, Self from dataclasses import dataclass +import numpy as np + from quantities.quantity import NamedQuantity from sasdata.metadata import Metadata from sasdata.quantities.accessors import AccessorTarget @@ -9,7 +11,11 @@ class SasData: - def __init__(self, name: str, data_contents: list[NamedQuantity], raw_metadata: Group, verbose: bool=False): + def __init__(self, name: str, + data_contents: list[NamedQuantity], + raw_metadata: Group, + verbose: bool=False): + self.name = name self._data_contents = data_contents self._raw_metadata = raw_metadata @@ -17,14 +23,11 @@ def __init__(self, name: str, data_contents: list[NamedQuantity], raw_metadata: self.metadata = Metadata(AccessorTarget(raw_metadata, verbose=verbose)) - # TO IMPLEMENT - - # abscissae: list[NamedQuantity[np.ndarray]] - # ordinate: NamedQuantity[np.ndarray] - # other: list[NamedQuantity[np.ndarray]] - # - # metadata: Metadata - # model_requirements: ModellingRequirements + # Components that need to be organised after creation + self.ordinate: NamedQuantity[np.ndarray] = None # TODO: fill out + self.abscissae: list[NamedQuantity[np.ndarray]] = None # TODO: fill out + self.mask = None # TODO: fill out + self.model_requirements = None # TODO: fill out def summary(self, indent = " ", include_raw=False): s = f"{self.name}\n" diff --git a/sasdata/model_requirements.py b/sasdata/model_requirements.py index f186d2d..d043b2c 100644 --- a/sasdata/model_requirements.py +++ b/sasdata/model_requirements.py @@ -3,7 +3,7 @@ import numpy as np from sasdata.metadata import Metadata -from transforms.operation import Operation +from sasdata.quantities.operations import Operation @dataclass diff --git a/sasdata/quantities/operations.py b/sasdata/quantities/operations.py index 724e55d..35f6fc7 100644 --- a/sasdata/quantities/operations.py +++ b/sasdata/quantities/operations.py @@ -1,4 +1,5 @@ from typing import Any, TypeVar, Union +import numpy as np import json @@ -702,9 +703,119 @@ def __eq__(self, other): if isinstance(other, Pow): return self.a == other.a and self.power == other.power + + +# +# Matrix operations +# + +class Transpose(UnaryOperation): + """ Transpose operation - as per numpy""" + + serialisation_name = "transpose" + + def evaluate(self, variables: dict[int, T]) -> T: + return np.transpose(self.a.evaluate(variables)) + + def _derivative(self, hash_value: int) -> Operation: + return Transpose(self.a.derivative(hash_value)) # TODO: Check! + + def _clean(self): + clean_a = self.a._clean() + return Transpose(clean_a) + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return Transpose(Operation.deserialise_json(parameters["a"])) + + def _summary_open(self): + return "Transpose" + + def __eq__(self, other): + if isinstance(other, Transpose): + return other.a == self.a + + +class Dot(BinaryOperation): + """ Dot product - backed by numpy's dot method""" + + serialisation_name = "dot" + + def _self_cls(self) -> type: + return Dot + + def evaluate(self, variables: dict[int, T]) -> T: + return np.dot(self.a.evaluate(variables) + self.b.evaluate(variables)) + + def _derivative(self, hash_value: int) -> Operation: + return Add( + Dot(self.a, + self.b._derivative(hash_value)), + Dot(self.a._derivative(hash_value), + self.b)) + + def _clean_ab(self, a, b): + return Dot(a, b) # Do nothing for now + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return Dot(*BinaryOperation._deserialise_ab(parameters)) + + def _summary_open(self): + return "Dot" + + +# TODO: Add to base operation class, and to quantities +class MatMul(BinaryOperation): + """ Matrix multiplication, using __matmul__ dunder""" + + serialisation_name = "matmul" + + def _self_cls(self) -> type: + return MatMul + + def evaluate(self, variables: dict[int, T]) -> T: + return self.a.evaluate(variables) @ self.b.evaluate(variables) + + def _derivative(self, hash_value: int) -> Operation: + return Add( + MatMul(self.a, + self.b._derivative(hash_value)), + MatMul(self.a._derivative(hash_value), + self.b)) + + def _clean_ab(self, a, b): + + if isinstance(a, AdditiveIdentity) or isinstance(b, AdditiveIdentity): + # Convert 0*b or a*0 to 0 + return AdditiveIdentity() + + elif isinstance(a, ConstantBase) and isinstance(b, ConstantBase): + # Convert constant "a"@"b" to "a@b" + return Constant(a.evaluate({}) @ b.evaluate({}))._clean() + + elif isinstance(a, Neg): + return Neg(Mul(a.a, b)) + + elif isinstance(b, Neg): + return Neg(Mul(a, b.a)) + + return MatMul(a, b) + + + @staticmethod + def _deserialise(parameters: dict) -> "Operation": + return MatMul(*BinaryOperation._deserialise_ab(parameters)) + + def _summary_open(self): + return "MatMul" + + + _serialisable_classes = [AdditiveIdentity, MultiplicativeIdentity, Constant, Variable, Neg, Inv, - Add, Sub, Mul, Div, Pow] + Add, Sub, Mul, Div, Pow, + Transpose, Dot, MatMul] _serialisation_lookup = {cls.serialisation_name: cls for cls in _serialisable_classes} diff --git a/sasdata/quantities/quantity.py b/sasdata/quantities/quantity.py index df70830..6ee80d8 100644 --- a/sasdata/quantities/quantity.py +++ b/sasdata/quantities/quantity.py @@ -132,10 +132,10 @@ def __init__(self, self.hash_value = -1 """ Hash based on value and uncertainty for data, -1 if it is a derived hash value """ - """ Contains the variance if it is data driven, else it is """ + self._variance = None + """ Contains the variance if it is data driven """ if standard_error is None: - self._variance = None self.hash_value = hash_data_via_numpy(hash_seed, value) else: self._variance = standard_error ** 2 @@ -233,6 +233,45 @@ def __rmul__(self: Self, other: ArrayLike | Self): self.history.operation_tree), self.history.references)) + + def __matmul__(self, other: ArrayLike | Self): + if isinstance(other, Quantity): + return DerivedQuantity( + self.value @ other.value, + self.units * other.units, + history=QuantityHistory.apply_operation( + operations.MatMul, + self.history, + other.history)) + else: + return DerivedQuantity( + self.value @ other, + self.units, + QuantityHistory( + operations.MatMul( + self.history.operation_tree, + operations.Constant(other)), + self.history.references)) + + def __rmatmul__(self, other: ArrayLike | Self): + if isinstance(other, Quantity): + return DerivedQuantity( + other.value @ self.value, + other.units * self.units, + history=QuantityHistory.apply_operation( + operations.MatMul, + other.history, + self.history)) + + else: + return DerivedQuantity(other @ self.value, self.units, + QuantityHistory( + operations.MatMul( + operations.Constant(other), + self.history.operation_tree), + self.history.references)) + + def __truediv__(self: Self, other: float | Self) -> Self: if isinstance(other, Quantity): return DerivedQuantity( diff --git a/sasdata/transforms/operation.py b/sasdata/transforms/operation.py deleted file mode 100644 index 5912188..0000000 --- a/sasdata/transforms/operation.py +++ /dev/null @@ -1,19 +0,0 @@ -import numpy as np -from sasdata.quantities.quantity import Quantity - -class Operation: - """ Sketch of what model post-processing classes might look like """ - - children: list["Operation"] - named_children: dict[str, "Operation"] - - @property - def name(self) -> str: - raise NotImplementedError("No name for transform") - - def evaluate(self) -> Quantity[np.ndarray]: - pass - - def __call__(self, *children, **named_children): - self.children = children - self.named_children = named_children \ No newline at end of file diff --git a/sasdata/transforms/post_process.py b/sasdata/transforms/post_process.py new file mode 100644 index 0000000..e69de29 diff --git a/sasdata/transforms/rebinning.py b/sasdata/transforms/rebinning.py new file mode 100644 index 0000000..cd05cfc --- /dev/null +++ b/sasdata/transforms/rebinning.py @@ -0,0 +1,142 @@ +""" Algorithms for interpolation and rebinning """ +from typing import TypeVar + +import numpy as np +from numpy._typing import ArrayLike +from scipy.interpolate import interp1d + +from sasdata.quantities.quantity import Quantity +from scipy.sparse import coo_matrix + +from enum import Enum + +class InterpolationOptions(Enum): + NEAREST_NEIGHBOUR = 0 + LINEAR = 1 + + + +def calculate_interpolation_matrix_1d(input_axis: Quantity[ArrayLike], + output_axis: Quantity[ArrayLike], + mask: ArrayLike | None = None, + order: InterpolationOptions = InterpolationOptions.NEAREST_NEIGHBOUR, + is_density=False): + + # We want the input values in terms of the output units, will implicitly check compatability + + working_units = output_axis.units + + input_x = input_axis.in_units_of(working_units) + output_x = output_axis.in_units_of(working_units) + + # Get the array indices that will map the array to a sorted one + input_sort = np.argsort(input_x) + output_sort = np.argsort(output_x) + + input_unsort = np.arange(len(output_x), dtype=int)[input_sort] + output_unsort = np.arange(len(input_x), dtype=int)[output_sort] + + sorted_in = input_x[input_sort] + sorted_out = output_x[output_sort] + + n_in = len(sorted_in) + n_out = len(sorted_out) + + conversion_matrix = None # output + + match order: + case InterpolationOptions.NEAREST_NEIGHBOUR: + + # COO Sparse matrix definition data + i_entries = [] + j_entries = [] + + crossing_points = 0.5*(sorted_out[1:] + sorted_out[:-1]) + + # Find the output values nearest to each of the input values + i=0 + for k, crossing_point in enumerate(crossing_points): + while i < n_in and sorted_in[i] < crossing_point: + i_entries.append(i) + j_entries.append(k) + i += 1 + + # All the rest in the last bin + while i < n_in: + i_entries.append(i) + j_entries.append(n_out-1) + i += 1 + + i_entries = input_unsort[np.array(i_entries, dtype=int)] + j_entries = output_unsort[np.array(j_entries, dtype=int)] + values = np.ones_like(i_entries, dtype=float) + + conversion_matrix = coo_matrix((values, (i_entries, j_entries)), shape=(n_in, n_out)) + + case InterpolationOptions.LINEAR: + + # Leverage existing linear interpolation methods to get the mapping + # do a linear interpolation on indices + # the floor should give the left bin + # the ceil should give the right bin + # the fractional part should give the relative weightings + + input_indices = np.arange(n_in, dtype=int) + output_indices = np.arange(n_out, dtype=int) + + fractional = np.interp(x=sorted_out, xp=sorted_in, fp=input_indices, left=0, right=n_in-1) + + left_bins = np.floor(fractional, dtype=int) + right_bins = np.ceil(fractional, dtype=int) + + right_weight = fractional % 1 + left_weight = 1 - right_weight + + # There *should* be no repeated entries for both i and j in the main part, but maybe at the ends + # If left bin is the same as right bin, then we only want one entry, and the weight should be 1 + + same = left_bins == right_bins + not_same = ~same + + same_bins = left_bins[same] # could equally be right bins, they're the same + + same_indices = output_indices[same] + not_same_indices = output_indices[not_same] + + j_entries_sorted = np.concatenate((same_indices, not_same_indices, not_same_indices)) + i_entries_sorted = np.concatenate((same_bins, left_bins[not_same], right_bins[not_same])) + + i_entries = input_unsort[i_entries_sorted] + j_entries = output_unsort[j_entries_sorted] + + # weights don't need to be unsorted # TODO: check this is right, it should become obvious if we use unsorted data + weights = np.concatenate((np.ones_like(same_bins, dtype=float), left_weight[not_same], right_weight[not_same])) + + conversion_matrix = coo_matrix((weights, (i_entries, j_entries)), shape=(n_in, n_out)) + + case _: + raise ValueError(f"Unsupported interpolation order: {order}") + + + return conversion_matrix + +def calculate_interpolation_matrix(input_axes: list[Quantity[ArrayLike]], + output_axes: list[Quantity[ArrayLike]], + data: ArrayLike | None = None, + mask: ArrayLike | None = None): + + pass + + + +def rebin(data: Quantity[ArrayLike], + axes: list[Quantity[ArrayLike]], + new_axes: list[Quantity[ArrayLike]], + mask: ArrayLike | None = None, + interpolation_order: int = 1): + + """ This algorithm is only for operations that preserve dimensionality, + i.e. non-projective rebinning. + """ + + pass \ No newline at end of file