Skip to content

Commit

Permalink
fix: fix linting
Browse files Browse the repository at this point in the history
  • Loading branch information
mathysgrapotte committed Feb 19, 2025
1 parent 20135c3 commit dfb7fca
Show file tree
Hide file tree
Showing 20 changed files with 205 additions and 270 deletions.
5 changes: 2 additions & 3 deletions src/stimulus/cli/check_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def main(

encoder_loader = loaders.EncoderLoader()
encoder_loader.initialize_column_encoders_from_config(
column_config=data_config.columns
column_config=data_config.columns,
)

logger.info("Dataset loaded successfully.")
Expand All @@ -138,8 +138,7 @@ def main(

logger.info("Model class loaded successfully.")

ray_config_loader = yaml_model_schema.YamlRayConfigLoader(
model=model_config)
ray_config_loader = yaml_model_schema.YamlRayConfigLoader(model=model_config)
ray_config_dict = ray_config_loader.get_config().model_dump()
ray_config_model = ray_config_loader.get_config()

Expand Down
6 changes: 5 additions & 1 deletion src/stimulus/cli/split_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ def get_args() -> argparse.Namespace:


def main(
data_csv: str, config_yaml: str, out_path: str, *, force: bool = False
data_csv: str,
config_yaml: str,
out_path: str,
*,
force: bool = False,
) -> None:
"""Connect CSV and YAML configuration and handle sanity checks.
Expand Down
4 changes: 1 addition & 3 deletions src/stimulus/cli/split_transforms.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,7 @@ def main(config_yaml: str, out_dir_path: str) -> None:
yaml_config_dict: YamlSplitConfigDict = YamlSplitConfigDict(**yaml_config)

# Generate the yaml files for each transform
split_transform_configs: list[YamlSplitTransformDict] = (
generate_split_transform_configs(yaml_config_dict)
)
split_transform_configs: list[YamlSplitTransformDict] = generate_split_transform_configs(yaml_config_dict)

# Dump all the YAML configs into files
dump_yaml_list_into_files(split_transform_configs, out_dir_path, "test_transforms")
Expand Down
3 changes: 1 addition & 2 deletions src/stimulus/cli/transform_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
def get_args() -> argparse.Namespace:
"""Get the arguments when using from the commandline."""
parser = argparse.ArgumentParser(
description="CLI for transforming CSV data files using YAML configuration."
description="CLI for transforming CSV data files using YAML configuration.",
)
parser.add_argument(
"-c",
Expand Down Expand Up @@ -57,7 +57,6 @@ def main(data_csv: str, config_yaml: str, out_path: str) -> None:
with open(config_yaml) as f:
yaml_config = YamlSplitConfigDict(**yaml.safe_load(f))
transform_loader = TransformLoader(seed=yaml_config.global_params.seed)
print(transform_config)
transform_loader.initialize_column_data_transformers_from_config(transform_config)
transform_manager = TransformManager(transform_loader)

Expand Down
8 changes: 3 additions & 5 deletions src/stimulus/cli/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,18 +179,16 @@ def main(
"""
with open(model_config_path) as file:
model_config_dict: dict[str, Any] = yaml.safe_load(file)
model_config: yaml_model_schema.Model = yaml_model_schema.Model(
**model_config_dict)
model_config: yaml_model_schema.Model = yaml_model_schema.Model(**model_config_dict)

encoder_loader = loaders.EncoderLoader()
encoder_loader.initialize_column_encoders_from_config(
column_config=data_config.columns
column_config=data_config.columns,
)

model_class = launch_utils.import_class_from_file(model_path)

ray_config_loader = yaml_model_schema.YamlRayConfigLoader(
model=model_config)
ray_config_loader = yaml_model_schema.YamlRayConfigLoader(model=model_config)
ray_config_model = ray_config_loader.get_config()

tuner = raytune_learner.TuneWrapper(
Expand Down
74 changes: 29 additions & 45 deletions src/stimulus/data/data_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def _load_config(self, config_path: str) -> yaml_data.YamlSplitConfigDict:
>>> print(config["columns"][0]["column_name"])
'hello'
"""

with open(config_path) as file:
# FIXME: cette fonction est appellé pour test_shuffle_csv et test_tune
return yaml_data.YamlSplitConfigDict(**yaml.safe_load(file))
Expand Down Expand Up @@ -190,8 +189,7 @@ def encode_column(self, column_name: str, column_data: list) -> torch.Tensor:
>>> print(encoded.shape)
torch.Size([2, 4, 4]) # 2 sequences, length 4, one-hot encoded
"""
encode_all_function = self.encoder_loader.get_function_encode_all(
column_name)
encode_all_function = self.encoder_loader.get_function_encode_all(column_name)
return encode_all_function(column_data)

def encode_columns(self, column_data: dict) -> dict:
Expand All @@ -213,16 +211,11 @@ def encode_columns(self, column_data: dict) -> dict:
>>> print(encoded["dna_seq"].shape)
torch.Size([2, 4, 4]) # 2 sequences, length 4, one-hot encoded
"""
return {
col: self.encode_column(col, values) for col, values in column_data.items()
}
return {col: self.encode_column(col, values) for col, values in column_data.items()}

def encode_dataframe(self, dataframe: pl.DataFrame) -> dict[str, torch.Tensor]:
"""Encode the dataframe using the encoders."""
return {
col: self.encode_column(col, dataframe[col].to_list())
for col in dataframe.columns
}
return {col: self.encode_column(col, dataframe[col].to_list()) for col in dataframe.columns}


class TransformManager:
Expand All @@ -236,7 +229,10 @@ def __init__(
self.transform_loader = transform_loader

def transform_column(
self, column_name: str, transform_name: str, column_data: list
self,
column_name: str,
transform_name: str,
column_data: list,
) -> tuple[list, bool]:
"""Transform a column of data using the specified transformation.
Expand All @@ -249,9 +245,7 @@ def transform_column(
list: The transformed data.
bool: Whether the transformation added new rows to the data.
"""
transformer = self.transform_loader.__getattribute__(column_name)[
transform_name
]
transformer = self.transform_loader.__getattribute__(column_name)[transform_name]
return transformer.transform_all(column_data), transformer.add_row


Expand All @@ -266,7 +260,8 @@ def __init__(
self.split_loader = split_loader

def get_split_indices(
self, data: dict
self,
data: dict,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Get the indices for train, validation, and test splits."""
return self.split_loader.get_function_split()(data)
Expand Down Expand Up @@ -370,8 +365,7 @@ def add_split(self, split_manager: SplitManager, *, force: bool = False) -> None
split_input_data = self.select_columns(split_columns)

# get the split indices
train, validation, test = split_manager.get_split_indices(
split_input_data)
train, validation, test = split_manager.get_split_indices(split_input_data)

# add the split column to the data
split_column = np.full(len(self.data), -1).astype(int)
Expand All @@ -397,12 +391,12 @@ def apply_transformation_group(self, transform_manager: TransformManager) -> Non
)
if add_row:
new_rows = self.data.with_columns(
pl.Series(column_name, transformed_data)
pl.Series(column_name, transformed_data),
)
self.data = pl.vstack(self.data, new_rows)
else:
self.data = self.data.with_columns(
pl.Series(column_name, transformed_data)
pl.Series(column_name, transformed_data),
)

def shuffle_labels(self, seed: Optional[float] = None) -> None:
Expand All @@ -413,7 +407,7 @@ def shuffle_labels(self, seed: Optional[float] = None) -> None:
label_keys = self.dataset_manager.column_categories["label"]
for key in label_keys:
self.data = self.data.with_columns(
pl.Series(key, np.random.permutation(list(self.data[key])))
pl.Series(key, np.random.permutation(list(self.data[key]))),
)


Expand All @@ -430,11 +424,7 @@ def __init__(
"""Initialize the DatasetLoader."""
super().__init__(data_config, csv_path)
self.encoder_manager = EncodeManager(encoder_loader)
self.data = (
self.load_csv_per_split(csv_path, split)
if split is not None
else self.load_csv(csv_path)
)
self.data = self.load_csv_per_split(csv_path, split) if split is not None else self.load_csv(csv_path)

def get_all_items(self) -> tuple[dict, dict, dict]:
"""Get the full dataset as three separate dictionaries for inputs, labels and metadata.
Expand All @@ -460,10 +450,8 @@ def get_all_items(self) -> tuple[dict, dict, dict]:
self.dataset_manager.column_categories["label"],
self.dataset_manager.column_categories["meta"],
)
input_data = self.encoder_manager.encode_dataframe(
self.data[input_columns])
label_data = self.encoder_manager.encode_dataframe(
self.data[label_columns])
input_data = self.encoder_manager.encode_dataframe(self.data[input_columns])
label_data = self.encoder_manager.encode_dataframe(self.data[label_columns])
meta_data = {key: self.data[key].to_list() for key in meta_columns}
return input_data, label_data, meta_data

Expand All @@ -481,11 +469,10 @@ def load_csv_per_split(self, csv_path: str, split: int) -> pl.DataFrame:
we are gonna load only the relevant data for it.
"""
if "split" not in self.columns:
raise ValueError(
"The category split is not present in the csv file")
raise ValueError("The category split is not present in the csv file")
if split not in [0, 1, 2]:
raise ValueError(
f"The split value should be 0, 1 or 2. The specified split value is {split}"
f"The split value should be 0, 1 or 2. The specified split value is {split}",
)
return pl.scan_csv(csv_path).filter(pl.col("split") == split).collect()

Expand All @@ -494,7 +481,8 @@ def __len__(self) -> int:
return len(self.data)

def __getitem__(
self, idx: Any
self,
idx: Any,
) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor], dict[str, list]]:
"""Get the data at a given index, and encodes the input and label, leaving meta as it is.
Expand All @@ -514,23 +502,20 @@ def __getitem__(

# Process DataFrame
input_data = self.encoder_manager.encode_dataframe(
data_at_index[input_columns]
data_at_index[input_columns],
)
label_data = self.encoder_manager.encode_dataframe(
data_at_index[label_columns]
data_at_index[label_columns],
)
meta_data = {key: data_at_index[key].to_list()
for key in meta_columns}
meta_data = {key: data_at_index[key].to_list() for key in meta_columns}

elif isinstance(idx, int):
# For single row, convert to dict with column names as keys
row_dict = dict(zip(self.data.columns, self.data.row(idx)))

# Create single-row DataFrames for encoding
input_df = pl.DataFrame(
{col: [row_dict[col]] for col in input_columns})
label_df = pl.DataFrame(
{col: [row_dict[col]] for col in label_columns})
input_df = pl.DataFrame({col: [row_dict[col]] for col in input_columns})
label_df = pl.DataFrame({col: [row_dict[col]] for col in label_columns})

input_data = self.encoder_manager.encode_dataframe(input_df)
label_data = self.encoder_manager.encode_dataframe(label_df)
Expand All @@ -541,12 +526,11 @@ def __getitem__(

# Process DataFrame
input_data = self.encoder_manager.encode_dataframe(
data_at_index[input_columns]
data_at_index[input_columns],
)
label_data = self.encoder_manager.encode_dataframe(
data_at_index[label_columns]
data_at_index[label_columns],
)
meta_data = {key: data_at_index[key].to_list()
for key in meta_columns}
meta_data = {key: data_at_index[key].to_list() for key in meta_columns}

return input_data, label_data, meta_data
Loading

0 comments on commit dfb7fca

Please sign in to comment.