Skip to content
This repository has been archived by the owner on May 10, 2024. It is now read-only.

Commit

Permalink
Merge pull request #6 from Trainy-ai/patchfix
Browse files Browse the repository at this point in the history
Patchfix
  • Loading branch information
asaiacai authored Sep 25, 2023
2 parents bc2d1d0 + c859d14 commit 2eff234
Show file tree
Hide file tree
Showing 11 changed files with 76 additions and 36 deletions.
2 changes: 1 addition & 1 deletion docs/source/quickstart/finetuning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ To do a vicuna finetune of your first model through LLM-ATC, run the following
.. code-block:: console
# start training
$ llm-atc train --model_type vicuna --finetune_data ./vicuna_test.json --name myvicuna --description "This is a finetuned model that just says its name is vicuna" -c mycluster --cloud gcp --envs "MODEL_SIZE=7 WANDB_API_KEY=<my wandb key>" --accelerator A100-80G:4
$ llm-atc train --model_type vicuna --finetune_data ./vicuna_test.json --name myvicuna --checkpoint_bucket my-trainy-bucket --checkpoint_path ~/test_vicuna --checkpoint_store S3 --description "This is a finetuned model that just says its name is vicuna" -c mycluster --cloud gcp --envs "MODEL_SIZE=7 WANDB_API_KEY=<my wandb key>" --accelerator A100-80G:4
# Once training is done, shutdown the cluster
$ sky down
Expand Down
2 changes: 1 addition & 1 deletion docs/source/quickstart/serving.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ by using the :code:`llm-atc/` prefix.
.. code-block:: console
# serve an llm-atc finetuned model, requires `llm-atc/` prefix and grabs model checkpoint from object store
$ llm-atc serve --name llm-atc/myvicuna --accelerator A100:1 -c servecluster --cloud gcp --region asia-southeast1 --envs "HF_TOKEN=<HuggingFace_token>"
$ llm-atc serve --name llm-atc/myvicuna --source s3://my-bucket/my_vicuna/ --accelerator A100:1 -c servecluster --cloud gcp --region asia-southeast1 --envs "HF_TOKEN=<HuggingFace_token>"
# serve a HuggingFace model, e.g. `lmsys/vicuna-13b-v1.3`
$ llm-atc serve --name lmsys/vicuna-13b-v1.3 --accelerator A100:1 -c servecluster --cloud gcp --region asia-southeast1 --envs "HF_TOKEN=<HuggingFace_token>"
Expand Down
35 changes: 29 additions & 6 deletions llm_atc/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,21 @@ def cli():
required=True,
help="local/cloud URI to finetuning data. (e.g ~/mychat.json, s3://my_bucket/my_chat.json)",
)
@click.option(
"--checkpoint_bucket", type=str, required=True, help="object store bucket name"
)
@click.option(
"--checkpoint_path",
type=str,
required=True,
help="object store path for fine tuned checkpoints, e.g. ~/datasets",
)
@click.option(
"--checkpoint_store",
type=str,
required=True,
help="object store type ['S3', 'GCS', 'AZURE', 'R2', 'IBM']",
)
@click.option("-n", "--name", type=str, help="Name of this model run.", required=True)
@click.option(
"--description", type=str, default="", help="description of this model run"
Expand Down Expand Up @@ -100,6 +115,9 @@ def cli():
def train(
model_type: str,
finetune_data: str,
checkpoint_bucket: str,
checkpoint_path: str,
checkpoint_store: Optional[str],
name: str,
description: str,
cluster: Optional[str],
Expand All @@ -118,12 +136,11 @@ def train(
event="training launched",
timestamp=datetime.utcnow(),
)
if RunTracker.run_exists(name):
raise ValueError(
f"Task with name {name} already exists in {llm_atc.constants.LLM_ATC_PATH}. Try again with a different name"
)
task = train_task(
model_type,
checkpoint_bucket=checkpoint_bucket,
checkpoint_path=checkpoint_path,
checkpoint_store=checkpoint_store,
finetune_data=finetune_data,
name=name,
cloud=cloud,
Expand All @@ -146,7 +163,11 @@ def train(
"--name",
help="name of model to serve",
required=True,
multiple=True,
)
@click.option(
"--source",
help="object store path for llm-atc finetuned model checkpoints."
"e.g. s3://<bucket-name>/<path>/<to>/<checkpoints>",
)
@click.option(
"-e",
Expand Down Expand Up @@ -189,7 +210,8 @@ def train(
help="Don't connect to this session",
)
def serve(
name: List[str],
name: str,
source: Optional[str],
accelerator: Optional[str],
envs: Optional[str],
cluster: Optional[str],
Expand All @@ -209,6 +231,7 @@ def serve(
)
serve_task = serve_route(
name,
source=source,
accelerator=accelerator,
envs=envs,
cloud=cloud,
Expand Down
10 changes: 2 additions & 8 deletions llm_atc/config/serve/serve.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,6 @@ resources:
ports:
- 8000

file_mounts:
/llm-atc:
name: llm-atc # Make sure it is unique or you own this bucket name
mode: MOUNT # MOUNT or COPY. Defaults to MOUNT if not specified

setup: |
conda activate chatbot
if [ $? -ne 0 ]; then
Expand All @@ -22,7 +17,7 @@ setup: |
conda install -y -c conda-forge accelerate
pip install sentencepiece
pip install vllm
pip install git+https://github.com/lm-sys/FastChat.git
pip install git+https://github.com/lm-sys/FastChat.git@v0.2.28
pip install --upgrade openai
if [[ "$HF_TOKEN" != "" ]];
then
Expand All @@ -49,7 +44,6 @@ run: |
master_addr=`echo "$SKYPILOT_NODE_IPS" | head -n1`
let x='SKYPILOT_NODE_RANK + 1'
this_addr=`echo "$SKYPILOT_NODE_IPS" | sed -n "${x}p"`
MODEL_NAME=`echo "$MODELS_LIST" | sed -n "${x}p"`
echo "The ip address of this machine is ${this_addr}"
echo "The head address is ${master_addr}"
Expand Down Expand Up @@ -82,5 +76,5 @@ run: |
envs:
MODELS_LIST: lmsys/vicuna-7b-v1.3
MODEL_NAME: lmsys/vicuna-7b-v1.3
HF_TOKEN: ""
10 changes: 6 additions & 4 deletions llm_atc/config/train/vicuna.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@ num_nodes: 1

file_mounts:
/artifacts:
name: llm-atc
name: ${MY_BUCKET} # Name of the bucket.
mode: MOUNT

workdir: .
store: ${BUCKET_TYPE} # s3, gcs, r2, ibm

setup: |
# Setup the environment
Expand Down Expand Up @@ -64,7 +63,7 @@ run: |
# the training for saving checkpoints.
mkdir -p ~/.checkpoints
LOCAL_CKPT_PATH=~/.checkpoints
CKPT_PATH=/artifacts/${MODEL_NAME}
CKPT_PATH=/artifacts/${BUCKET_PATH}/${MODEL_NAME}
mkdir -p $CKPT_PATH
last_ckpt=$(ls ${CKPT_PATH} | grep -E '[0-9]+' | sort -t'-' -k1,1 -k2,2n | tail -1)
mkdir -p ~/.checkpoints/${last_ckpt}
Expand Down Expand Up @@ -127,3 +126,6 @@ envs:
WANDB_API_KEY: ""
MODEL_NAME: "vicuna_test"
HF_TOKEN: ""
MY_BUCKET: "llm-atc"
BUCKET_PATH: "my_vicuna" # object store path.
BUCKET_TYPE: "S3"
15 changes: 14 additions & 1 deletion llm_atc/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import os
import sky
from sky.data.storage import Storage

from omegaconf import OmegaConf
from typing import Any, Dict, Optional
Expand All @@ -10,7 +11,7 @@
SUPPORTED_MODELS = ("vicuna",)


def train_task(model_type: str, **launcher_kwargs) -> sky.Task:
def train_task(model_type: str, *args, **launcher_kwargs) -> sky.Task:
"""
Dispatch train launch to corresponding task default config
Expand Down Expand Up @@ -39,6 +40,9 @@ class Launcher:
def __init__(
self,
finetune_data: str,
checkpoint_bucket: str = "llm-atc",
checkpoint_path: str = "my_vicuna",
checkpoint_store: str = "S3",
name: Optional[str] = None,
cloud: Optional[str] = None,
region: Optional[str] = None,
Expand All @@ -47,6 +51,9 @@ def __init__(
envs: Optional[str] = "",
):
self.finetune_data: str = finetune_data
self.checkpoint_bucket: str = checkpoint_bucket
self.checkpoint_path: str = checkpoint_path
self.checkpoint_store: str = checkpoint_store
self.name: Optional[str] = name
self.cloud: Optional[str] = cloud
self.region: Optional[str] = region
Expand Down Expand Up @@ -85,8 +92,14 @@ def launch(self) -> sky.Task:
logging.warning(
"No huggingface token provided. You will not be able to finetune starting from private or gated models"
)
self.envs["MY_BUCKET"] = self.checkpoint_bucket
self.envs["BUCKET_PATH"] = self.checkpoint_path
self.envs["BUCKET_TYPE"] = self.checkpoint_store
task.update_envs(self.envs)
task.update_file_mounts({"/data/mydata.json": self.finetune_data})
storage = Storage(name=self.checkpoint_bucket)
storage.add_store(self.checkpoint_store)
task.update_storage_mounts({"/artifacts": storage})
resource = list(task.get_resources())[0]
resource._set_accelerators(self.accelerator, None)
resource._cloud = sky.clouds.CLOUD_REGISTRY.from_str(self.cloud)
Expand Down
27 changes: 16 additions & 11 deletions llm_atc/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Any, Dict, List, Optional


def serve_route(model_names: List[str], **serve_kwargs):
def serve_route(model_name: str, source: Optional[str] = None, **serve_kwargs):
"""Routes model serve requests to the corresponding model serve config
Args:
Expand All @@ -17,26 +17,30 @@ def serve_route(model_names: List[str], **serve_kwargs):
Raises:
ValueError: requested non-existent model from llm-atc
"""
model_names = list(model_names)
for i, name in enumerate(model_names):
if name.startswith("llm-atc/") and not RunTracker.run_exists(
name.split("/")[-1]
):
raise ValueError(f"model = {name} does not exist within llm-atc.")
return Serve(model_names, **serve_kwargs).serve()
if model_name.startswith("llm-atc/") and source is None:
raise ValueError(
"Attempting to use a finetuned model without a corresponding object store location"
)
elif not source is None and not model_name.startswith("llm-atc/"):
logging.warning(
"Specified object store mount but model is not an llm-atc model. Skipping mounting."
)
return Serve(model_name, source, **serve_kwargs).serve()


class Serve:
def __init__(
self,
names: List[str],
names: str,
source: Optional[str],
accelerator: Optional[str] = None,
cloud: Optional[str] = None,
region: Optional[str] = None,
zone: Optional[str] = None,
envs: str = "",
):
self.names = names
self.source = source
self.num_models = len(names)
self.accelerator = accelerator
self.envs: Dict[Any, Any] = (
Expand Down Expand Up @@ -65,7 +69,7 @@ def default_serve_task(self) -> sky.Task:
def serve(self) -> sky.Task:
"""Deploy fastchat.serve.openai_api_server with vllm_worker"""
serve_task = self.default_serve_task
self.envs["MODELS_LIST"] = "\n".join(self.names)
self.envs["MODEL_NAME"] = self.names
if "HF_TOKEN" not in self.envs:
logging.warning(
"No huggingface token provided. You will not be able to access private or gated models"
Expand All @@ -76,5 +80,6 @@ def serve(self) -> sky.Task:
resource._cloud = sky.clouds.CLOUD_REGISTRY.from_str(self.cloud)
resource._set_region_zone(self.region, self.zone)
serve_task.set_resources(resource)
serve_task.num_noded = self.num_models
if self.source and self.names.startswith("llm-atc/"):
serve_task.update_file_mounts({"/" + self.names: self.source})
return serve_task
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "llm_atc"
version = "0.1.3"
version = "0.1.4"
description = "Tools for fine tuning and serving LLMs"
authors = ["Andrew Aikawa <[email protected]>"]
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def test_train_vicuna():
test = Test(
"train_vicuna",
[
f"llm-atc train --model_type vicuna --finetune_data {test_chat} --name {name} --description 'test case vicuna fine tune' -c mycluster --cloud gcp --envs 'MODEL_SIZE=7' --accelerator A100-80G:4",
f"llm-atc train --checkpoint_bucket llm-atc --checkpoint_path ~/test_vicuna --checkpoint_store S3 --model_type vicuna --finetune_data {test_chat} --name {name} --description 'test case vicuna fine tune' -c mycluster --cloud gcp --envs 'MODEL_SIZE=7' --accelerator A100-80G:4",
],
f"sky down --purge -y {name}",
timeout=10 * 60,
Expand Down
3 changes: 3 additions & 0 deletions tests/test_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
def test_train():
task = train_task(
"vicuna",
checkpoint_bucket="llm-atc",
checkpoint_path="myvicuna",
checkpoint_store="S3",
finetune_data="./vicuna_test.json",
name="myvicuna",
cloud="aws",
Expand Down
4 changes: 2 additions & 2 deletions tests/test_serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@

def test_serve():
serve_task = serve_route(
["lmsys/vicuna-7b-1.3"],
"lmsys/vicuna-7b-1.3",
accelerator="V100:1",
envs="HF_TOKEN=mytoken",
cloud="aws",
region="us-east-2",
)
assert serve_task.envs["MODELS_LIST"] == "lmsys/vicuna-7b-1.3"
assert serve_task.envs["MODEL_NAME"] == "lmsys/vicuna-7b-1.3"
assert serve_task.envs["HF_TOKEN"] == "mytoken"

sky.launch(serve_task, cluster_name="dummycluster", dryrun=True)

0 comments on commit 2eff234

Please sign in to comment.