Skip to content

Commit

Permalink
Merge pull request instructlab#166 from derekhiggins/load-pipes
Browse files Browse the repository at this point in the history
Load custom pipelines from shared data dir
  • Loading branch information
russellb authored Jul 18, 2024
2 parents 263372b + a78525e commit e4765b9
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 14 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ click>=8.1.7,<9.0.0
httpx>=0.25.0,<1.0.0
langchain-text-splitters
openai>=1.13.3,<2.0.0
platformdirs>=4.2
# Note: this dependency goes along with langchain-text-splitters and mayt be
# removed once that one is removed.
# do not use 8.4.0 due to a bug in the library
Expand Down
49 changes: 35 additions & 14 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from datasets import Dataset
import httpx
import openai
import platformdirs

# First Party
# pylint: disable=ungrouped-imports
Expand Down Expand Up @@ -164,29 +165,48 @@ def _gen_test_data(
outfile.write("\n")


def _check_pipeline_dir(pipeline):
for file in ["knowledge.yaml", "freeform_skills.yaml", "grounded_skills.yaml"]:
if not os.path.exists(os.path.join(pipeline, file)):
raise GenerateException(
f"Error: pipeline directory ({pipeline}) does not contain {file}."
)


def _sdg_init(pipeline, client, model_family, model_id, num_instructions_to_generate):
pipeline_pkg = None
if pipeline == "full":
pipeline_pkg = FULL_PIPELINES_PACKAGE
elif pipeline == "simple":
pipeline_pkg = SIMPLE_PIPELINES_PACKAGE

# Search for the pipeline in User and Site data directories
# then for a package defined pipeline
# and finally pipelines referenced by absolute path
pd = platformdirs.PlatformDirs(
appname=os.path.join("instructlab", "sdg"), multipath=True
)
for d in pd.iter_data_dirs():
if os.path.exists(os.path.join(d, pipeline)):
pipeline = os.path.join(d, pipeline)
_check_pipeline_dir(pipeline)
break
else:
# Validate that pipeline is a valid directory and that it contains the required files
if not os.path.exists(pipeline):
raise GenerateException(
f"Error: pipeline directory ({pipeline}) does not exist."
)
for file in ["knowledge.yaml", "freeform_skills.yaml", "grounded_skills.yaml"]:
if not os.path.exists(os.path.join(pipeline, file)):
if pipeline == "full":
pipeline_pkg = FULL_PIPELINES_PACKAGE
elif pipeline == "simple":
pipeline_pkg = SIMPLE_PIPELINES_PACKAGE
else:
# Validate that pipeline is a valid directory and that it contains the required files
if not os.path.exists(pipeline):
raise GenerateException(
f"Error: pipeline directory ({pipeline}) does not contain {file}."
f"Error: pipeline directory ({pipeline}) does not exist."
)
_check_pipeline_dir(pipeline)

ctx = PipelineContext(client, model_family, model_id, num_instructions_to_generate)

def load_pipeline(yaml_basename):
if pipeline_pkg:
with resources.path(pipeline_pkg, yaml_basename) as yaml_path:
with resources.as_file(
resources.files(pipeline_pkg).joinpath(yaml_basename)
) as yaml_path:
return Pipeline.from_file(ctx, yaml_path)
else:
return Pipeline.from_file(ctx, os.path.join(pipeline, yaml_basename))
Expand Down Expand Up @@ -236,7 +256,8 @@ def generate_data(
use the SDG library constructs directly, and this function will likely be removed.
Args:
pipeline: This argument may be either an alias defined by the sdg library ("simple", "full"),
pipeline: This argument may be either an alias defined in a user or site "data directory"
or an alias defined by the sdg library ("simple", "full")(if the data directory has no matches),
or an absolute path to a directory containing the pipeline YAML files.
We expect three files to be present in this directory: "knowledge.yaml",
"freeform_skills.yaml", and "grounded_skills.yaml".
Expand Down

0 comments on commit e4765b9

Please sign in to comment.