From f404c1035e4e9b82793a7b07450a1aeef8871c52 Mon Sep 17 00:00:00 2001 From: suzannejin Date: Fri, 17 Jan 2025 14:44:31 +0100 Subject: [PATCH] merge encoders with repeated functions --- src/stimulus/data/encoding/encoders.py | 221 ++++++++++-------------- tests/data/encoding/test_encoders.py | 230 +++++++++++-------------- 2 files changed, 193 insertions(+), 258 deletions(-) diff --git a/src/stimulus/data/encoding/encoders.py b/src/stimulus/data/encoding/encoders.py index f1ae183e..d8eeaef0 100644 --- a/src/stimulus/data/encoding/encoders.py +++ b/src/stimulus/data/encoding/encoders.py @@ -297,43 +297,45 @@ def decode(self, data: torch.Tensor) -> Union[str, List[str]]: raise ValueError(f"Expected 2D or 3D tensor, got {data.dim()}D") -class FloatEncoder(AbstractEncoder): - """Encoder for float data.""" +class NumericEncoder(AbstractEncoder): + """Encoder for float/int data.""" - def __init__(self, dtype: torch.dtype = torch.float) -> None: - """Initialize the FloatEncoder class. + def __init__(self, dtype: torch.dtype = torch.float32) -> None: + """Initialize the NumericEncoder class. Args: dtype (torch.dtype): the data type of the encoded data. Default = torch.float (32-bit floating point) """ self.dtype = dtype - def encode(self, data: float) -> torch.Tensor: + def encode(self, data: Union[float, int]) -> torch.Tensor: """Encodes the data. This method takes as input a single data point, should be mappable to a single output. Args: - data (float): a single data point + data (float or int): a single data point Returns: encoded_data_point (torch.Tensor): the encoded data point """ - self._check_input_dtype(data) return self.encode_all(data) # there is no difference in this case - def encode_all(self, data: Union[float, List[float]]) -> torch.Tensor: + def encode_all(self, data: Union[float, int, List[float], List[int]]) -> torch.Tensor: """Encodes the data. This method takes as input a list of data points, or a single float, and returns a torch.tensor. Args: - data (Union[float, List[float]]): a list of data points or a single data point + data (float or int): a list of data points or a single data point Returns: encoded_data (torch.Tensor): the encoded data """ - self._check_input_dtype(data) if not isinstance(data, list): data = [data] + + self._check_input_dtype(data) + self._warn_float_is_converted_to_int(data) + return torch.tensor(data, dtype=self.dtype) def decode(self, data: torch.Tensor) -> List[float]: @@ -347,56 +349,64 @@ def decode(self, data: torch.Tensor) -> List[float]: """ return data.cpu().numpy().tolist() - def _check_input_dtype(self, data: Union[float, List[int], List[float]]) -> None: + def _check_input_dtype(self, data: Union[List[float], List[int]]) -> None: """Check if the input data is int or float data. Args: - data (int or float): a single data point or a list of data points + data (float or int): a list of float or integer data points Raises: - ValueError: If the input data is not a float + ValueError: If the input data contains a non-integer or non-float data point """ - if (isinstance(data, list) and not all(isinstance(d, (int, float)) for d in data)) or ( - not isinstance(data, list) and not isinstance(data, (int, float)) - ): - err_msg = f"Expected input data to be a float, got {type(data).__name__}" + if not all(isinstance(d, (int, float)) for d in data): + err_msg = f"Expected input data to be a float or int" logger.error(err_msg) raise ValueError(err_msg) - -class IntEncoder(FloatEncoder): - """Encoder for integer data.""" - - def __init__(self, dtype: torch.dtype = torch.int) -> None: - """Initialize the IntEncoder class. + def _warn_float_is_converted_to_int(self, data: Union[List[float], List[int]]) -> None: + """Warn if float data is encoded into int data. Args: - dtype (torch.dtype): the data type of the encoded data. Default = torch.int (32-bit integer) + data (float or int): a list of float or integer data points """ - super().__init__(dtype) + if any(isinstance(d, float) for d in data) and (self.dtype in [torch.int, torch.int8, torch.int16, torch.int32, torch.int64]): + logger.warning("Encoding float data to torch.int data type.") - def _check_input_dtype(self, data: Union[int, List[int]]) -> None: - """Check if the input data is int data. - Args: - data (int): a single data point or a list of data points +class StrClassificationEncoder(AbstractEncoder): + """ + A string classification encoder that converts lists of strings into numeric labels using scikit-learn's + LabelEncoder. When scale is set to True, the labels are scaled to be between 0 and 1. - Raises: - ValueError: If the input data is not a int - """ - if (isinstance(data, list) and not all(isinstance(d, int) for d in data)) or ( - not isinstance(data, list) and not isinstance(data, int) - ): - err_msg = f"Expected input data to be a int, got {type(data).__name__}" - logger.error(err_msg) - raise ValueError(err_msg) + Attributes: + None + + Methods: + encode(data: str) -> int: + Raises a NotImplementedError, as encoding a single string is not meaningful in this context. + encode_all(data: List[str]) -> torch.tensor: + Encodes an entire list of string data into a numeric representation using LabelEncoder and + returns a torch tensor. Ensures that the provided data items are valid strings prior to encoding. + decode(data: Any) -> Any: + Raises a NotImplementedError, as decoding is not supported with the current design. + _check_dtype(data: List[str]) -> None: + Validates that all items in the data list are strings, raising a ValueError otherwise. + """ + def __init__(self, scale: bool = False) -> None: + """Initialize the StrClassificationEncoder class. -class StrClassificationIntEncoder(AbstractEncoder): - """Considering a ensemble of strings, this encoder encodes them into integers from 0 to (n-1) where n is the number of unique strings.""" + Args: + scale (bool): whether to scale the labels to be between 0 and 1. Default = False + """ + self.scale = scale def encode(self, data: str) -> int: - """Returns an error since encoding a single string does not make sense.""" + """Returns an error since encoding a single string does not make sense. + + Args: + data (str): a single string + """ raise NotImplementedError("Encoding a single string does not make sense. Use encode_all instead.") def encode_all(self, data: List[str]) -> torch.tensor: @@ -412,13 +422,19 @@ def encode_all(self, data: List[str]) -> torch.tensor: """ if not isinstance(data, list): data = [data] + self._check_dtype(data) + encoder = preprocessing.LabelEncoder() - return torch.tensor(encoder.fit_transform(data)) + encoded_data = torch.tensor(encoder.fit_transform(data)) + if self.scale: + encoded_data = encoded_data / max(len(encoded_data) - 1, 1) + + return encoded_data def decode(self, data: Any) -> Any: """Returns an error since decoding does not make sense without encoder information, which is not yet supported.""" - raise NotImplementedError("Decoding is not yet supported.") + raise NotImplementedError("Decoding is not yet supported for StrClassification.") def _check_dtype(self, data: List[str]) -> None: """Check if the input data is string data. @@ -435,62 +451,61 @@ def _check_dtype(self, data: List[str]) -> None: raise ValueError(err_msg) -class StrClassificationScaledEncoder(StrClassificationIntEncoder): - """Considering a ensemble of strings, this encoder encodes them into floats from 0 to 1 (essentially scaling the integer encoding).""" +class NumericRankEncoder(AbstractEncoder): + """Encoder for float/int data that encodes the data based on their rank. - def encode_all(self, data: List[str]) -> torch.Tensor: - """Encodes the data. - This method takes as input a list of data points, should be mappable to a single output, using LabelEncoder from scikit learn and returning a numpy array. - For more info visit : https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.LabelEncoder.html - - Args: - data (List[str]): a list of strings + Attributes: + scale (bool): whether to scale the ranks to be between 0 and 1. Default = False - Returns: - encoded_data (torch.Tensor): the encoded data - """ - encoded_data = super().encode_all(data) - return encoded_data / max(len(encoded_data) - 1, 1) + Methods: + encode: encodes a single data point + encode_all: encodes a list of data points into a torch.tensor + decode: decodes a single data point + _check_input_dtype: checks if the input data is int or float data + _warn_float_is_converted_to_int: warns if float data is encoded into + """ + def __init__(self, scale: bool = False) -> None: + """Initialize the NumericRankEncoder class. -class FloatRankEncoder(AbstractEncoder): - """Considering an ensemble of float values, this encoder encodes them into floats from 0 to 1, where 1 is the maximum value and 0 is the minimum value.""" + Args: + scale (bool): whether to scale the ranks to be between 0 and 1. Default = False + """ + self.scale = scale - def encode(self, data: float, dtype: torch.dtype = torch.float) -> torch.Tensor: + def encode(self, data: Any) -> torch.Tensor: """Returns an error since encoding a single float does not make sense.""" raise NotImplementedError("Encoding a single float does not make sense. Use encode_all instead.") - def encode_all(self, data: List[float], dtype: torch.dtype = torch.float) -> torch.Tensor: + def encode_all(self, data: Union[List[float], List[int]]) -> torch.Tensor: """Encodes the data. - This method takes as input a list of data points, converts them to numpy array, and returns the ranks of the data points. - The ranks are normalized to be between 0 and 1. + This method takes as input a list of data points, and returns the ranks of the data points. + The ranks are normalized to be between 0 and 1, when scale is set to True. Args: - data (List[float]): a list of float values + data (Union[List[float], List[int]]): a list of numeric values + scale (bool): whether to scale the ranks to be between 0 and 1. Default = False Returns: encoded_data (torch.Tensor): the encoded data """ - self._check_input_dtype(data) if not isinstance(data, list): data = [data] - try: - data = np.array(data, dtype=float) - # Get ranks (0 is lowest, n-1 is highest) - # and normalize to be between 0 and 1 - ranks = np.argsort(np.argsort(data)) + self._check_input_dtype(data) + + # Get ranks (0 is lowest, n-1 is highest) + # and normalize to be between 0 and 1 + data = np.array(data) + ranks = np.argsort(np.argsort(data)) + if self.scale: ranks = ranks / max(len(ranks) - 1, 1) - return torch.tensor(ranks, dtype=dtype) - except Exception as e: - err_msg = f"Failed to encode data: {e}" - logger.error(err_msg) - raise RuntimeError(err_msg) from e + return torch.tensor(ranks) def decode(self, data: Any) -> Any: """Returns an error since decoding does not make sense without encoder information, which is not yet supported.""" - raise NotImplementedError("Decoding is not yet supported for FloatRank.") + raise NotImplementedError("Decoding is not yet supported for NumericRank.") - def _check_input_dtype(self, data: Union[float, List[int], List[float]]) -> None: + def _check_input_dtype(self, data: list) -> None: """Check if the input data is int or float data. Args: @@ -499,55 +514,7 @@ def _check_input_dtype(self, data: Union[float, List[int], List[float]]) -> None Raises: ValueError: If the input data is not a float """ - if (isinstance(data, list) and not all(isinstance(d, (int, float)) for d in data)) or ( - not isinstance(data, list) and not isinstance(data, (int, float)) - ): - err_msg = f"Expected input data to be a float, got {type(data).__name__}" - logger.error(err_msg) - raise ValueError(err_msg) - - -class IntRankEncoder(FloatRankEncoder): - """Considering an ensemble of integer values, this encoder encodes them into floats from 0 to 1, where 1 is the maximum value and 0 is the minimum value.""" - - def encode(self, data: int, dtype: torch.dtype = torch.int) -> torch.Tensor: - """Returns an error since encoding a single integer does not make sense.""" - raise NotImplementedError("Encoding a single integer does not make sense. Use encode_all instead.") - - def encode_all(self, data: list, dtype: torch.dtype = torch.int) -> torch.Tensor: - """Encodes the data. - This method takes as input a list of data points, should be mappable to a single output, using min-max scaling. - """ - self._check_input_dtype(data) - if not isinstance(data, list): - data = [data] - try: - data = np.array(data, dtype=int) - # Get ranks (0 is lowest, n-1 is highest) - # and normalize to be between 0 and 1 - ranks = np.argsort(np.argsort(data)) - return torch.tensor(ranks, dtype=dtype) - except Exception as e: - err_msg = f"Failed to encode data: {e}" - logger.error(err_msg) - raise RuntimeError(err_msg) from e - - def decode(self, data: Any) -> Any: - """Returns an error since decoding does not make sense without encoder information, which is not yet supported.""" - raise NotImplementedError("Decoding is not yet supported for IntRank.") - - def _check_input_dtype(self, data: Union[int, List[int]]) -> None: - """Check if the input data is int data. - - Args: - data (int): a single data point or a list of data points - - Raises: - ValueError: If the input data is not a int - """ - if (isinstance(data, list) and not all(isinstance(d, int) for d in data)) or ( - not isinstance(data, list) and not isinstance(data, int) - ): - err_msg = f"Expected input data to be a int, got {type(data).__name__}" + if not all(isinstance(d, (int, float)) for d in data): + err_msg = f"Expected input data to be a float or int, got {type(data).__name__}" logger.error(err_msg) raise ValueError(err_msg) diff --git a/tests/data/encoding/test_encoders.py b/tests/data/encoding/test_encoders.py index 582139ab..c2c801f0 100644 --- a/tests/data/encoding/test_encoders.py +++ b/tests/data/encoding/test_encoders.py @@ -4,12 +4,9 @@ import torch from src.stimulus.data.encoding.encoders import ( - FloatEncoder, - FloatRankEncoder, - IntEncoder, - IntRankEncoder, - StrClassificationIntEncoder, - StrClassificationScaledEncoder, + NumericEncoder, + NumericRankEncoder, + StrClassificationEncoder, TextOneHotEncoder, ) @@ -178,14 +175,20 @@ def test_decode_multiple_sequences(self, encoder_default): assert decoded[0] == "acgt-" # '-' for unknown character n -class TestFloatEncoder: - """Test suite for FloatEncoder.""" +class TestNumericEncoder: + """Test suite for NumericEncoder.""" @staticmethod @pytest.fixture def float_encoder(): - """Fixture to instantiate the FloatEncoder.""" - return FloatEncoder() + """Fixture to instantiate the NumericEncoder.""" + return NumericEncoder() + + @staticmethod + @pytest.fixture + def int_encoder(): + """Fixture to instantiate the NumericEncoder with integer dtype.""" + return NumericEncoder(dtype=torch.int32) def test_encode_single_float(self, float_encoder): """Test encoding a single float value.""" @@ -194,14 +197,24 @@ def test_encode_single_float(self, float_encoder): assert isinstance(output, torch.Tensor), "Output should be a torch.Tensor." assert output.dtype == torch.float32, "Tensor dtype should be float32." assert output.numel() == 1, "Tensor should have exactly one element." - # Using pytest.approx for floating-point comparison assert output.item() == pytest.approx(input_val), "Encoded value does not match the input float." + + def test_encode_single_int(self, int_encoder): + """Test encoding a single int value.""" + input_val = 3 + output = int_encoder.encode(input_val) + assert isinstance(output, torch.Tensor), "Output should be a torch.Tensor." + assert output.dtype == torch.int32, "Tensor dtype should be int32." + assert output.numel() == 1, "Tensor should have exactly one element." + assert output.item() == input_val - def test_encode_non_float_raises(self, float_encoder): + @pytest.mark.parametrize("fixture_name", ["float_encoder", "int_encoder"]) + def test_encode_non_numeric_raises(self, request, fixture_name): """Test that encoding a non-float raises a ValueError.""" + numeric_encoder = request.getfixturevalue(fixture_name) with pytest.raises(ValueError) as exc_info: - float_encoder.encode("not_a_float") - assert "Expected input data to be a float, got str" in str(exc_info.value), ( + numeric_encoder.encode("not_numeric") + assert "Expected input data to be a float or int" in str(exc_info.value), ( "Expected ValueError with specific error message." ) @@ -216,61 +229,6 @@ def test_encode_all_single_float(self, float_encoder): assert output.numel() == 1, "Tensor should have exactly one element." assert output.item() == pytest.approx(input_val), "Encoded value does not match the input." - def test_encode_all_multi_float(self, float_encoder): - """Test encode_all with a list of floats.""" - input_vals = [3.14, 4.56] - output = float_encoder.encode_all(input_vals) - assert isinstance(output, torch.Tensor), "Output should be a torch.Tensor." - assert output.dtype == torch.float32, "Tensor dtype should be float32." - assert output.numel() == 2, "Tensor should have exactly one element." - assert output[0].item() == pytest.approx(3.14), "First element does not match." - assert output[1].item() == pytest.approx(4.56), "Second element does not match." - - def test_decode_single_float(self, float_encoder): - """Test decoding a tensor of shape (1).""" - input_tensor = torch.tensor([3.14], dtype=torch.float32) - decoded = float_encoder.decode(input_tensor) - # decode returns data.numpy().tolist() - assert isinstance(decoded, list), "Decoded output should be a list." - assert len(decoded) == 1, "Decoded list should have one element." - assert decoded[0] == pytest.approx(3.14), "Decoded value does not match." - - def test_decode_multi_float(self, float_encoder): - """Test decoding a tensor of shape (n).""" - input_tensor = torch.tensor([3.14, 2.71], dtype=torch.float32) - decoded = float_encoder.decode(input_tensor) - assert isinstance(decoded, list), "Decoded output should be a list." - assert len(decoded) == 2, "Decoded list should have two elements." - assert decoded[0] == pytest.approx(3.14), "First decoded value does not match." - assert decoded[1] == pytest.approx(2.71), "Second decoded value does not match." - - -class TestIntEncoder: - """Test suite for IntEncoder.""" - - @staticmethod - @pytest.fixture - def int_encoder(): - """Fixture to instantiate the IntEncoder.""" - return IntEncoder() - - def test_encode_single_int(self, int_encoder): - """Test encoding a single int value.""" - input_val = 3 - output = int_encoder.encode(input_val) - assert isinstance(output, torch.Tensor), "Output should be a torch.Tensor." - assert output.dtype == torch.int32, "Tensor dtype should be int32." - assert output.numel() == 1, "Tensor should have exactly one element." - assert output.item() == input_val - - def test_encode_non_int_raises(self, int_encoder): - """Test that encoding a non-int raises a RuntimeError.""" - with pytest.raises(ValueError) as exc_info: - int_encoder.encode("not_a_int") - assert "Expected input data to be a int, got str" in str(exc_info.value), ( - "Expected ValueError with specific error message." - ) - def test_encode_all_single_int(self, int_encoder): """Test encode_all when given a single int. It should be treated as a list of one int internally. @@ -282,6 +240,16 @@ def test_encode_all_single_int(self, int_encoder): assert output.numel() == 1, "Tensor should have exactly one element." assert output.item() == input_val + def test_encode_all_multi_float(self, float_encoder): + """Test encode_all with a list of floats.""" + input_vals = [3.14, 4.56] + output = float_encoder.encode_all(input_vals) + assert isinstance(output, torch.Tensor), "Output should be a torch.Tensor." + assert output.dtype == torch.float32, "Tensor dtype should be float32." + assert output.numel() == 2, "Tensor should have exactly one element." + assert output[0].item() == pytest.approx(3.14), "First element does not match." + assert output[1].item() == pytest.approx(4.56), "Second element does not match." + def test_encode_all_multi_int(self, int_encoder): """Test encode_all with a list of integers.""" input_vals = [3, 4] @@ -292,6 +260,15 @@ def test_encode_all_multi_int(self, int_encoder): assert output[0].item() == 3, "First element does not match." assert output[1].item() == 4, "Second element does not match." + def test_decode_single_float(self, float_encoder): + """Test decoding a tensor of shape (1).""" + input_tensor = torch.tensor([3.14], dtype=torch.float32) + decoded = float_encoder.decode(input_tensor) + # decode returns data.numpy().tolist() + assert isinstance(decoded, list), "Decoded output should be a list." + assert len(decoded) == 1, "Decoded list should have one element." + assert decoded[0] == pytest.approx(3.14), "Decoded value does not match." + def test_decode_single_int(self, int_encoder): """Test decoding a tensor of shape (1).""" input_tensor = torch.tensor([3], dtype=torch.int32) @@ -301,6 +278,15 @@ def test_decode_single_int(self, int_encoder): assert len(decoded) == 1, "Decoded list should have one element." assert decoded[0] == 3, "Decoded value does not match." + def test_decode_multi_float(self, float_encoder): + """Test decoding a tensor of shape (n).""" + input_tensor = torch.tensor([3.14, 2.71], dtype=torch.float32) + decoded = float_encoder.decode(input_tensor) + assert isinstance(decoded, list), "Decoded output should be a list." + assert len(decoded) == 2, "Decoded list should have two elements." + assert decoded[0] == pytest.approx(3.14), "First decoded value does not match." + assert decoded[1] == pytest.approx(2.71), "Second decoded value does not match." + def test_decode_multi_int(self, int_encoder): """Test decoding a tensor of shape (n).""" input_tensor = torch.tensor([3, 4], dtype=torch.int32) @@ -311,20 +297,20 @@ def test_decode_multi_int(self, int_encoder): assert decoded[1] == 4, "Second decoded value does not match." -class TestStrClassificationIntEncoder: +class TestStrClassificationEncoder: """Test suite for StrClassificationIntEncoder and StrClassificationScaledEncoder.""" @staticmethod @pytest.fixture def str_encoder(): - """Pytest fixture to instantiate StrClassificationIntEncoder.""" - return StrClassificationIntEncoder() + """Pytest fixture to instantiate StrClassificationEncoder.""" + return StrClassificationEncoder() @staticmethod @pytest.fixture def scaled_encoder(): - """Pytest fixture to instantiate StrClassificationScaledEncoder.""" - return StrClassificationScaledEncoder() + """Pytest fixture to instantiate StrClassificationEncoder with scale set to True""" + return StrClassificationEncoder(scale=True) @pytest.mark.parametrize("fixture", ["str_encoder", "scaled_encoder"]) def test_encode_raises_not_implemented(self, request, fixture): @@ -385,84 +371,66 @@ def test_decode_raises_not_implemented(self, request, fixture): encoder = request.getfixturevalue(fixture) with pytest.raises(NotImplementedError) as exc_info: encoder.decode(torch.tensor([0])) - assert "Decoding is not yet supported." in str(exc_info.value) + assert "Decoding is not yet supported for StrClassification." in str(exc_info.value) -class TestFloatRankEncoder: - """Test suite for FloatRankEncoder.""" +class TestNumericRankEncoder: + """Test suite for NumericRankEncoder.""" @staticmethod @pytest.fixture - def float_rank_encoder(): - """Fixture to instantiate the FloatRankEncoder.""" - return FloatRankEncoder() + def rank_encoder(): + """Fixture to instantiate the NumericRankEncoder.""" + return NumericRankEncoder() + + @staticmethod + @pytest.fixture + def scaled_encoder(): + """Fixture to instantiate the NumericRankEncoder with scale set to True.""" + return NumericRankEncoder(scale=True) - def test_encode_raises_not_implemented(self, float_rank_encoder): + @pytest.mark.parametrize("fixture", ["rank_encoder", "scaled_encoder"]) + def test_encode_raises_not_implemented(self, request, fixture): """Test that encoding a single float raises NotImplementedError.""" + encoder = request.getfixturevalue(fixture) with pytest.raises(NotImplementedError) as exc_info: - float_rank_encoder.encode(3.14) + encoder.encode(3.14) assert "Encoding a single float does not make sense. Use encode_all instead." in str(exc_info.value) - def test_encode_all_with_valid_floats(self, float_rank_encoder): + def test_encode_all_with_valid_rank(self, rank_encoder): """Test encoding a list of float values.""" input_vals = [3.14, 2.71, 1.41] - output = float_rank_encoder.encode_all(input_vals) + output = rank_encoder.encode_all(input_vals) assert isinstance(output, torch.Tensor), "Output should be a torch.Tensor." - assert output.dtype == torch.float32, "Tensor dtype should be float32." assert output.numel() == 3, "Tensor should have exactly three elements." - assert torch.allclose(output, torch.tensor([1.0, 0.5, 0.0])), "Encoded values do not match expected ranks." - - def test_encode_all_with_non_float_raises(self, float_rank_encoder): - """Test that encoding a non-float raises a ValueError.""" - with pytest.raises(ValueError) as exc_info: - float_rank_encoder.encode_all(["not_a_float"]) - assert "Expected input data to be a float" in str(exc_info.value), ( - "Expected ValueError with specific error message." - ) - - def test_decode_raises_not_implemented(self, float_rank_encoder): - """Test that decoding raises NotImplementedError.""" - with pytest.raises(NotImplementedError) as exc_info: - float_rank_encoder.decode(torch.tensor([0.0])) - assert "Decoding is not yet supported for FloatRank." in str(exc_info.value) - + assert output[0] == 2, "First encoded value does not match." + assert output[1] == 1, "Second encoded value does not match." + assert output[2] == 0, "Third encoded value does not match." -class TestIntRankEncoder: - """Test suite for IntRankEncoder.""" - - @staticmethod - @pytest.fixture - def int_rank_encoder(): - """Fixture to instantiate the IntRankEncoder.""" - return IntRankEncoder() - - def test_encode_raises_not_implemented(self, int_rank_encoder): - """Test that encoding a single integer raises NotImplementedError.""" - with pytest.raises(NotImplementedError) as exc_info: - int_rank_encoder.encode(3) - assert "Encoding a single integer does not make sense. Use encode_all instead." in str(exc_info.value) - - def test_encode_all_with_valid_integers(self, int_rank_encoder): - """Test encoding a list of integer values.""" - input_vals = [3, 1, 2] - output = int_rank_encoder.encode_all(input_vals) + def test_encode_all_with_valid_scaled_rank(self, scaled_encoder): + """Test encoding a list of float values.""" + input_vals = [3.14, 2.71, 1.41] + output = scaled_encoder.encode_all(input_vals) assert isinstance(output, torch.Tensor), "Output should be a torch.Tensor." - assert output.dtype == torch.int32, "Tensor dtype should be int32." assert output.numel() == 3, "Tensor should have exactly three elements." - assert torch.allclose(output, torch.tensor([2, 0, 1], dtype=torch.int32)), ( - "Encoded values do not match expected ranks." - ) + assert output[0] == pytest.approx(1), "First encoded value does not match." + assert output[1] == pytest.approx(0.5), "Second encoded value does not match." + assert output[2] == pytest.approx(0), "Third encoded value does not match." - def test_encode_all_with_non_int_raises(self, int_rank_encoder): - """Test that encoding a non-integer raises a ValueError.""" + @pytest.mark.parametrize("fixture", ["rank_encoder", "scaled_encoder"]) + def test_encode_all_with_non_numeric_raises(self, request, fixture): + """Test that encoding a non-float raises a ValueError.""" + encoder = request.getfixturevalue(fixture) with pytest.raises(ValueError) as exc_info: - int_rank_encoder.encode_all(["not_an_int"]) - assert "Expected input data to be a int" in str(exc_info.value), ( + encoder.encode_all(["not_numeric"]) + assert "Expected input data to be a float or int" in str(exc_info.value), ( "Expected ValueError with specific error message." ) - def test_decode_raises_not_implemented(self, int_rank_encoder): + @pytest.mark.parametrize("fixture", ["rank_encoder", "scaled_encoder"]) + def test_decode_raises_not_implemented(self, request, fixture): """Test that decoding raises NotImplementedError.""" + encoder = request.getfixturevalue(fixture) with pytest.raises(NotImplementedError) as exc_info: - int_rank_encoder.decode(torch.tensor([0])) - assert "Decoding is not yet supported for IntRank." in str(exc_info.value) + encoder.decode(torch.tensor([0.0])) + assert "Decoding is not yet supported for NumericRank." in str(exc_info.value)