diff --git a/.gitignore b/.gitignore index ceaa9f7a..b2911a08 100644 --- a/.gitignore +++ b/.gitignore @@ -164,3 +164,6 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +# IDEs +.vscode/ diff --git a/.pylintrc b/.pylintrc index 3b4da7a8..5cde4c71 100644 --- a/.pylintrc +++ b/.pylintrc @@ -90,7 +90,7 @@ persistent=yes # Minimum Python version to use for version dependent checks. Will default to # the version used to run pylint. -py-version=3.9 +py-version=3.10 # Discover python modules and packages in the file system subtree. recursive=no @@ -379,7 +379,8 @@ int-import-graph= known-standard-library= # Force import order to recognize a module as part of a third party library. -known-third-party=enchant +known-third-party=enchant, + instructlab.schema, # Couples of modules and preferred modules, separated by a comma. preferred-modules= diff --git a/.spellcheck-en-custom.txt b/.spellcheck-en-custom.txt index 0cc4115d..45868849 100644 --- a/.spellcheck-en-custom.txt +++ b/.spellcheck-en-custom.txt @@ -1,5 +1,12 @@ # make spellcheck-sort # Please keep this file sorted: # SPDX-License-Identifier: Apache-2.0 +Dataset +dataset +datasets +Pre +pre sdg +subfolder Tatsu +yaml diff --git a/docs/data_mixing.md b/docs/data_mixing.md index 3721840e..7dedb446 100644 --- a/docs/data_mixing.md +++ b/docs/data_mixing.md @@ -2,12 +2,13 @@ As one of the last steps in data generation, the SDG library can optionally mix multiple datasets into a single output dataset in proportions specified by a recipe yaml file. The current implementation is designed to be used with mostly static recipes, that get used by default for every `ilab data generate` run. There is not yet an easy way to specify the recipe to use with each generation run, but we do make it possible to change the default recipe used for skills and/or knowledge data generation. -The primary intended use of this is to specify an optional pregenerated dataset maintained by the InstructLab community that can improve training results when attempting to teach new skills to a model. This process is a bit manual for now, and the steps to do that are documented below. +The primary intended use of this is to specify an optional pre-generated dataset maintained by the InstructLab community that can improve training results when attempting to teach new skills to a model. This process is a bit manual for now, and the steps to do that are documented below. -## Using InstructLab Community Pregenerated Dataset +## Using InstructLab Community Pre-generated Dataset -To use the [InstructLab Community pregenerated dataset](https://huggingface.co/datasets/instructlab/InstructLabCommunity) with all skills training, we first need to create a default recipe that specifies this dataset to include when mixing generated skills data. This recipe will get automatically picked up if placed in a `default_data_recipes/skills.yaml` subfolder and file under one of several possible locations - `'/home//.local/share/instructlab/sdg'`, `'/usr/local/share/instructlab/sdg'`, or `'/usr/share/instructlab/sdg'`. The exact list of possible locations is platform-dependent, and can be enumerated by a Python command like below: -``` +To use the [InstructLab Community pre-generated dataset](https://huggingface.co/datasets/instructlab/InstructLabCommunity) with all skills training, we first need to create a default recipe that specifies this dataset to include when mixing generated skills data. This recipe will get automatically picked up if placed in a `default_data_recipes/skills.yaml` subfolder and file under one of several possible locations - `'/home//.local/share/instructlab/sdg'`, `'/usr/local/share/instructlab/sdg'`, or `'/usr/share/instructlab/sdg'`. The exact list of possible locations is platform-dependent, and can be enumerated by a Python command like below: + +```python python3 -c ' import os, platformdirs print(list(platformdirs.PlatformDirs( @@ -18,7 +19,8 @@ print(list(platformdirs.PlatformDirs( For this example, we'll assume you want to place to default data recipe under the `~/.local/share/instructlab/sdg/` platform directory. Ensure that directory exists and create the recipe yaml file: -``` + +```shell mkdir -p ~/.local/share/instructlab/sdg/default_data_recipes/ cat < ~/.local/share/instructlab/sdg/default_data_recipes/skills.yaml datasets: @@ -27,15 +29,15 @@ datasets: EOF ``` -Next, download the instructlab_community.jsonl file from https://huggingface.co/datasets/instructlab/InstructLabCommunity/tree/main and place it in `~/.local/share/instructlab/datasets/`, where the recipe we wrote above will pick it up. If you prefer to place this pregenerated dataset in a different location, you can specify the absolute path to that different location in your recipe yaml file instead of using relative paths as shown here. +Next, download the `instructlab_community.jsonl` file from and place it in `~/.local/share/instructlab/datasets/`, where the recipe we wrote above will pick it up. If you prefer to place this pre-generated dataset in a different location, you can specify the absolute path to that different location in your recipe yaml file instead of using relative paths as shown here. Then, during your next `ilab data generate`, you should see output near the end like: -``` +```log INFO 2024-08-06 16:08:42,069 instructlab.sdg.datamixing:123: Loading dataset from /home/user/.local/share/instructlab/datasets/instructlab_community.jsonl ... Generating train split: 13863 examples [00:00, 185935.73 examples/s] INFO 2024-08-06 16:08:42,414 instructlab.sdg.datamixing:125: Dataset columns: ['messages', 'metadata', 'id'] INFO 2024-08-06 16:08:42,414 instructlab.sdg.datamixing:126: Dataset loaded with 13863 samples ``` -Your resulting skills_train_*.jsonl file will now contain the additional 13k+ examples from the precomputed dataset, which should ensure your subsequent skills training doesn't regress in already-learned skills while being taught the new skill. +Your resulting `skills_train_*.jsonl` file will now contain the additional 13k+ examples from the pre-computed dataset, which should ensure your subsequent skills training doesn't regress in already-learned skills while being taught the new skill. diff --git a/pyproject.toml b/pyproject.toml index 8178ca06..88978472 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ authors = [ description = "Synthetic Data Generation" readme = "README.md" license = {text = "Apache-2.0"} -requires-python = ">=3.9" +requires-python = ">=3.10" classifiers = [ "Development Status :: 3 - Alpha", "Environment :: Console", @@ -22,7 +22,6 @@ classifiers = [ "Operating System :: POSIX :: Linux", "Topic :: Scientific/Engineering :: Artificial Intelligence", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", @@ -53,7 +52,7 @@ where = ["src"] include = ["instructlab.sdg"] [tool.ruff] -target-version = "py39" +target-version = "py310" # same as black's default line length line-length = 88 @@ -92,6 +91,7 @@ from-first = true known-local-folder = ["tuning"] [tool.mypy] +python_version = "3.10" disable_error_code = ["import-not-found", "import-untyped"] exclude = [ "^src/instructlab/sdg/generate_data\\.py$", diff --git a/requirements.txt b/requirements.txt index ac6d4762..0e0f4ac9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 click>=8.1.7,<9.0.0 httpx>=0.25.0,<1.0.0 +instructlab-schema>=0.4.0 langchain-text-splitters openai>=1.13.3,<2.0.0 platformdirs>=4.2 @@ -9,4 +10,3 @@ platformdirs>=4.2 # do not use 8.4.0 due to a bug in the library # https://github.com/instructlab/instructlab/issues/1389 tenacity>=8.3.0,!=8.4.0 -instructlab-schema>=0.3.1 diff --git a/src/instructlab/sdg/utils/taxonomy.py b/src/instructlab/sdg/utils/taxonomy.py index 5b8a0caf..9b472a48 100644 --- a/src/instructlab/sdg/utils/taxonomy.py +++ b/src/instructlab/sdg/utils/taxonomy.py @@ -1,57 +1,52 @@ # SPDX-License-Identifier: Apache-2.0 # Standard -from functools import cache from pathlib import Path -from typing import Any, Dict, List, Mapping, Optional, Union +from typing import Dict, List, Union import glob -import json import logging import os import re -import subprocess import tempfile # Third Party +from instructlab.schema.taxonomy import DEFAULT_TAXONOMY_FOLDERS as TAXONOMY_FOLDERS +from instructlab.schema.taxonomy import ( + TaxonomyMessageFormat, + TaxonomyParser, + TaxonomyReadingException, +) import git import gitdb import yaml -# First Party -from instructlab.sdg.utils import chunking +# Local +from . import chunking logger = logging.getLogger(__name__) -MIN_KNOWLEDGE_VERSION = 3 -DEFAULT_YAML_RULES = """\ -extends: relaxed - -rules: - line-length: - max: 120 -""" - - -# This is part of the public API. -class TaxonomyReadingException(Exception): - """An exception raised during reading of the taxonomy.""" - - -TAXONOMY_FOLDERS: List[str] = ["compositional_skills", "knowledge"] -"""Taxonomy folders which are also the schema names""" - - -def _istaxonomyfile(fn): +def _is_taxonomy_file(fn: str) -> bool: path = Path(fn) - if path.suffix == ".yaml" and path.parts[0] in TAXONOMY_FOLDERS: + if path.parts[0] not in TAXONOMY_FOLDERS: + return False + if path.name == "qna.yaml": return True + if path.name.casefold() in {"qna.yml", "qna.yaml"}: + # warning for incorrect extension or case variants + logger.warning( + "Found a '%s' file: %s: taxonomy files must be named 'qna.yaml'. File will not be checked.", + path.name, + path, + ) return False -def _get_taxonomy_diff(repo="taxonomy", base="origin/main"): - repo = git.Repo(repo) - untracked_files = [u for u in repo.untracked_files if _istaxonomyfile(u)] +def _get_taxonomy_diff( + repo_path: str | Path = "taxonomy", base: str = "origin/main" +) -> list[str]: + repo = git.Repo(repo_path) + untracked_files = [u for u in repo.untracked_files if _is_taxonomy_file(u)] branches = [b.name for b in repo.branches] @@ -90,7 +85,7 @@ def _get_taxonomy_diff(repo="taxonomy", base="origin/main"): modified_files = [ d.b_path for d in head_commit.diff(None) - if not d.deleted_file and _istaxonomyfile(d.b_path) + if not d.deleted_file and _is_taxonomy_file(d.b_path) ] updated_taxonomy_files = list(set(untracked_files + modified_files)) @@ -103,7 +98,7 @@ def _get_taxonomy(repo="taxonomy"): for root, _, files in os.walk(repo): for file in files: file_path = Path(root).joinpath(file).relative_to(repo) - if _istaxonomyfile(file_path): + if _is_taxonomy_file(file_path): taxonomy_file_paths.append(str(file_path)) return taxonomy_file_paths @@ -146,214 +141,25 @@ def _get_documents( raise e -@cache -def _load_schema(path: "importlib.resources.abc.Traversable") -> "referencing.Resource": - """Load the schema from the path into a Resource object. - - Args: - path (Traversable): Path to the schema to be loaded. - - Raises: - NoSuchResource: If the resource cannot be loaded. - - Returns: - Resource: A Resource containing the requested schema. - """ - # pylint: disable=C0415 - # Third Party - from referencing import Resource - from referencing.exceptions import NoSuchResource - from referencing.jsonschema import DRAFT202012 - - try: - contents = json.loads(path.read_text(encoding="utf-8")) - resource = Resource.from_contents( - contents=contents, default_specification=DRAFT202012 - ) - except Exception as e: - raise NoSuchResource(ref=str(path)) from e - return resource - - -def _validate_yaml(contents: Mapping[str, Any], taxonomy_path: Path) -> int: - """Validate the parsed yaml document using the taxonomy path to - determine the proper schema. - - Args: - contents (Mapping): The parsed yaml document to validate against the schema. - taxonomy_path (Path): Relative path of the taxonomy yaml document where the - first element is the schema to use. - - Returns: - int: The number of errors found during validation. - Messages for each error have been logged. - """ - # pylint: disable=C0415 - # Standard - from importlib import resources - - # Third Party - from jsonschema.protocols import Validator - from jsonschema.validators import validator_for - from referencing import Registry, Resource - from referencing.exceptions import NoSuchResource - from referencing.typing import URI - - errors = 0 - version = _get_version(contents) - schemas_path = resources.files("instructlab.schema").joinpath(f"v{version}") - - def retrieve(uri: URI) -> Resource: - path = schemas_path.joinpath(uri) - return _load_schema(path) - - schema_name = taxonomy_path.parts[0] - if schema_name not in TAXONOMY_FOLDERS: - schema_name = "knowledge" if "document" in contents else "compositional_skills" - logger.info( - f"Cannot determine schema name from path {taxonomy_path}. Using {schema_name} schema." - ) - - if schema_name == "knowledge" and version < MIN_KNOWLEDGE_VERSION: - logger.error( - f"Version {version} is not supported for knowledge taxonomy. Minimum supported version is {MIN_KNOWLEDGE_VERSION}." - ) - errors += 1 - return errors - - try: - schema_resource = retrieve(f"{schema_name}.json") - schema = schema_resource.contents - validator_cls = validator_for(schema) - validator: Validator = validator_cls( - schema, registry=Registry(retrieve=retrieve) - ) - - for validation_error in validator.iter_errors(contents): - errors += 1 - yaml_path = validation_error.json_path[1:] - if not yaml_path: - yaml_path = "." - if validation_error.validator == "minItems": - # Special handling for minItems which can have a long message for seed_examples - message = ( - f"Value must have at least {validation_error.validator_value} items" - ) - else: - message = validation_error.message[-200:] - logger.error( - f"Validation error in {taxonomy_path}: [{yaml_path}] {message}" - ) - except NoSuchResource as e: - cause = e.__cause__ if e.__cause__ is not None else e - errors += 1 - logger.error(f"Cannot load schema file {e.ref}. {cause}") - - return errors - +# pylint: disable=broad-exception-caught +def _read_taxonomy_file(file_path: str | Path, yamllint_config: str | None = None): + seed_instruction_data = [] -def _get_version(contents: Mapping) -> int: - version = contents.get("version", 1) - if not isinstance(version, int): - # schema validation will complain about the type - try: - version = int(version) - except ValueError: - version = 1 # fallback to version 1 - return version + parser = TaxonomyParser( + schema_version=0, # Use version value in yaml + message_format=TaxonomyMessageFormat.LOGGING, # Report warnings and errors to the logger + yamllint_config=yamllint_config, + yamllint_strict=True, # Report yamllint warnings as errors + ) + taxonomy = parser.parse(file_path) + if taxonomy.warnings or taxonomy.errors: + return seed_instruction_data, taxonomy.warnings, taxonomy.errors -# pylint: disable=broad-exception-caught -def _read_taxonomy_file(file_path: str, yaml_rules: Optional[str] = None): - seed_instruction_data = [] - warnings = 0 - errors = 0 - file_path = Path(file_path).resolve() - # file should end with ".yaml" explicitly - if file_path.suffix != ".yaml": - logger.warning( - f"Skipping {file_path}! Use lowercase '.yaml' extension instead." - ) - warnings += 1 - return None, warnings, errors - for i in range(len(file_path.parts) - 1, -1, -1): - if file_path.parts[i] in TAXONOMY_FOLDERS: - taxonomy_path = Path(*file_path.parts[i:]) - break - else: - taxonomy_path = file_path - # read file if extension is correct try: - with open(file_path, "r", encoding="utf-8") as file: - contents = yaml.safe_load(file) - if not contents: - logger.warning(f"Skipping {file_path} because it is empty!") - warnings += 1 - return None, warnings, errors - if not isinstance(contents, Mapping): - logger.error( - f"{file_path} is not valid. The top-level element is not an object with key-value pairs." - ) - errors += 1 - return None, warnings, errors - - # do general YAML linting if specified - version = _get_version(contents) - if version > 1: # no linting for version 1 yaml - if yaml_rules is not None: - is_file = os.path.isfile(yaml_rules) - if is_file: - logger.debug(f"Using YAML rules from {yaml_rules}") - yamllint_cmd = [ - "yamllint", - "-f", - "parsable", - "-c", - yaml_rules, - file_path, - "-s", - ] - else: - logger.debug(f"Cannot find {yaml_rules}. Using default rules.") - yamllint_cmd = [ - "yamllint", - "-f", - "parsable", - "-d", - DEFAULT_YAML_RULES, - file_path, - "-s", - ] - else: - yamllint_cmd = [ - "yamllint", - "-f", - "parsable", - "-d", - DEFAULT_YAML_RULES, - file_path, - "-s", - ] - try: - subprocess.check_output(yamllint_cmd, text=True) - except subprocess.SubprocessError as e: - lint_messages = [f"Problems found in file {file_path}"] - parsed_output = e.output.splitlines() - for p in parsed_output: - errors += 1 - delim = str(file_path) + ":" - parsed_p = p.split(delim)[1] - lint_messages.append(parsed_p) - logger.error("\n".join(lint_messages)) - return None, warnings, errors - - validation_errors = _validate_yaml(contents, taxonomy_path) - if validation_errors: - errors += validation_errors - return None, warnings, errors - # get seed instruction data - tax_path = "->".join(taxonomy_path.parent.parts) + tax_path = "->".join(taxonomy.path.parent.parts) + contents = taxonomy.contents task_description = contents.get("task_description", None) domain = contents.get("domain") documents = contents.get("document") @@ -391,18 +197,28 @@ def _read_taxonomy_file(file_path: str, yaml_rules: Optional[str] = None): } ) except Exception as e: - errors += 1 raise TaxonomyReadingException(f"Exception {e} raised in {file_path}") from e - return seed_instruction_data, warnings, errors + return seed_instruction_data, 0, 0 + +def read_taxonomy( + taxonomy: str | Path, taxonomy_base: str, yaml_rules: str | None = None +): + yamllint_config = None # If no custom rules file, use default config + if yaml_rules is not None: # user attempted to pass custom rules file + yaml_rules_path = Path(yaml_rules) + if yaml_rules_path.is_file(): # file was found, use specified config + logger.debug("Using YAML rules from %s", yaml_rules) + yamllint_config = yaml_rules_path.read_text(encoding="utf-8") + else: + logger.debug("Cannot find %s. Using default rules.", yaml_rules) -def read_taxonomy(taxonomy, taxonomy_base, yaml_rules): seed_instruction_data = [] is_file = os.path.isfile(taxonomy) if is_file: # taxonomy is file seed_instruction_data, warnings, errors = _read_taxonomy_file( - taxonomy, yaml_rules + taxonomy, yamllint_config ) if warnings: logger.warning( @@ -425,7 +241,7 @@ def read_taxonomy(taxonomy, taxonomy_base, yaml_rules): logger.debug(f"* {e}") for f in taxonomy_files: file_path = os.path.join(taxonomy, f) - data, warnings, errors = _read_taxonomy_file(file_path, yaml_rules) + data, warnings, errors = _read_taxonomy_file(file_path, yamllint_config) total_warnings += warnings total_errors += errors if data: diff --git a/tests/test_generate_data.py b/tests/test_generate_data.py index 98b87780..6cf16fda 100644 --- a/tests/test_generate_data.py +++ b/tests/test_generate_data.py @@ -25,7 +25,7 @@ TEST_TAXONOMY_BASE = "main" -TEST_CUSTOM_YAML_RULES = b"""extends: relaxed +TEST_CUSTOM_YAML_RULES = """extends: relaxed rules: line-length: max: 180 diff --git a/tests/test_taxonomy.py b/tests/test_taxonomy.py index 8c148113..371492f3 100644 --- a/tests/test_taxonomy.py +++ b/tests/test_taxonomy.py @@ -14,7 +14,7 @@ TEST_SEED_EXAMPLE = "Can you help me debug this failing unit test?" -TEST_CUSTOM_YAML_RULES = b"""extends: relaxed +TEST_CUSTOM_YAML_RULES = """extends: relaxed rules: line-length: @@ -65,6 +65,7 @@ def test_read_taxonomy_leaf_nodes( create_tracked_file, create_untracked_file, check_leaf_node_keys, + tmp_path, ): tracked_file = "compositional_skills/tracked/qna.yaml" untracked_file = "compositional_skills/new/qna.yaml" @@ -77,8 +78,10 @@ def test_read_taxonomy_leaf_nodes( if create_untracked_file: self.taxonomy.create_untracked(untracked_file, test_compositional_skill) + custom_config_yaml = tmp_path.joinpath("custom_config.yaml") + custom_config_yaml.write_text(TEST_CUSTOM_YAML_RULES, encoding="utf-8") leaf_nodes = taxonomy.read_taxonomy_leaf_nodes( - self.taxonomy.root, taxonomy_base, TEST_CUSTOM_YAML_RULES + self.taxonomy.root, taxonomy_base, str(custom_config_yaml) ) assert len(leaf_nodes) == len(check_leaf_node_keys)