Skip to content

Commit

Permalink
fix: fixed type checker.
Browse files Browse the repository at this point in the history
  • Loading branch information
mathysgrapotte committed Feb 19, 2025
1 parent dfb7fca commit 48e5ef8
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 36 deletions.
6 changes: 3 additions & 3 deletions src/stimulus/learner/raytune_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def setup(self, config: dict[Any, Any]) -> None:
seed_filename = os.path.join(debug_dir, "seeds.txt")

# save the initialized model weights
self.export_model(export_dir=debug_dir)
self.export_model(export_formats=[], export_dir=debug_dir)

# save the seeds
with open(seed_filename, "a") as seed_f:
Expand Down Expand Up @@ -318,11 +318,11 @@ def objective(self) -> dict[str, float]:
**{"train_" + metric: value for metric, value in predict_train.compute_metrics(metrics).items()},
}

# type: ignore[override]
def export_model(self, export_dir: str | None = None) -> None:
def export_model(self, export_formats: list[str] | str, export_dir: str | None = None) -> Any:
"""Export model to safetensors format."""
if export_dir is None:
return
_ = export_formats # Unused parameter required for compatibility with Trainable
safe_save_model(self.model, os.path.join(export_dir, "model.safetensors"))

def load_checkpoint(self, checkpoint: dict[Any, Any] | None) -> None:
Expand Down
7 changes: 1 addition & 6 deletions src/stimulus/utils/yaml_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,12 +370,7 @@ def write_line_break(self, _data: Any = None) -> None:
if len(self.indents) <= 1: # At root level
super().write_line_break(_data)

def increase_indent(
self,
*,
flow: bool = False,
indentless: bool = False,
) -> None: # type: ignore[override]
def increase_indent(self, flow: bool = False, indentless: bool = False) -> None: # noqa: FBT001, FBT002
"""Ensure consistent indentation by preventing indentless sequences."""
return super().increase_indent(
flow=flow,
Expand Down
25 changes: 8 additions & 17 deletions tests/cli/test_shuffle_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
from typing import Any, Callable

import pytest
import yaml

from src.stimulus.cli.shuffle_csv import main
from src.stimulus.utils.yaml_data import YamlSplitTransformDict


# Fixtures
Expand Down Expand Up @@ -44,18 +42,11 @@ def test_shuffle_csv(
csv_path = request.getfixturevalue(csv_type)
yaml_path = request.getfixturevalue(yaml_type)
tmpdir = pathlib.Path(tempfile.gettempdir())
with open(yaml_path) as f:
if error:
config_dict: YamlSplitTransformDict = YamlSplitTransformDict(
**yaml.safe_load(f),
)
with pytest.raises(error): # type: ignore[call-overload]
main(csv_path, config_dict, str(tmpdir / "test.csv"))
else:
config_dict: YamlSplitTransformDict = YamlSplitTransformDict(
**yaml.safe_load(f),
)
main(csv_path, config_dict, str(tmpdir / "test.csv"))
with open(tmpdir / "test.csv") as file:
hash = hashlib.md5(file.read().encode()).hexdigest() # noqa: S324
assert hash == snapshot
if error:
with pytest.raises(error): # type: ignore[call-overload]
main(csv_path, yaml_path, str(tmpdir / "test.csv"))
else:
main(csv_path, yaml_path, str(tmpdir / "test.csv"))
with open(tmpdir / "test.csv") as file:
hash = hashlib.md5(file.read().encode()).hexdigest() # noqa: S324
assert hash == snapshot
8 changes: 4 additions & 4 deletions tests/cli/test_split_splits.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ def test_split_split(
tmpdir = tmp_path
if error:
with pytest.raises(error): # type: ignore[call-overload]
split_split.main(yaml_path, tmpdir)
split_split.main(yaml_path, str(tmpdir))
else:
split_split.main(yaml_path, tmpdir) # main() returns None, no need to assert
files = os.listdir(tmpdir)
split_split.main(yaml_path, str(tmpdir)) # main() returns None, no need to assert
files = os.listdir(str(tmpdir))
test_out = [f for f in files if f.startswith("test_")]
hashes = []
for f in test_out:
with open(os.path.join(tmpdir, f)) as file:
with open(os.path.join(str(tmpdir), f)) as file:
hashes.append(hashlib.md5(file.read().encode()).hexdigest()) # noqa: S324
assert sorted(hashes) == snapshot # sorted ensures that the order of the hashes does not matter
10 changes: 5 additions & 5 deletions tests/cli/test_split_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ def test_split_transforms(
yaml_path: str = request.getfixturevalue(yaml_type)
tmpdir = tmp_path
if error:
with pytest.raises(error):
main(yaml_path, tmpdir)
with pytest.raises(type(error)):
main(yaml_path, str(tmpdir))
else:
main(yaml_path, tmpdir)
files = os.listdir(tmpdir)
main(yaml_path, str(tmpdir))
files = os.listdir(str(tmpdir))
test_out = [f for f in files if f.startswith("test_")]
hashes = []
for f in test_out:
with open(os.path.join(tmpdir, f)) as file:
with open(os.path.join(str(tmpdir), f)) as file:
hashes.append(hashlib.sha256(file.read().encode()).hexdigest())
assert sorted(hashes) == snapshot # Sorted ensures that the order of the hashes does not matter
2 changes: 1 addition & 1 deletion tests/utils/test_data_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_split_config_validation(load_titanic_yaml_from_file: YamlConfigDict) ->


def test_sub_config_validation(
load_split_config_yaml_from_file: YamlConfigDict,
load_split_config_yaml_from_file: YamlSplitConfigDict,
) -> None:
"""Test sub-config validation."""
split_config = generate_split_transform_configs(
Expand Down

0 comments on commit 48e5ef8

Please sign in to comment.