Skip to content

Commit

Permalink
merge encoders with repeated functions
Browse files Browse the repository at this point in the history
  • Loading branch information
suzannejin committed Jan 17, 2025
1 parent 8340693 commit f404c10
Show file tree
Hide file tree
Showing 2 changed files with 193 additions and 258 deletions.
221 changes: 94 additions & 127 deletions src/stimulus/data/encoding/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,43 +297,45 @@ def decode(self, data: torch.Tensor) -> Union[str, List[str]]:
raise ValueError(f"Expected 2D or 3D tensor, got {data.dim()}D")


class FloatEncoder(AbstractEncoder):
"""Encoder for float data."""
class NumericEncoder(AbstractEncoder):
"""Encoder for float/int data."""

def __init__(self, dtype: torch.dtype = torch.float) -> None:
"""Initialize the FloatEncoder class.
def __init__(self, dtype: torch.dtype = torch.float32) -> None:
"""Initialize the NumericEncoder class.
Args:
dtype (torch.dtype): the data type of the encoded data. Default = torch.float (32-bit floating point)
"""
self.dtype = dtype

def encode(self, data: float) -> torch.Tensor:
def encode(self, data: Union[float, int]) -> torch.Tensor:
"""Encodes the data.
This method takes as input a single data point, should be mappable to a single output.
Args:
data (float): a single data point
data (float or int): a single data point
Returns:
encoded_data_point (torch.Tensor): the encoded data point
"""
self._check_input_dtype(data)
return self.encode_all(data) # there is no difference in this case

def encode_all(self, data: Union[float, List[float]]) -> torch.Tensor:
def encode_all(self, data: Union[float, int, List[float], List[int]]) -> torch.Tensor:
"""Encodes the data.
This method takes as input a list of data points, or a single float, and returns a torch.tensor.
Args:
data (Union[float, List[float]]): a list of data points or a single data point
data (float or int): a list of data points or a single data point
Returns:
encoded_data (torch.Tensor): the encoded data
"""
self._check_input_dtype(data)
if not isinstance(data, list):
data = [data]

self._check_input_dtype(data)
self._warn_float_is_converted_to_int(data)

return torch.tensor(data, dtype=self.dtype)

def decode(self, data: torch.Tensor) -> List[float]:
Expand All @@ -347,56 +349,64 @@ def decode(self, data: torch.Tensor) -> List[float]:
"""
return data.cpu().numpy().tolist()

def _check_input_dtype(self, data: Union[float, List[int], List[float]]) -> None:
def _check_input_dtype(self, data: Union[List[float], List[int]]) -> None:
"""Check if the input data is int or float data.
Args:
data (int or float): a single data point or a list of data points
data (float or int): a list of float or integer data points
Raises:
ValueError: If the input data is not a float
ValueError: If the input data contains a non-integer or non-float data point
"""
if (isinstance(data, list) and not all(isinstance(d, (int, float)) for d in data)) or (
not isinstance(data, list) and not isinstance(data, (int, float))
):
err_msg = f"Expected input data to be a float, got {type(data).__name__}"
if not all(isinstance(d, (int, float)) for d in data):
err_msg = f"Expected input data to be a float or int"
logger.error(err_msg)
raise ValueError(err_msg)


class IntEncoder(FloatEncoder):
"""Encoder for integer data."""

def __init__(self, dtype: torch.dtype = torch.int) -> None:
"""Initialize the IntEncoder class.
def _warn_float_is_converted_to_int(self, data: Union[List[float], List[int]]) -> None:
"""Warn if float data is encoded into int data.
Args:
dtype (torch.dtype): the data type of the encoded data. Default = torch.int (32-bit integer)
data (float or int): a list of float or integer data points
"""
super().__init__(dtype)
if any(isinstance(d, float) for d in data) and (self.dtype in [torch.int, torch.int8, torch.int16, torch.int32, torch.int64]):
logger.warning("Encoding float data to torch.int data type.")

def _check_input_dtype(self, data: Union[int, List[int]]) -> None:
"""Check if the input data is int data.

Args:
data (int): a single data point or a list of data points
class StrClassificationEncoder(AbstractEncoder):
"""
A string classification encoder that converts lists of strings into numeric labels using scikit-learn's
LabelEncoder. When scale is set to True, the labels are scaled to be between 0 and 1.
Raises:
ValueError: If the input data is not a int
"""
if (isinstance(data, list) and not all(isinstance(d, int) for d in data)) or (
not isinstance(data, list) and not isinstance(data, int)
):
err_msg = f"Expected input data to be a int, got {type(data).__name__}"
logger.error(err_msg)
raise ValueError(err_msg)
Attributes:
None
Methods:
encode(data: str) -> int:
Raises a NotImplementedError, as encoding a single string is not meaningful in this context.
encode_all(data: List[str]) -> torch.tensor:
Encodes an entire list of string data into a numeric representation using LabelEncoder and
returns a torch tensor. Ensures that the provided data items are valid strings prior to encoding.
decode(data: Any) -> Any:
Raises a NotImplementedError, as decoding is not supported with the current design.
_check_dtype(data: List[str]) -> None:
Validates that all items in the data list are strings, raising a ValueError otherwise.
"""

def __init__(self, scale: bool = False) -> None:
"""Initialize the StrClassificationEncoder class.
class StrClassificationIntEncoder(AbstractEncoder):
"""Considering a ensemble of strings, this encoder encodes them into integers from 0 to (n-1) where n is the number of unique strings."""
Args:
scale (bool): whether to scale the labels to be between 0 and 1. Default = False
"""
self.scale = scale

def encode(self, data: str) -> int:
"""Returns an error since encoding a single string does not make sense."""
"""Returns an error since encoding a single string does not make sense.
Args:
data (str): a single string
"""
raise NotImplementedError("Encoding a single string does not make sense. Use encode_all instead.")

def encode_all(self, data: List[str]) -> torch.tensor:
Expand All @@ -412,13 +422,19 @@ def encode_all(self, data: List[str]) -> torch.tensor:
"""
if not isinstance(data, list):
data = [data]

self._check_dtype(data)

encoder = preprocessing.LabelEncoder()
return torch.tensor(encoder.fit_transform(data))
encoded_data = torch.tensor(encoder.fit_transform(data))
if self.scale:
encoded_data = encoded_data / max(len(encoded_data) - 1, 1)

return encoded_data

def decode(self, data: Any) -> Any:
"""Returns an error since decoding does not make sense without encoder information, which is not yet supported."""
raise NotImplementedError("Decoding is not yet supported.")
raise NotImplementedError("Decoding is not yet supported for StrClassification.")

def _check_dtype(self, data: List[str]) -> None:
"""Check if the input data is string data.
Expand All @@ -435,62 +451,61 @@ def _check_dtype(self, data: List[str]) -> None:
raise ValueError(err_msg)


class StrClassificationScaledEncoder(StrClassificationIntEncoder):
"""Considering a ensemble of strings, this encoder encodes them into floats from 0 to 1 (essentially scaling the integer encoding)."""
class NumericRankEncoder(AbstractEncoder):
"""Encoder for float/int data that encodes the data based on their rank.
def encode_all(self, data: List[str]) -> torch.Tensor:
"""Encodes the data.
This method takes as input a list of data points, should be mappable to a single output, using LabelEncoder from scikit learn and returning a numpy array.
For more info visit : https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.LabelEncoder.html
Args:
data (List[str]): a list of strings
Attributes:
scale (bool): whether to scale the ranks to be between 0 and 1. Default = False
Returns:
encoded_data (torch.Tensor): the encoded data
"""
encoded_data = super().encode_all(data)
return encoded_data / max(len(encoded_data) - 1, 1)
Methods:
encode: encodes a single data point
encode_all: encodes a list of data points into a torch.tensor
decode: decodes a single data point
_check_input_dtype: checks if the input data is int or float data
_warn_float_is_converted_to_int: warns if float data is encoded into
"""

def __init__(self, scale: bool = False) -> None:
"""Initialize the NumericRankEncoder class.
class FloatRankEncoder(AbstractEncoder):
"""Considering an ensemble of float values, this encoder encodes them into floats from 0 to 1, where 1 is the maximum value and 0 is the minimum value."""
Args:
scale (bool): whether to scale the ranks to be between 0 and 1. Default = False
"""
self.scale = scale

def encode(self, data: float, dtype: torch.dtype = torch.float) -> torch.Tensor:
def encode(self, data: Any) -> torch.Tensor:
"""Returns an error since encoding a single float does not make sense."""
raise NotImplementedError("Encoding a single float does not make sense. Use encode_all instead.")

def encode_all(self, data: List[float], dtype: torch.dtype = torch.float) -> torch.Tensor:
def encode_all(self, data: Union[List[float], List[int]]) -> torch.Tensor:
"""Encodes the data.
This method takes as input a list of data points, converts them to numpy array, and returns the ranks of the data points.
The ranks are normalized to be between 0 and 1.
This method takes as input a list of data points, and returns the ranks of the data points.
The ranks are normalized to be between 0 and 1, when scale is set to True.
Args:
data (List[float]): a list of float values
data (Union[List[float], List[int]]): a list of numeric values
scale (bool): whether to scale the ranks to be between 0 and 1. Default = False
Returns:
encoded_data (torch.Tensor): the encoded data
"""
self._check_input_dtype(data)
if not isinstance(data, list):
data = [data]
try:
data = np.array(data, dtype=float)
# Get ranks (0 is lowest, n-1 is highest)
# and normalize to be between 0 and 1
ranks = np.argsort(np.argsort(data))
self._check_input_dtype(data)

# Get ranks (0 is lowest, n-1 is highest)
# and normalize to be between 0 and 1
data = np.array(data)
ranks = np.argsort(np.argsort(data))
if self.scale:
ranks = ranks / max(len(ranks) - 1, 1)
return torch.tensor(ranks, dtype=dtype)
except Exception as e:
err_msg = f"Failed to encode data: {e}"
logger.error(err_msg)
raise RuntimeError(err_msg) from e
return torch.tensor(ranks)

def decode(self, data: Any) -> Any:
"""Returns an error since decoding does not make sense without encoder information, which is not yet supported."""
raise NotImplementedError("Decoding is not yet supported for FloatRank.")
raise NotImplementedError("Decoding is not yet supported for NumericRank.")

def _check_input_dtype(self, data: Union[float, List[int], List[float]]) -> None:
def _check_input_dtype(self, data: list) -> None:
"""Check if the input data is int or float data.
Args:
Expand All @@ -499,55 +514,7 @@ def _check_input_dtype(self, data: Union[float, List[int], List[float]]) -> None
Raises:
ValueError: If the input data is not a float
"""
if (isinstance(data, list) and not all(isinstance(d, (int, float)) for d in data)) or (
not isinstance(data, list) and not isinstance(data, (int, float))
):
err_msg = f"Expected input data to be a float, got {type(data).__name__}"
logger.error(err_msg)
raise ValueError(err_msg)


class IntRankEncoder(FloatRankEncoder):
"""Considering an ensemble of integer values, this encoder encodes them into floats from 0 to 1, where 1 is the maximum value and 0 is the minimum value."""

def encode(self, data: int, dtype: torch.dtype = torch.int) -> torch.Tensor:
"""Returns an error since encoding a single integer does not make sense."""
raise NotImplementedError("Encoding a single integer does not make sense. Use encode_all instead.")

def encode_all(self, data: list, dtype: torch.dtype = torch.int) -> torch.Tensor:
"""Encodes the data.
This method takes as input a list of data points, should be mappable to a single output, using min-max scaling.
"""
self._check_input_dtype(data)
if not isinstance(data, list):
data = [data]
try:
data = np.array(data, dtype=int)
# Get ranks (0 is lowest, n-1 is highest)
# and normalize to be between 0 and 1
ranks = np.argsort(np.argsort(data))
return torch.tensor(ranks, dtype=dtype)
except Exception as e:
err_msg = f"Failed to encode data: {e}"
logger.error(err_msg)
raise RuntimeError(err_msg) from e

def decode(self, data: Any) -> Any:
"""Returns an error since decoding does not make sense without encoder information, which is not yet supported."""
raise NotImplementedError("Decoding is not yet supported for IntRank.")

def _check_input_dtype(self, data: Union[int, List[int]]) -> None:
"""Check if the input data is int data.
Args:
data (int): a single data point or a list of data points
Raises:
ValueError: If the input data is not a int
"""
if (isinstance(data, list) and not all(isinstance(d, int) for d in data)) or (
not isinstance(data, list) and not isinstance(data, int)
):
err_msg = f"Expected input data to be a int, got {type(data).__name__}"
if not all(isinstance(d, (int, float)) for d in data):
err_msg = f"Expected input data to be a float or int, got {type(data).__name__}"
logger.error(err_msg)
raise ValueError(err_msg)
Loading

0 comments on commit f404c10

Please sign in to comment.