Skip to content

Commit

Permalink
refactor(data): made data_handlers more explicit and added extra func…
Browse files Browse the repository at this point in the history
…tion to parse various configs.
  • Loading branch information
mathysgrapotte committed Feb 20, 2025
1 parent 3c4e9b4 commit bdac5e4
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 23 deletions.
4 changes: 0 additions & 4 deletions src/stimulus/data/data_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,10 +251,6 @@ def load_csv_per_split(self, csv_path: str, split: int) -> pl.DataFrame:
"""Load the part of csv file that has the specified split value.
Split is a number that for 0 is train, 1 is validation, 2 is test.
This is accessed through the column with category `split`. Example column name could be `split:split:int`.
NOTE that the aim of having this function is that depending on the training, validation and test scenarios,
we are gonna load only the relevant data for it.
"""
if "split" not in self.columns:
raise ValueError("The category split is not present in the csv file")
Expand Down
48 changes: 29 additions & 19 deletions src/stimulus/data/interface/data_config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,27 @@ def create_encoders(column_config: list[Columns]) -> dict[str, encoders_module.A
}


def create_transforms(transform_config: list[TransformColumns]) -> dict[str, list[Any]]:
"""Factory for creating transforms from config"""
return {
column.column_name: [
_instantiate_component(
module=transforms_module,
name=transformation.name,
params=transformation.params,
)
for transformation in column.transformations
]
for column in transform_config.columns
}
def create_transforms(transform_config: list[Transform]) -> dict[str, list[Any]]:
"""Factory for creating transforms from config.
Args:
transform_config: List of Transform objects from the YAML config
Returns:
Dictionary mapping column names to lists of instantiated transform objects
"""
transforms = {}
for transform in transform_config:
for column in transform.columns:
transforms[column.column_name] = [
_instantiate_component(
module=transforms_module,
name=transformation.name,
params=transformation.params,
)
for transformation in column.transformations
]
return transforms


def create_splitter(split_config: Split) -> splitters_module.AbstractSplitter:
Expand All @@ -83,12 +91,13 @@ def create_splitter(split_config: Split) -> splitters_module.AbstractSplitter:
)


def parse_data_config(
def parse_split_transform_config(
config: SplitTransformDict,
) -> tuple[
dict[str, encoders_module.AbstractEncoder],
dict[str, list[transforms_module.AbstractTransform]],
splitters_module.AbstractSplitter,
list[str],
list[str],
list[str],
]:
"""Parse the configuration and return a dictionary of the parsed configuration.
Expand All @@ -99,10 +108,11 @@ def parse_data_config(
A tuple of the parsed configuration.
"""
encoders = create_encoders(config.columns)
transforms = create_transforms(config.transforms)
splitter = create_splitter(config.split)
input_columns = [column.column_name for column in config.columns if column.column_type == "input"]
label_columns = [column.column_name for column in config.columns if column.column_type == "label"]
meta_columns = [column.column_name for column in config.columns if column.column_type == "meta"]

return encoders, transforms, splitter
return encoders, input_columns, label_columns, meta_columns


def extract_transform_parameters_at_index(
Expand Down

0 comments on commit bdac5e4

Please sign in to comment.