Skip to content

Commit

Permalink
Merge pull request #45 from tiran/entry_point
Browse files Browse the repository at this point in the history
Add entry points for evaluator classes
  • Loading branch information
alimaredia authored Jul 3, 2024
2 parents 684d142 + f6ac656 commit 6d7dc6f
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 0 deletions.
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ homepage = "https://instructlab.ai"
source = "https://github.com/instructlab/eval"
issues = "https://github.com/instructlab/eval/issues"

[project.entry-points."instructlab.eval.evaluator"]
"mmlu" = "instructlab.eval.mmlu:MMLUEvaluator"
"mmlu_branch" = "instructlab.eval.mmlu:MMLUBranchEvaluator"
"mt_bench" = "instructlab.eval.mt_bench:MTBenchEvaluator"
"mt_bench_branch" = "instructlab.eval.mt_bench:MTBenchBranchEvaluator"

[tool.setuptools_scm]
version_file = "src/instructlab/eval/_version.py"
# do not include +gREV local version, required for Test PyPI upload
Expand Down
2 changes: 2 additions & 0 deletions src/instructlab/eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,7 @@ class Evaluator:
Parent class for Evaluators
"""

name: str

def __init__(self) -> None:
pass
4 changes: 4 additions & 0 deletions src/instructlab/eval/mmlu.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ class MMLUEvaluator(Evaluator):
batch_size number of GPUs
"""

name = "mmlu"

def __init__(
self,
model_path,
Expand Down Expand Up @@ -147,6 +149,8 @@ class MMLUBranchEvaluator(Evaluator):
batch_size number of GPUs
"""

name = "mmlu_branch"

def __init__(
self,
model_path,
Expand Down
4 changes: 4 additions & 0 deletions src/instructlab/eval/mt_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class MTBenchEvaluator(Evaluator):
max_workers Max parallel workers to run the evaluation with
"""

name = "mt_bench"

def __init__(
self,
model_name: str,
Expand Down Expand Up @@ -82,6 +84,8 @@ class MTBenchBranchEvaluator(Evaluator):
max_workers Max parallel workers to run the evaluation with
"""

name = "mt_bench_branch"

def __init__(
self,
model_name: str,
Expand Down
29 changes: 29 additions & 0 deletions tests/test_project.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# SPDX-License-Identifier: Apache-2.0
# Standard
from importlib.metadata import entry_points

# First Party
from instructlab.eval.evaluator import Evaluator
from instructlab.eval.mmlu import MMLUBranchEvaluator, MMLUEvaluator
from instructlab.eval.mt_bench import MTBenchBranchEvaluator, MTBenchEvaluator


def test_evaluator_eps():
expected = {
"mmlu": MMLUEvaluator,
"mmlu_branch": MMLUBranchEvaluator,
"mt_bench": MTBenchEvaluator,
"mt_bench_branch": MTBenchBranchEvaluator,
}
eps = entry_points(group="instructlab.eval.evaluator")
found = {}
for ep in eps:
# different project
if not ep.module.startswith("instructlab.eval"):
continue
evaluator = ep.load()
assert issubclass(evaluator, Evaluator)
assert evaluator.name == ep.name
found[ep.name] = evaluator

assert found == expected

0 comments on commit 6d7dc6f

Please sign in to comment.