Skip to content

Commit

Permalink
Convert validate_actions into a post validator
Browse files Browse the repository at this point in the history
  • Loading branch information
ghickman committed May 20, 2022
1 parent f6126f1 commit 7392e4b
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 46 deletions.
46 changes: 29 additions & 17 deletions pipeline/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class Outputs(BaseModel):
moderately_sensitive: Optional[Dict[str, str]]
minimally_sensitive: Optional[Dict[str, str]]

def __len__(self) -> int:
return len(self.dict(exclude_unset=True))

@root_validator()
def at_least_one_output(cls, outputs: Dict[str, str]) -> Dict[str, str]:
if not any(outputs.values()):
Expand Down Expand Up @@ -138,23 +141,20 @@ def all_actions(self) -> List[str]:
"""
return [action for action in self.actions.keys() if action != RUN_ALL_COMMAND]

@root_validator(pre=True)
def validate_actions(cls, values: RawPipeline) -> RawPipeline:
for action_id, config in values["actions"].items():
if config["run"] == "":
# key is present but empty
raise ValueError(
f"run must have a value, {action_id} has an empty run key"
)

validators = {
"cohortextractor:latest generate_cohort": validate_cohortextractor_outputs,
"cohortextractor-v2:latest generate_cohort": validate_databuilder_outputs,
"databuilder:latest generate_dataset": validate_databuilder_outputs,
}
for cmd, validator_func in validators.items():
if config["run"].startswith(cmd):
validator_func(action_id, config)
@root_validator()
def validate_actions(
cls, values: PartiallyValidatedPipeline
) -> PartiallyValidatedPipeline:
# TODO: move to Action when we move name onto it
for action_id, config in values.get("actions", {}).items():
validators = {
"cohortextractor:latest generate_cohort": validate_cohortextractor_outputs,
"cohortextractor-v2:latest generate_cohort": validate_databuilder_outputs,
"databuilder:latest generate_dataset": validate_databuilder_outputs,
}
for cmd, validator_func in validators.items():
if config.run.raw.startswith(cmd):
validator_func(action_id, config)

return values

Expand Down Expand Up @@ -221,6 +221,18 @@ def validate_outputs_per_version(

return values

@root_validator(pre=True)
def validate_actions_run(cls, values: RawPipeline) -> RawPipeline:
# TODO: move to Action when we move name onto it
for action_id, config in values.get("actions", {}).items():
if config["run"] == "":
# key is present but empty
raise ValueError(
f"run must have a value, {action_id} has an empty run key"
)

return values

@validator("actions")
def validate_unique_commands(cls, actions: Dict[str, Action]) -> Dict[str, Action]:
seen: Dict[Command, List[str]] = defaultdict(list)
Expand Down
14 changes: 8 additions & 6 deletions pipeline/outputs.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from __future__ import annotations

from pathlib import PurePosixPath
from typing import Iterator
from typing import TYPE_CHECKING, Iterator

from .types import RawOutputs

if TYPE_CHECKING: # pragma: no cover
from .models import Outputs

def get_first_output_file(output_spec: RawOutputs) -> str:

def get_first_output_file(output_spec: Outputs) -> str:
return next(iter_all_outputs(output_spec))


def get_output_dirs(output_spec: RawOutputs) -> list[PurePosixPath]:
def get_output_dirs(output_spec: Outputs) -> list[PurePosixPath]:
"""
Given the set of output files specified by an action, return a list of the
unique directory names of those outputs
Expand All @@ -20,6 +22,6 @@ def get_output_dirs(output_spec: RawOutputs) -> list[PurePosixPath]:
return list({PurePosixPath(filename).parent for filename in filenames})


def iter_all_outputs(output_spec: RawOutputs) -> Iterator[str]:
for group in output_spec.values():
def iter_all_outputs(output_spec: Outputs) -> Iterator[str]:
for group in output_spec.dict(exclude_unset=True).values():
yield from group.values()
29 changes: 15 additions & 14 deletions pipeline/validation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from __future__ import annotations

import posixpath
import shlex
from pathlib import PurePosixPath, PureWindowsPath
from typing import TYPE_CHECKING

from .exceptions import InvalidPatternError
from .outputs import get_first_output_file, get_output_dirs
from .types import RawAction


if TYPE_CHECKING: # pragma: no cover
from .models import Action


def assert_valid_glob_pattern(pattern: str) -> None:
Expand Down Expand Up @@ -50,35 +53,33 @@ def assert_valid_glob_pattern(pattern: str) -> None:
raise InvalidPatternError("is an absolute path")


def validate_cohortextractor_outputs(action_id: str, action: RawAction) -> None:
def validate_cohortextractor_outputs(action_id: str, action: Action) -> None:
"""
Check cohortextractor's output config is valid for this command
We can't validate outputs in the Action or Outputs models because we need
to look up other fields (eg run).
"""
# ensure we only have output level defined.
num_output_levels = len(action["outputs"])
num_output_levels = len(action.outputs)
if num_output_levels != 1:
raise ValueError(
"A `generate_cohort` action must have exactly one output; "
f"{action_id} had {num_output_levels}"
)

output_dirs = get_output_dirs(action["outputs"])
output_dirs = get_output_dirs(action.outputs)
if len(output_dirs) == 1:
return

# the calling function for this validator, Pipeline.validate_actions, has
# already checked the run key is valid for us.
run_args = shlex.split(action["run"])

# If we detect multiple output directories but the command explicitly
# specifies an output directory then we assume the user knows what
# they're doing and don't attempt to modify the output directory or
# throw an error
flag = "--output-dir"
has_output_dir = any(arg == flag or arg.startswith(f"{flag}=") for arg in run_args)
has_output_dir = any(
arg == flag or arg.startswith(f"{flag}=") for arg in action.run.parts
)
if not has_output_dir:
raise ValueError(
f"generate_cohort command should produce output in only one "
Expand All @@ -87,21 +88,21 @@ def validate_cohortextractor_outputs(action_id: str, action: RawAction) -> None:
)


def validate_databuilder_outputs(action_id: str, action: RawAction) -> None:
def validate_databuilder_outputs(action_id: str, action: Action) -> None:
"""
Check databuilder's output config is valid for this command
We can't validate outputs in the Action or Outputs models because we need
to look up other fields (eg run).
"""
# TODO: should this be checking output _paths_ instead of levels?
num_output_levels = len(action["outputs"])
num_output_levels = len(action.outputs)
if num_output_levels != 1:
raise ValueError(
"A `generate_dataset` action must have exactly one output; "
f"{action_id} had {num_output_levels}"
)

first_output_file = get_first_output_file(action["outputs"])
if first_output_file not in action["run"]:
first_output_file = get_first_output_file(action.outputs)
if first_output_file not in action.run.raw:
raise ValueError("--output in run command and outputs must match")
22 changes: 13 additions & 9 deletions tests/test_outputs.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,29 @@
from pathlib import Path

from pipeline.models import Outputs
from pipeline.outputs import get_output_dirs


def test_get_output_dirs_with_duplicates():
outputs = {
"one": {"a": "output/1a.csv"},
"two": {"a": "output/2a.csv"},
}
outputs = Outputs(
highly_sensitive={
"a": "output/1a.csv",
"b": "output/2a.csv",
},
)

dirs = get_output_dirs(outputs)

assert set(dirs) == {Path("output")}


def test_get_output_dirs_without_duplicates():
outputs = {
"one": {"a": "1a/output.csv"},
"two": {"a": "2a/output.csv"},
}

outputs = Outputs(
highly_sensitive={
"a": "1a/output.csv",
"b": "2a/output.csv",
},
)
dirs = get_output_dirs(outputs)

assert set(dirs) == {Path("1a"), Path("2a")}

0 comments on commit 7392e4b

Please sign in to comment.