-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Databricks finetuning integration (#1770)
* Databricks finetuning small fix some fixes * add examples
- Loading branch information
1 parent
b60de6e
commit cdee3b4
Showing
6 changed files
with
565 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,361 @@ | ||
import logging | ||
import os | ||
import re | ||
import time | ||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union | ||
|
||
import requests | ||
import ujson | ||
|
||
from dspy.clients.provider import Provider, TrainingJob | ||
from dspy.clients.utils_finetune import DataFormat, get_finetune_directory | ||
|
||
if TYPE_CHECKING: | ||
from databricks.sdk import WorkspaceClient | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class TrainingJobDatabricks(TrainingJob): | ||
def __init__(self, finetuning_run=None, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.finetuning_run = finetuning_run | ||
self.launch_started = False | ||
self.launch_completed = False | ||
self.endpoint_name = None | ||
|
||
def status(self): | ||
if not self.finetuning_run: | ||
return None | ||
try: | ||
from databricks.model_training import foundation_model as fm | ||
except ImportError: | ||
raise ImportError( | ||
"To use Databricks finetuning, please install the databricks_genai package via " | ||
"`pip install databricks_genai`." | ||
) | ||
run = fm.get(self.finetuning_run) | ||
return run.status | ||
|
||
|
||
class DatabricksProvider(Provider): | ||
finetunable = True | ||
TrainingJob = TrainingJobDatabricks | ||
|
||
@staticmethod | ||
def is_provider_model(model: str) -> bool: | ||
# We don't automatically infer Databricks models because Databricks is not a proprietary model provider. | ||
return False | ||
|
||
@staticmethod | ||
def deploy_finetuned_model( | ||
model: str, | ||
data_format: Optional[DataFormat] = None, | ||
databricks_host: Optional[str] = None, | ||
databricks_token: Optional[str] = None, | ||
deploy_timeout: int = 900, | ||
): | ||
workspace_client = _get_workspace_client() | ||
model_version = next(workspace_client.model_versions.list(model)).version | ||
|
||
# Allow users to override the host and token. This is useful on Databricks hosted runtime. | ||
databricks_host = databricks_host or workspace_client.config.host | ||
databricks_token = databricks_token or workspace_client.config.token | ||
|
||
headers = {"Context-Type": "text/json", "Authorization": f"Bearer {databricks_token}"} | ||
|
||
optimizable_info = requests.get( | ||
url=f"{databricks_host}/api/2.0/serving-endpoints/get-model-optimization-info/{model}/{model_version}", | ||
headers=headers, | ||
).json() | ||
|
||
if "optimizable" not in optimizable_info or not optimizable_info["optimizable"]: | ||
raise ValueError(f"Model is not eligible for provisioned throughput: {optimizable_info}") | ||
|
||
chunk_size = optimizable_info["throughput_chunk_size"] | ||
|
||
# Minimum desired provisioned throughput | ||
min_provisioned_throughput = 0 | ||
|
||
# Maximum desired provisioned throughput | ||
max_provisioned_throughput = chunk_size | ||
|
||
# Databricks serving endpoint names cannot contain ".". | ||
model_name = model.replace(".", "_") | ||
|
||
get_endpoint_response = requests.get( | ||
url=f"{databricks_host}/api/2.0/serving-endpoints/{model_name}", json={"name": model_name}, headers=headers | ||
) | ||
|
||
if get_endpoint_response.status_code == 200: | ||
logger.info( | ||
f"Serving endpoint {model_name} already exists, updating it instead of creating a new one." | ||
) | ||
# The serving endpoint already exists, we will update it instead of creating a new one. | ||
data = { | ||
"served_entities": [ | ||
{ | ||
"name": model_name, | ||
"entity_name": model, | ||
"entity_version": model_version, | ||
"min_provisioned_throughput": min_provisioned_throughput, | ||
"max_provisioned_throughput": max_provisioned_throughput, | ||
} | ||
] | ||
} | ||
|
||
response = requests.put( | ||
url=f"{databricks_host}/api/2.0/serving-endpoints/{model_name}/config", | ||
json=data, | ||
headers=headers, | ||
) | ||
else: | ||
logger.info( | ||
f"Creating serving endpoint {model_name} on Databricks model serving!" | ||
) | ||
# Send the POST request to create the serving endpoint | ||
data = { | ||
"name": model_name, | ||
"config": { | ||
"served_entities": [ | ||
{ | ||
"name": model_name, | ||
"entity_name": model, | ||
"entity_version": model_version, | ||
"min_provisioned_throughput": min_provisioned_throughput, | ||
"max_provisioned_throughput": max_provisioned_throughput, | ||
} | ||
] | ||
}, | ||
} | ||
|
||
response = requests.post(url=f"{databricks_host}/api/2.0/serving-endpoints", json=data, headers=headers) | ||
|
||
if response.status_code == 200: | ||
logger.info( | ||
f"Successfully started creating/updating serving endpoint {model_name} on Databricks model serving!" | ||
) | ||
else: | ||
raise ValueError(f"Failed to create serving endpoint: {response.json()}.") | ||
|
||
logger.info( | ||
f"Waiting for serving endpoint {model_name} to be ready, this might take a few minutes... You can check " | ||
f"the status of the endpoint at {databricks_host}/ml/endpoints/{model_name}" | ||
) | ||
from openai import OpenAI | ||
|
||
client = OpenAI( | ||
api_key=databricks_token, | ||
base_url=f"{databricks_host}/serving-endpoints", | ||
) | ||
# Wait for the deployment to be ready. | ||
num_retries = deploy_timeout // 60 | ||
for _ in range(num_retries): | ||
try: | ||
if data_format == DataFormat.chat: | ||
client.chat.completions.create( | ||
messages=[{"role": "user", "content": "hi"}], model=model_name, max_tokens=1 | ||
) | ||
elif data_format == DataFormat.completion: | ||
client.completions.create(prompt="hi", model=model_name, max_tokens=1) | ||
logger.info(f"Databricks model serving endpoint {model_name} is ready!") | ||
return | ||
except Exception: | ||
time.sleep(60) | ||
|
||
raise ValueError( | ||
f"Failed to create serving endpoint {model_name} on Databricks model serving platform within " | ||
f"{deploy_timeout} seconds." | ||
) | ||
|
||
@staticmethod | ||
def finetune( | ||
job: TrainingJobDatabricks, | ||
model: str, | ||
train_data: List[Dict[str, Any]], | ||
train_kwargs: Optional[Dict[str, Any]] = None, | ||
data_format: Optional[Union[DataFormat, str]] = None, | ||
) -> str: | ||
if isinstance(data_format, str): | ||
if data_format == "chat": | ||
data_format = DataFormat.chat | ||
elif data_format == "completion": | ||
data_format = DataFormat.completion | ||
else: | ||
raise ValueError( | ||
f"String `data_format` must be one of 'chat' or 'completion', but received: {data_format}." | ||
) | ||
|
||
if "train_data_path" not in train_kwargs: | ||
raise ValueError("The `train_data_path` must be provided to finetune on Databricks.") | ||
# Add the file name to the directory path. | ||
train_kwargs["train_data_path"] = DatabricksProvider.upload_data( | ||
train_data, train_kwargs["train_data_path"], data_format | ||
) | ||
|
||
try: | ||
from databricks.model_training import foundation_model as fm | ||
except ImportError: | ||
raise ImportError( | ||
"To use Databricks finetuning, please install the databricks_genai package via " | ||
"`pip install databricks_genai`." | ||
) | ||
|
||
if "register_to" not in train_kwargs: | ||
raise ValueError("The `register_to` must be provided to finetune on Databricks.") | ||
|
||
# Allow users to override the host and token. This is useful on Databricks hosted runtime. | ||
databricks_host = train_kwargs.pop("databricks_host", None) | ||
databricks_token = train_kwargs.pop("databricks_token", None) | ||
|
||
skip_deploy = train_kwargs.pop("skip_deploy", False) | ||
deploy_timeout = train_kwargs.pop("deploy_timeout", 900) | ||
|
||
logger.info("Starting finetuning on Databricks... this might take a few minutes to finish.") | ||
finetuning_run = fm.create( | ||
model=model, | ||
**train_kwargs, | ||
) | ||
|
||
job.run = finetuning_run | ||
|
||
# Wait for the finetuning run to be ready. | ||
while True: | ||
job.run = fm.get(job.run) | ||
if job.run.status.display_name == "Completed": | ||
logger.info("Finetuning run completed successfully!") | ||
break | ||
elif job.run.status.display_name == "Failed": | ||
raise ValueError( | ||
f"Finetuning run failed with status: {job.run.status.display_name}. Please check the Databricks " | ||
f"workspace for more details. Finetuning job's metadata: {job.run}." | ||
) | ||
else: | ||
time.sleep(60) | ||
|
||
if skip_deploy: | ||
return None | ||
|
||
job.launch_started = True | ||
model_to_deploy = train_kwargs.get("register_to") | ||
job.endpoint_name = model_to_deploy.replace(".", "_") | ||
DatabricksProvider.deploy_finetuned_model( | ||
model_to_deploy, data_format, databricks_host, databricks_token, deploy_timeout | ||
) | ||
job.launch_completed = True | ||
# The finetuned model name should be in the format: "databricks/<endpoint_name>". | ||
return f"databricks/{job.endpoint_name}" | ||
|
||
@staticmethod | ||
def upload_data(train_data: List[Dict[str, Any]], databricks_unity_catalog_path: str, data_format: DataFormat): | ||
logger.info("Uploading finetuning data to Databricks Unity Catalog...") | ||
file_path = _save_data_to_local_file(train_data, data_format) | ||
|
||
w = _get_workspace_client() | ||
_create_directory_in_databricks_unity_catalog(w, databricks_unity_catalog_path) | ||
|
||
try: | ||
with open(file_path, "rb") as f: | ||
target_path = os.path.join(databricks_unity_catalog_path, os.path.basename(file_path)) | ||
w.files.upload(target_path, f, overwrite=True) | ||
logger.info("Successfully uploaded finetuning data to Databricks Unity Catalog!") | ||
return target_path | ||
except Exception as e: | ||
raise ValueError(f"Failed to upload finetuning data to Databricks Unity Catalog: {e}") | ||
|
||
|
||
def _get_workspace_client() -> "WorkspaceClient": | ||
try: | ||
from databricks.sdk import WorkspaceClient | ||
except ImportError: | ||
raise ImportError( | ||
"To use Databricks finetuning, please install the databricks-sdk package via " | ||
"`pip install databricks-sdk`." | ||
) | ||
return WorkspaceClient() | ||
|
||
|
||
def _create_directory_in_databricks_unity_catalog(w: "WorkspaceClient", databricks_unity_catalog_path: str): | ||
pattern = r"^/Volumes/(?P<catalog>[^/]+)/(?P<schema>[^/]+)/(?P<volume>[^/]+)(/[^/]+)+$" | ||
match = re.match(pattern, databricks_unity_catalog_path) | ||
if not match: | ||
raise ValueError( | ||
f"Databricks Unity Catalog path must be in the format '/Volumes/<catalog>/<schema>/<volume>/...', but " | ||
f"received: {databricks_unity_catalog_path}." | ||
) | ||
|
||
catalog = match.group("catalog") | ||
schema = match.group("schema") | ||
volume = match.group("volume") | ||
|
||
try: | ||
volume_path = f"{catalog}.{schema}.{volume}" | ||
w.volumes.read(volume_path) | ||
except Exception: | ||
raise ValueError( | ||
f"Databricks Unity Catalog volume does not exist: {volume_path}, please create it on the Databricks " | ||
"workspace." | ||
) | ||
|
||
try: | ||
w.files.get_directory_metadata(databricks_unity_catalog_path) | ||
logger.info(f"Directory {databricks_unity_catalog_path} already exists, skip creating it.") | ||
except Exception: | ||
# Create the directory if it doesn't exist, we don't raise an error because this is a common case. | ||
logger.info(f"Creating directory {databricks_unity_catalog_path} in Databricks Unity Catalog...") | ||
w.files.create_directory(databricks_unity_catalog_path) | ||
logger.info(f"Successfully created directory {databricks_unity_catalog_path} in Databricks Unity Catalog!") | ||
|
||
|
||
def _save_data_to_local_file(train_data: List[Dict[str, Any]], data_format: DataFormat): | ||
import uuid | ||
|
||
file_name = f"finetuning_{uuid.uuid4()}.jsonl" | ||
|
||
finetune_dir = get_finetune_directory() | ||
file_path = os.path.join(finetune_dir, file_name) | ||
file_path = os.path.abspath(file_path) | ||
with open(file_path, "w") as f: | ||
for item in train_data: | ||
if data_format == DataFormat.chat: | ||
_validate_chat_data(item) | ||
elif data_format == DataFormat.completion: | ||
_validate_completion_data(item) | ||
|
||
f.write(ujson.dumps(item) + "\n") | ||
return file_path | ||
|
||
|
||
def _validate_chat_data(data: Dict[str, Any]): | ||
if "messages" not in data: | ||
raise ValueError( | ||
"Each finetuning data must be a dict with a 'messages' key when `task=CHAT_COMPLETION`, but " | ||
f"received: {data}" | ||
) | ||
|
||
if not isinstance(data["messages"], list): | ||
raise ValueError( | ||
"The value of the 'messages' key in each finetuning data must be a list of dicts with keys 'role' and " | ||
f"'content' when `task=CHAT_COMPLETION`, but received: {data['messages']}" | ||
) | ||
|
||
for message in data["messages"]: | ||
if "role" not in message: | ||
raise ValueError(f"Each message in the 'messages' list must contain a 'role' key, but received: {message}.") | ||
if "content" not in message: | ||
raise ValueError( | ||
f"Each message in the 'messages' list must contain a 'content' key, but received: {message}." | ||
) | ||
|
||
|
||
def _validate_completion_data(data: Dict[str, Any]): | ||
if "prompt" not in data: | ||
raise ValueError( | ||
"Each finetuning data must be a dict with a 'prompt' key when `task=INSTRUCTION_FINETUNE`, but " | ||
f"received: {data}" | ||
) | ||
if "response" not in data and "completion" not in data: | ||
raise ValueError( | ||
"Each finetuning data must be a dict with a 'response' or 'completion' key when " | ||
f"`task=INSTRUCTION_FINETUNE`, but received: {data}" | ||
) |
Oops, something went wrong.