Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parameterize TLS configuration options #243

Merged
merged 4 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ metadata:
data:
endpoint: '<YOUR_MIXTRAL_MODEL_ENDPOINT>'
model: mixtral
ca.crt: | # If using TLS
-----BEGIN CERTIFICATE-----
<TLS Certificate to Teacher Model>
-----END CERTIFICATE-----
```

```yaml
Expand All @@ -110,6 +114,10 @@ metadata:
data:
endpoint: '<YOUR_PROMETHEUS_MODEL_ENDPOINT>'
model: prometheus
ca.crt: | # If using TLS
-----BEGIN CERTIFICATE-----
<TLS Certificate to Judge Model>
-----END CERTIFICATE-----
```

```yaml
Expand All @@ -122,6 +130,8 @@ data:
type: Opaque
```

**NOTE**: You can find and copy the certs needed for the teacher- and judge-server ConfigMaps in another ConfigMap, `kube-root-ca.crt`, found in the same namespace as the hosted model


### Run the Pipeline

Expand Down
30 changes: 11 additions & 19 deletions eval/final/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,30 +25,20 @@ def run_final_eval_op(
import os
import subprocess

import httpx
import torch
from instructlab.eval.mmlu import MMLUBranchEvaluator
from instructlab.eval.mt_bench import MTBenchBranchEvaluator
from instructlab.model.evaluate import qa_pairs_to_qna_to_avg_scores, sort_score

if judge_ca_cert := os.getenv("JUDGE_CA_CERT_PATH"):
import httpx
import openai

# Create a custom HTTP client
class CustomHttpClient(httpx.Client):
def __init__(self, *args, **kwargs):
# Use the custom CA certificate
kwargs.setdefault("verify", judge_ca_cert)
super().__init__(*args, **kwargs)

# Create a new OpenAI class that uses the custom HTTP client
class CustomOpenAI(openai.OpenAI):
def __init__(self, *args, **kwargs):
custom_client = CustomHttpClient()
super().__init__(http_client=custom_client, *args, **kwargs)

# Monkey patch the OpenAI class in the openai module, so that the eval lib can use it
openai.OpenAI = CustomOpenAI
judge_api_key = os.getenv("JUDGE_API_KEY", "")
judge_model_name = os.getenv("JUDGE_NAME")
judge_endpoint = os.getenv("JUDGE_ENDPOINT")
judge_ca_cert_path = os.getenv("JUDGE_CA_CERT_PATH")
use_tls = os.path.exists(judge_ca_cert_path) and (
os.path.getsize(judge_ca_cert_path) > 0
)
judge_http_client = httpx.Client(verify=judge_ca_cert_path) if use_tls else None

print("Starting Final Eval...")

Expand Down Expand Up @@ -408,6 +398,7 @@ def find_node_dataset_directories(base_dir: str):
server_url=vllm_server,
serving_gpus=gpu_count,
max_workers=max_workers,
http_client=judge_http_client,
)

shutdown_vllm(vllm_process)
Expand All @@ -418,6 +409,7 @@ def find_node_dataset_directories(base_dir: str):
api_key=judge_api_key,
serving_gpus=gpu_count,
max_workers=max_workers,
http_client=judge_http_client,
)

qa_pairs_and_errors.append((overall_score, qa_pairs, error_rate))
Expand Down
36 changes: 12 additions & 24 deletions eval/mt_bench/components.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# type: ignore
# pylint: disable=no-value-for-parameter,import-outside-toplevel,import-error
from typing import List, NamedTuple, Optional
from typing import NamedTuple, Optional

from kfp.dsl import component

Expand All @@ -22,28 +22,18 @@ def run_mt_bench_op(
import os
import subprocess

import httpx
import torch
from instructlab.eval.mt_bench import MTBenchEvaluator

if judge_ca_cert := os.getenv("JUDGE_CA_CERT_PATH"):
import httpx
import openai

# Create a custom HTTP client
class CustomHttpClient(httpx.Client):
def __init__(self, *args, **kwargs):
# Use the custom CA certificate
kwargs.setdefault("verify", judge_ca_cert)
super().__init__(*args, **kwargs)

# Create a new OpenAI class that uses the custom HTTP client
class CustomOpenAI(openai.OpenAI):
def __init__(self, *args, **kwargs):
custom_client = CustomHttpClient()
super().__init__(http_client=custom_client, *args, **kwargs)

# Monkey patch the OpenAI class in the openai module, so that the eval lib can use it
openai.OpenAI = CustomOpenAI
judge_api_key = os.getenv("JUDGE_API_KEY", "")
judge_model_name = os.getenv("JUDGE_NAME")
judge_endpoint = os.getenv("JUDGE_ENDPOINT")
judge_ca_cert_path = os.getenv("JUDGE_CA_CERT_PATH")
use_tls = os.path.exists(judge_ca_cert_path) and (
os.path.getsize(judge_ca_cert_path) > 0
)
judge_http_client = httpx.Client(verify=judge_ca_cert_path) if use_tls else None

def launch_vllm(
model_path: str, gpu_count: int, retries: int = 120, delay: int = 10
Expand Down Expand Up @@ -136,10 +126,6 @@ def shutdown_vllm(process: subprocess.Popen, timeout: int = 20):

models_list = os.listdir(models_folder)

judge_api_key = os.getenv("JUDGE_API_KEY", "")
judge_model_name = os.getenv("JUDGE_NAME")
judge_endpoint = os.getenv("JUDGE_ENDPOINT")

scores = {}
all_mt_bench_data = []

Expand Down Expand Up @@ -175,6 +161,7 @@ def shutdown_vllm(process: subprocess.Popen, timeout: int = 20):
server_url=vllm_server,
serving_gpus=gpu_count,
max_workers=max_workers,
http_client=judge_http_client,
)

shutdown_vllm(vllm_process)
Expand All @@ -184,6 +171,7 @@ def shutdown_vllm(process: subprocess.Popen, timeout: int = 20):
api_key=judge_api_key,
serving_gpus=gpu_count,
max_workers=max_workers,
http_client=judge_http_client,
)

mt_bench_data = {
Expand Down
35 changes: 34 additions & 1 deletion pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# type: ignore
# pylint: disable=no-value-for-parameter,import-outside-toplevel,import-error,no-member
import os
import typing
from typing import List, Literal, Optional

Expand All @@ -9,8 +10,8 @@
CreatePVC,
DeletePVC,
mount_pvc,
set_image_pull_policy,
use_config_map_as_env,
use_config_map_as_volume,
use_secret_as_env,
use_secret_as_volume,
)
Expand All @@ -26,6 +27,15 @@
GENERATED_STANDALONE_FILE_NAME = "standalone.py"
DEFAULT_REPO_URL = "https://github.com/instructlab/taxonomy.git"

# Model Serving SSL connection
SDG_CA_CERT_CM_KEY = "ca.crt"
SDG_CA_CERT_ENV_VAR_NAME = "SDG_CA_CERT_PATH"
SDG_CA_CERT_PATH = "/tmp/cert"

JUDGE_CA_CERT_CM_KEY = "ca.crt"
JUDGE_CA_CERT_ENV_VAR_NAME = "JUDGE_CA_CERT_PATH"
JUDGE_CA_CERT_PATH = "/tmp/cert"


def ilab_pipeline_wrapper(mock: List[Literal[MOCKED_STAGES]]):
"""Wrapper for KFP pipeline, which allows for mocking individual stages."""
Expand Down Expand Up @@ -187,6 +197,13 @@ def pipeline(
sdg_task, TEACHER_CONFIG_MAP, dict(endpoint="endpoint", model="model")
)
use_secret_as_env(sdg_task, TEACHER_SECRET, {"api_key": "api_key"})
use_config_map_as_volume(
sdg_task, TEACHER_CONFIG_MAP, mount_path=SDG_CA_CERT_PATH
)
sdg_task.set_env_variable(
SDG_CA_CERT_ENV_VAR_NAME, os.path.join(SDG_CA_CERT_PATH, SDG_CA_CERT_CM_KEY)
)

sdg_task.after(git_clone_task)
mount_pvc(
task=sdg_task,
Expand Down Expand Up @@ -349,6 +366,14 @@ def pipeline(
)
use_secret_as_env(run_mt_bench_task, JUDGE_SECRET, {"api_key": "JUDGE_API_KEY"})

use_config_map_as_volume(
run_mt_bench_task, JUDGE_CONFIG_MAP, mount_path=JUDGE_CA_CERT_PATH
)
run_mt_bench_task.set_env_variable(
JUDGE_CA_CERT_ENV_VAR_NAME,
os.path.join(JUDGE_CA_CERT_PATH, JUDGE_CA_CERT_CM_KEY),
)

# uncomment if updating image with same tag
# set_image_pull_policy(run_mt_bench_task, "Always")

Expand Down Expand Up @@ -391,6 +416,14 @@ def pipeline(

use_secret_as_env(final_eval_task, JUDGE_SECRET, {"api_key": "JUDGE_API_KEY"})

use_config_map_as_volume(
final_eval_task, JUDGE_CONFIG_MAP, mount_path=JUDGE_CA_CERT_PATH
)
final_eval_task.set_env_variable(
JUDGE_CA_CERT_ENV_VAR_NAME,
os.path.join(JUDGE_CA_CERT_PATH, JUDGE_CA_CERT_CM_KEY),
)

final_eval_task.after(run_mt_bench_task)
final_eval_task.set_accelerator_type("nvidia.com/gpu")
final_eval_task.set_accelerator_limit(1)
Expand Down
Loading
Loading