diff --git a/sasdata/trend.py b/sasdata/trend.py index a01e5e5..26667ec 100644 --- a/sasdata/trend.py +++ b/sasdata/trend.py @@ -45,16 +45,22 @@ def __getitem__(self, item) -> SasData: def trend_axes(self) -> list[float]: return [get_metadatum_from_path(datum, self.trend_axis) for datum in self.data] + def data_axes(self, data: SasData, axis: str) -> list[NamedQuantity]: + return [content for content in data._data_contents if content.name == axis] + + def reference_data_axis(self, data: SasData, axis: str) -> NamedQuantity: + return self.data_axes(data, axis)[0] + # TODO: Assumes there are at least 2 items in data. Is this reasonable to assume? Should there be error handling for # situations where this may not be the case? def all_axis_match(self, axis: str) -> bool: reference_data = self.data[0] - reference_data_axis = [content for content in reference_data._data_contents if content.name == axis][0] + data_axis = self.reference_data_axis(reference_data, axis) for datum in self.data[1::]: contents = datum._data_contents axis_datum = [content for content in contents if content.name == axis][0] # FIXME: Linter is complaining about typing. - if not np.all(np.isclose(axis_datum.value, reference_data_axis.value)): + if not np.all(np.isclose(axis_datum.value, data_axis.value)): return False return True @@ -63,7 +69,7 @@ def interpolate(self, axis: str) -> Self: new_data: list[SasData] = [] reference_data = self.data[0] # TODO: I don't like the repetition here. Can probably abstract a function for this ot make it clearer. - reference_data_axis = [content for content in reference_data._data_contents if content.name == axis][0] + data_axis = self.reference_data_axis(reference_data, axis) for i, datum in enumerate(self.data): if i == 0: # This is already the reference axis; no need to interpolate it. @@ -71,7 +77,7 @@ def interpolate(self, axis: str) -> Self: # TODO: Again, repetition axis_datum = [content for content in datum._data_contents if content.name == axis][0] # TODO: There are other options which may need to be filled (or become new params to this method) - mat = calculate_interpolation_matrix(axis_datum, reference_data_axis) + mat = calculate_interpolation_matrix(axis_datum, data_axis) new_quantities: list[NamedQuantity] = [] for quantity in datum._data_contents: if quantity.name == axis_datum.name: