Skip to content

Commit

Permalink
Merge pull request instructlab#242 from danmcp/fulltax
Browse files Browse the repository at this point in the history
Add taxonomy_base='null' option to support using the full taxonomy repo contents
  • Loading branch information
n1hility authored Aug 16, 2024
2 parents b814dfe + 680cbfd commit 889c9da
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 26 deletions.
27 changes: 21 additions & 6 deletions src/instructlab/sdg/utils/taxonomy.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,17 @@ def _get_taxonomy_diff(repo="taxonomy", base="origin/main"):
return updated_taxonomy_files


def _get_taxonomy(repo="taxonomy"):
repo = Path(repo)
taxonomy_file_paths = []
for root, _, files in os.walk(repo):
for file in files:
file_path = Path(root).joinpath(file).relative_to(repo)
if _istaxonomyfile(file_path):
taxonomy_file_paths.append(str(file_path))
return taxonomy_file_paths


def _get_documents(
source: Dict[str, Union[str, List[str]]],
skip_checkout: bool = False,
Expand Down Expand Up @@ -400,15 +411,19 @@ def read_taxonomy(taxonomy, taxonomy_base, yaml_rules):
if errors:
raise SystemExit(yaml.YAMLError("Taxonomy file with errors! Exiting."))
else: # taxonomy is dir
# Gather the new or changed YAMLs using git diff
updated_taxonomy_files = _get_taxonomy_diff(taxonomy, taxonomy_base)
if taxonomy_base == "empty":
# Gather all the yamls - equivalent to a diff against "the null tree"
taxonomy_files = _get_taxonomy(taxonomy)
else:
# Gather the new or changed YAMLs using git diff, including untracked files
taxonomy_files = _get_taxonomy_diff(taxonomy, taxonomy_base)
total_errors = 0
total_warnings = 0
if updated_taxonomy_files:
logger.debug("Found new taxonomy files:")
for e in updated_taxonomy_files:
if taxonomy_files:
logger.debug("Found taxonomy files:")
for e in taxonomy_files:
logger.debug(f"* {e}")
for f in updated_taxonomy_files:
for f in taxonomy_files:
file_path = os.path.join(taxonomy, f)
data, warnings, errors = _read_taxonomy_file(file_path, yaml_rules)
total_warnings += warnings
Expand Down
70 changes: 50 additions & 20 deletions tests/test_taxonomy.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@

TEST_SEED_EXAMPLE = "Can you help me debug this failing unit test?"

TEST_TAXONOMY_BASE = "main"

TEST_CUSTOM_YAML_RULES = b"""extends: relaxed
rules:
Expand All @@ -50,26 +48,58 @@ class TestTaxonomy:
def _init_taxonomy(self, taxonomy_dir):
self.taxonomy = taxonomy_dir

def test_read_taxonomy_leaf_nodes(self):
@pytest.mark.parametrize(
"taxonomy_base, create_tracked_file, create_untracked_file, check_leaf_node_keys",
[
("main", True, True, ["compositional_skills->new"]),
("main", False, True, ["compositional_skills->new"]),
("main", True, False, []),
("main", False, False, []),
("main^", True, False, ["compositional_skills->tracked"]),
(
"main^",
True,
True,
["compositional_skills->new", "compositional_skills->tracked"],
),
("empty", True, False, ["compositional_skills->tracked"]),
(
"empty",
True,
True,
["compositional_skills->new", "compositional_skills->tracked"],
),
],
)
def test_read_taxonomy_leaf_nodes(
self,
taxonomy_base,
create_tracked_file,
create_untracked_file,
check_leaf_node_keys,
):
tracked_file = "compositional_skills/tracked/qna.yaml"
untracked_file = "compositional_skills/new/qna.yaml"
self.taxonomy.add_tracked(tracked_file, TEST_VALID_COMPOSITIONAL_SKILL_YAML)
self.taxonomy.create_untracked(
untracked_file, TEST_VALID_COMPOSITIONAL_SKILL_YAML
)
if create_tracked_file:
self.taxonomy.add_tracked(tracked_file, TEST_VALID_COMPOSITIONAL_SKILL_YAML)
if create_untracked_file:
self.taxonomy.create_untracked(
untracked_file, TEST_VALID_COMPOSITIONAL_SKILL_YAML
)

leaf_node = taxonomy.read_taxonomy_leaf_nodes(
self.taxonomy.root, TEST_TAXONOMY_BASE, TEST_CUSTOM_YAML_RULES
)
leaf_node_key = str(pathlib.Path(untracked_file).parent).replace(
os.path.sep, "->"
leaf_nodes = taxonomy.read_taxonomy_leaf_nodes(
self.taxonomy.root, taxonomy_base, TEST_CUSTOM_YAML_RULES
)
assert leaf_node_key in leaf_node

leaf_node_entries = leaf_node.get(leaf_node_key)
seed_example_exists = False
if any(
entry["instruction"] == TEST_SEED_EXAMPLE for entry in leaf_node_entries
):
seed_example_exists = True
assert seed_example_exists is True
assert len(leaf_nodes) == len(check_leaf_node_keys)

for leaf_node_key in check_leaf_node_keys:
assert leaf_node_key in leaf_nodes

leaf_node_entries = leaf_nodes.get(leaf_node_key)
seed_example_exists = False
if any(
entry["instruction"] == TEST_SEED_EXAMPLE for entry in leaf_node_entries
):
seed_example_exists = True
assert seed_example_exists is True

0 comments on commit 889c9da

Please sign in to comment.