diff --git a/requirements.txt b/requirements.txt index df33d5cc..90988d8f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ click>=8.1.7,<9.0.0 httpx>=0.25.0,<1.0.0 langchain-text-splitters openai>=1.13.3,<2.0.0 +platformdirs>=4.2 # Note: this dependency goes along with langchain-text-splitters and mayt be # removed once that one is removed. # do not use 8.4.0 due to a bug in the library diff --git a/src/instructlab/sdg/generate_data.py b/src/instructlab/sdg/generate_data.py index 0e3c42eb..d654dbfd 100644 --- a/src/instructlab/sdg/generate_data.py +++ b/src/instructlab/sdg/generate_data.py @@ -14,6 +14,7 @@ from datasets import Dataset import httpx import openai +import platformdirs # First Party # pylint: disable=ungrouped-imports @@ -164,29 +165,48 @@ def _gen_test_data( outfile.write("\n") +def _check_pipeline_dir(pipeline): + for file in ["knowledge.yaml", "freeform_skills.yaml", "grounded_skills.yaml"]: + if not os.path.exists(os.path.join(pipeline, file)): + raise GenerateException( + f"Error: pipeline directory ({pipeline}) does not contain {file}." + ) + + def _sdg_init(pipeline, client, model_family, model_id, num_instructions_to_generate): pipeline_pkg = None - if pipeline == "full": - pipeline_pkg = FULL_PIPELINES_PACKAGE - elif pipeline == "simple": - pipeline_pkg = SIMPLE_PIPELINES_PACKAGE + + # Search for the pipeline in User and Site data directories + # then for a package defined pipeline + # and finally pipelines referenced by absolute path + pd = platformdirs.PlatformDirs( + appname=os.path.join("instructlab", "sdg"), multipath=True + ) + for d in pd.iter_data_dirs(): + if os.path.exists(os.path.join(d, pipeline)): + pipeline = os.path.join(d, pipeline) + _check_pipeline_dir(pipeline) + break else: - # Validate that pipeline is a valid directory and that it contains the required files - if not os.path.exists(pipeline): - raise GenerateException( - f"Error: pipeline directory ({pipeline}) does not exist." - ) - for file in ["knowledge.yaml", "freeform_skills.yaml", "grounded_skills.yaml"]: - if not os.path.exists(os.path.join(pipeline, file)): + if pipeline == "full": + pipeline_pkg = FULL_PIPELINES_PACKAGE + elif pipeline == "simple": + pipeline_pkg = SIMPLE_PIPELINES_PACKAGE + else: + # Validate that pipeline is a valid directory and that it contains the required files + if not os.path.exists(pipeline): raise GenerateException( - f"Error: pipeline directory ({pipeline}) does not contain {file}." + f"Error: pipeline directory ({pipeline}) does not exist." ) + _check_pipeline_dir(pipeline) ctx = PipelineContext(client, model_family, model_id, num_instructions_to_generate) def load_pipeline(yaml_basename): if pipeline_pkg: - with resources.path(pipeline_pkg, yaml_basename) as yaml_path: + with resources.as_file( + resources.files(pipeline_pkg).joinpath(yaml_basename) + ) as yaml_path: return Pipeline.from_file(ctx, yaml_path) else: return Pipeline.from_file(ctx, os.path.join(pipeline, yaml_basename)) @@ -236,7 +256,8 @@ def generate_data( use the SDG library constructs directly, and this function will likely be removed. Args: - pipeline: This argument may be either an alias defined by the sdg library ("simple", "full"), + pipeline: This argument may be either an alias defined in a user or site "data directory" + or an alias defined by the sdg library ("simple", "full")(if the data directory has no matches), or an absolute path to a directory containing the pipeline YAML files. We expect three files to be present in this directory: "knowledge.yaml", "freeform_skills.yaml", and "grounded_skills.yaml".