Skip to content

Commit

Permalink
Merge branch '85-add-interpolationrebinning' into 99-move-fractional-…
Browse files Browse the repository at this point in the history
…binning-stuff-to-new-binning-branch
  • Loading branch information
lucas-wilkins authored Oct 15, 2024
2 parents 6df7930 + a1db35f commit 209ba83
Show file tree
Hide file tree
Showing 7 changed files with 308 additions and 32 deletions.
21 changes: 12 additions & 9 deletions sasdata/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,32 @@
from typing import TypeVar, Any, Self
from dataclasses import dataclass

import numpy as np

from quantities.quantity import NamedQuantity
from sasdata.metadata import Metadata
from sasdata.quantities.accessors import AccessorTarget
from sasdata.data_backing import Group, key_tree


class SasData:
def __init__(self, name: str, data_contents: list[NamedQuantity], raw_metadata: Group, verbose: bool=False):
def __init__(self, name: str,
data_contents: list[NamedQuantity],
raw_metadata: Group,
verbose: bool=False):

self.name = name
self._data_contents = data_contents
self._raw_metadata = raw_metadata
self._verbose = verbose

self.metadata = Metadata(AccessorTarget(raw_metadata, verbose=verbose))

# TO IMPLEMENT

# abscissae: list[NamedQuantity[np.ndarray]]
# ordinate: NamedQuantity[np.ndarray]
# other: list[NamedQuantity[np.ndarray]]
#
# metadata: Metadata
# model_requirements: ModellingRequirements
# Components that need to be organised after creation
self.ordinate: NamedQuantity[np.ndarray] = None # TODO: fill out
self.abscissae: list[NamedQuantity[np.ndarray]] = None # TODO: fill out
self.mask = None # TODO: fill out
self.model_requirements = None # TODO: fill out

def summary(self, indent = " ", include_raw=False):
s = f"{self.name}\n"
Expand Down
2 changes: 1 addition & 1 deletion sasdata/model_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

from sasdata.metadata import Metadata
from transforms.operation import Operation
from sasdata.quantities.operations import Operation


@dataclass
Expand Down
113 changes: 112 additions & 1 deletion sasdata/quantities/operations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, TypeVar, Union
import numpy as np

import json

Expand Down Expand Up @@ -702,9 +703,119 @@ def __eq__(self, other):
if isinstance(other, Pow):
return self.a == other.a and self.power == other.power



#
# Matrix operations
#

class Transpose(UnaryOperation):
""" Transpose operation - as per numpy"""

serialisation_name = "transpose"

def evaluate(self, variables: dict[int, T]) -> T:
return np.transpose(self.a.evaluate(variables))

def _derivative(self, hash_value: int) -> Operation:
return Transpose(self.a.derivative(hash_value)) # TODO: Check!

def _clean(self):
clean_a = self.a._clean()
return Transpose(clean_a)

@staticmethod
def _deserialise(parameters: dict) -> "Operation":
return Transpose(Operation.deserialise_json(parameters["a"]))

def _summary_open(self):
return "Transpose"

def __eq__(self, other):
if isinstance(other, Transpose):
return other.a == self.a


class Dot(BinaryOperation):
""" Dot product - backed by numpy's dot method"""

serialisation_name = "dot"

def _self_cls(self) -> type:
return Dot

def evaluate(self, variables: dict[int, T]) -> T:
return np.dot(self.a.evaluate(variables) + self.b.evaluate(variables))

def _derivative(self, hash_value: int) -> Operation:
return Add(
Dot(self.a,
self.b._derivative(hash_value)),
Dot(self.a._derivative(hash_value),
self.b))

def _clean_ab(self, a, b):
return Dot(a, b) # Do nothing for now

@staticmethod
def _deserialise(parameters: dict) -> "Operation":
return Dot(*BinaryOperation._deserialise_ab(parameters))

def _summary_open(self):
return "Dot"


# TODO: Add to base operation class, and to quantities
class MatMul(BinaryOperation):
""" Matrix multiplication, using __matmul__ dunder"""

serialisation_name = "matmul"

def _self_cls(self) -> type:
return MatMul

def evaluate(self, variables: dict[int, T]) -> T:
return self.a.evaluate(variables) @ self.b.evaluate(variables)

def _derivative(self, hash_value: int) -> Operation:
return Add(
MatMul(self.a,
self.b._derivative(hash_value)),
MatMul(self.a._derivative(hash_value),
self.b))

def _clean_ab(self, a, b):

if isinstance(a, AdditiveIdentity) or isinstance(b, AdditiveIdentity):
# Convert 0*b or a*0 to 0
return AdditiveIdentity()

elif isinstance(a, ConstantBase) and isinstance(b, ConstantBase):
# Convert constant "a"@"b" to "a@b"
return Constant(a.evaluate({}) @ b.evaluate({}))._clean()

elif isinstance(a, Neg):
return Neg(Mul(a.a, b))

elif isinstance(b, Neg):
return Neg(Mul(a, b.a))

return MatMul(a, b)


@staticmethod
def _deserialise(parameters: dict) -> "Operation":
return MatMul(*BinaryOperation._deserialise_ab(parameters))

def _summary_open(self):
return "MatMul"



_serialisable_classes = [AdditiveIdentity, MultiplicativeIdentity, Constant,
Variable,
Neg, Inv,
Add, Sub, Mul, Div, Pow]
Add, Sub, Mul, Div, Pow,
Transpose, Dot, MatMul]

_serialisation_lookup = {cls.serialisation_name: cls for cls in _serialisable_classes}
43 changes: 41 additions & 2 deletions sasdata/quantities/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,10 @@ def __init__(self,
self.hash_value = -1
""" Hash based on value and uncertainty for data, -1 if it is a derived hash value """

""" Contains the variance if it is data driven, else it is """
self._variance = None
""" Contains the variance if it is data driven """

if standard_error is None:
self._variance = None
self.hash_value = hash_data_via_numpy(hash_seed, value)
else:
self._variance = standard_error ** 2
Expand Down Expand Up @@ -233,6 +233,45 @@ def __rmul__(self: Self, other: ArrayLike | Self):
self.history.operation_tree),
self.history.references))


def __matmul__(self, other: ArrayLike | Self):
if isinstance(other, Quantity):
return DerivedQuantity(
self.value @ other.value,
self.units * other.units,
history=QuantityHistory.apply_operation(
operations.MatMul,
self.history,
other.history))
else:
return DerivedQuantity(
self.value @ other,
self.units,
QuantityHistory(
operations.MatMul(
self.history.operation_tree,
operations.Constant(other)),
self.history.references))

def __rmatmul__(self, other: ArrayLike | Self):
if isinstance(other, Quantity):
return DerivedQuantity(
other.value @ self.value,
other.units * self.units,
history=QuantityHistory.apply_operation(
operations.MatMul,
other.history,
self.history))

else:
return DerivedQuantity(other @ self.value, self.units,
QuantityHistory(
operations.MatMul(
operations.Constant(other),
self.history.operation_tree),
self.history.references))


def __truediv__(self: Self, other: float | Self) -> Self:
if isinstance(other, Quantity):
return DerivedQuantity(
Expand Down
19 changes: 0 additions & 19 deletions sasdata/transforms/operation.py

This file was deleted.

Empty file.
Loading

0 comments on commit 209ba83

Please sign in to comment.