Skip to content

Commit

Permalink
Add inference engines to the catalog (#1394)
Browse files Browse the repository at this point in the history
Signed-off-by: Martín Santillán Cooper <[email protected]>
  • Loading branch information
martinscooper authored Nov 26, 2024
1 parent d116a0b commit 746394e
Show file tree
Hide file tree
Showing 9 changed files with 71 additions and 0 deletions.
13 changes: 13 additions & 0 deletions prepare/engines/cross_provider/llama3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from unitxt.catalog import add_to_catalog
from unitxt.inference import CrossProviderInferenceEngine

model_list = ["meta-llama/llama-3-8b-instruct", "meta-llama/llama-3-70b-instruct"]

for model in model_list:
model_label = model.split("/")[1].replace("-", "_").replace(".", ",").lower()
inference_model = CrossProviderInferenceEngine(
model=model, provider="watsonx", max_tokens=2048, seed=42
)
add_to_catalog(
inference_model, f"engines.cross_provider.{model_label}", overwrite=True
)
7 changes: 7 additions & 0 deletions prepare/engines/openai/gpt4o.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from unitxt.catalog import add_to_catalog
from unitxt.inference import OpenAiInferenceEngine

model_name = "gpt-4o"
model_label = model_name.replace("-", "_").lower()
inference_model = OpenAiInferenceEngine(model_name=model_name, max_tokens=2048, seed=42)
add_to_catalog(inference_model, f"engines.openai.{model_label}", overwrite=True)
13 changes: 13 additions & 0 deletions prepare/engines/rits/llama3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from unitxt.catalog import add_to_catalog
from unitxt.inference import RITSInferenceEngine

model_list = [
"meta-llama/Llama-3.1-8B-Instruct",
"meta-llama/llama-3-1-70b-instruct",
"meta-llama/llama-3-1-405b-instruct-fp8",
]

for model in model_list:
model_label = model.split("/")[1].replace("-", "_").replace(",", "_").lower()
inference_model = RITSInferenceEngine(model_name=model, max_tokens=2048, seed=42)
add_to_catalog(inference_model, f"engines.rits.{model_label}", overwrite=True)
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"__type__": "cross_provider_inference_engine",
"model": "meta-llama/llama-3-70b-instruct",
"provider": "watsonx",
"max_tokens": 2048,
"seed": 42
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"__type__": "cross_provider_inference_engine",
"model": "meta-llama/llama-3-8b-instruct",
"provider": "watsonx",
"max_tokens": 2048,
"seed": 42
}
6 changes: 6 additions & 0 deletions src/unitxt/catalog/engines/openai/gpt_4o.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"__type__": "open_ai_inference_engine",
"model_name": "gpt-4o",
"max_tokens": 2048,
"seed": 42
}
6 changes: 6 additions & 0 deletions src/unitxt/catalog/engines/rits/llama_3/1_8b_instruct.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"__type__": "rits_inference_engine",
"model_name": "meta-llama/Llama-3.1-8B-Instruct",
"max_tokens": 2048,
"seed": 42
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"__type__": "rits_inference_engine",
"model_name": "meta-llama/llama-3-1-405b-instruct-fp8",
"max_tokens": 2048,
"seed": 42
}
6 changes: 6 additions & 0 deletions src/unitxt/catalog/engines/rits/llama_3_1_70b_instruct.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"__type__": "rits_inference_engine",
"model_name": "meta-llama/llama-3-1-70b-instruct",
"max_tokens": 2048,
"seed": 42
}

0 comments on commit 746394e

Please sign in to comment.