Skip to content

Commit

Permalink
Merge pull request #23 from Eventual-Inc/bug/fixes
Browse files Browse the repository at this point in the history
[BUG] Fix errors
  • Loading branch information
raunakab authored Oct 12, 2024
2 parents 37ff4cd + bdc5696 commit 4043d94
Show file tree
Hide file tree
Showing 7 changed files with 262 additions and 1,196 deletions.
2 changes: 1 addition & 1 deletion daft_launcher/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import cli
from daft_launcher import cli


def main():
Expand Down
70 changes: 44 additions & 26 deletions daft_launcher/cli.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
"""All CLI bindings."""

from typing import Optional
import click
from pathlib import Path
from . import commands, data_definitions
from daft_launcher import commands, data_definitions, helpers
from importlib import metadata


DEFAULT_CONFIG_PATH = Path(".daft.toml")


def generate_intro_message():
def _generate_intro_message():
summary = metadata.metadata("daft-launcher").get("Summary")
info_string = "For more documentation, please visit:\n\nhttps://eventual-inc.github.io/daft-launcher"
return f"{summary}\n\n{info_string}"


def get_config_path(config: Optional[Path]) -> data_definitions.RayConfiguration:
def _get_config_bundle(config: Optional[Path]) -> data_definitions.ConfigurationBundle:
if config:
if not config.exists():
raise click.UsageError("Config file does not exist.")
Expand All @@ -27,21 +29,21 @@ def get_config_path(config: Optional[Path]) -> data_definitions.RayConfiguration
return data_definitions.build_ray_config_from_path(config)


def assert_identity_file_path(identity_file: Optional[Path]):
def _assert_identity_file_path(identity_file: Optional[Path]):
if not identity_file:
return
if not identity_file.exists():
raise click.UsageError("Identity file does not exist.")


def assert_working_dir(working_dir: Path):
def _assert_working_dir(working_dir: Path):
if not working_dir.exists():
raise click.UsageError("Working dir does not exist.")
if not working_dir.is_dir():
raise click.UsageError("Working dir must be a directory.")


def get_new_configuration_file_path(name: Optional[Path]) -> Path:
def _get_new_configuration_file_path(name: Optional[Path]) -> Path:
name = name or DEFAULT_CONFIG_PATH
if name.is_file():
raise click.UsageError(f"A configuration file at path {name} already exists.")
Expand All @@ -55,6 +57,10 @@ def get_new_configuration_file_path(name: Optional[Path]) -> Path:
return name


# Options
# ==============================================================================


def identity_file_option(func):
return click.option(
"--identity-file",
Expand All @@ -75,14 +81,6 @@ def working_dir_option(func):
)(func)


def init_config_file_name_argument(func):
return click.argument(
"name",
required=False,
type=Path,
)(func)


def config_option(func):
return click.option(
"--config",
Expand All @@ -93,12 +91,28 @@ def config_option(func):
)(func)


# Arguments
# ==============================================================================


def init_config_file_name_argument(func):
return click.argument(
"name",
required=False,
type=Path,
)(func)


def cmd_args_argument(func):
return click.argument("cmd_args", nargs=-1, type=click.UNPROCESSED, required=True)(
func
)


# Command Decorators
# ==============================================================================


def init_config_command(func):
return click.command("init-config", help="Create a new configuration file.")(func)

Expand Down Expand Up @@ -132,17 +146,21 @@ def sql_command(func):
)


# CLI Commands
# ==============================================================================


@init_config_command
@init_config_file_name_argument
def init_config(name: Optional[Path]):
name = get_new_configuration_file_path(name)
name = _get_new_configuration_file_path(name)
commands.init_config(name)


@up_command
@config_option
def up(config: Optional[Path]):
ray_config = get_config_path(config)
ray_config = _get_config_bundle(config)
commands.up(ray_config)


Expand All @@ -158,8 +176,8 @@ def connect(
config: Optional[Path],
identity_file: Optional[Path],
):
ray_config = get_config_path(config)
assert_identity_file_path(identity_file)
ray_config = _get_config_bundle(config)
_assert_identity_file_path(identity_file)
commands.connect(ray_config, identity_file)


Expand All @@ -174,9 +192,9 @@ def submit(
working_dir: Path,
cmd_args: tuple[str],
):
ray_config = get_config_path(config)
assert_identity_file_path(identity_file)
assert_working_dir(working_dir)
ray_config = _get_config_bundle(config)
_assert_identity_file_path(identity_file)
_assert_working_dir(working_dir)
cmd_args_list = [arg for arg in cmd_args]
commands.submit(ray_config, identity_file, working_dir, cmd_args_list)

Expand All @@ -190,21 +208,21 @@ def sql(
identity_file: Optional[Path],
cmd_args: tuple[str],
):
ray_config = get_config_path(config)
assert_identity_file_path(identity_file)
ray_config = _get_config_bundle(config)
_assert_identity_file_path(identity_file)
cmd_args_list = [arg for arg in cmd_args]
commands.sql(ray_config, identity_file, cmd_args_list)


@down_command
@config_option
def down(config: Optional[Path]):
ray_config = get_config_path(config)
ray_config = _get_config_bundle(config)
commands.down(ray_config)


@click.group(help=generate_intro_message())
@click.version_option(version=metadata.version("daft-launcher"))
@click.group(help=_generate_intro_message())
@click.version_option(version=helpers.daft_launcher_version())
def cli(): ...


Expand Down
33 changes: 21 additions & 12 deletions daft_launcher/commands.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
"""All internal implementations of each command provided by daft launcher.
# Note
All implementations should go here.
The actual bindings to the CLI commands should go in `cli.py`.
"""

import asyncio
from botocore.exceptions import TokenRetrievalError
from typing import List, Optional, Any
from pathlib import Path
import subprocess
from . import helpers, data_definitions
from daft_launcher import helpers, data_definitions
from ray.autoscaler import sdk as ray_sdk
from ray import job_submission
import click
Expand Down Expand Up @@ -36,7 +43,8 @@ def init_config(name: Path):
print(f"Successfully created a new configuration file: {name}")


def up(ray_config: data_definitions.RayConfiguration):
def up(config_bundle: data_definitions.ConfigurationBundle):
_, ray_config = config_bundle
ray_sdk.create_or_update_cluster(
ray_config,
no_restart=False,
Expand All @@ -48,7 +56,7 @@ def up(ray_config: data_definitions.RayConfiguration):


def list():
state_map = helpers.list_helper()
state_map = helpers.get_state_map()
for state_index, (state, instance_infos) in enumerate(state_map.items()):
if state_index != 0:
print()
Expand All @@ -60,27 +68,27 @@ def list():


def connect(
ray_config: data_definitions.RayConfiguration,
config_bundle: data_definitions.ConfigurationBundle,
identity_file: Optional[Path],
):
if not identity_file:
identity_file = helpers.detect_keypair(ray_config)
process = helpers.ssh_helper(ray_config, identity_file, [10001])
identity_file = helpers.detect_keypair(config_bundle)
process = helpers.ssh(config_bundle, identity_file, [10001])
print(ON_CONNECTION_MESSAGE)
process.wait()


def submit(
ray_config: data_definitions.RayConfiguration,
config_bundle: data_definitions.ConfigurationBundle,
identity_file: Optional[Path],
working_dir: Path,
cmd_args: List[str],
):
if not identity_file:
identity_file = helpers.detect_keypair(ray_config)
identity_file = helpers.detect_keypair(config_bundle)
cmd = " ".join(cmd_args)

process = helpers.ssh_helper(ray_config, identity_file)
process = helpers.ssh(config_bundle, identity_file)
try:
working_dir_path = Path(working_dir).absolute()
client = None
Expand Down Expand Up @@ -119,18 +127,19 @@ def submit(


def sql(
ray_config: data_definitions.RayConfiguration,
config_bundle: data_definitions.ConfigurationBundle,
identity_file: Optional[Path],
cmd_args: List[str],
):
submit(
ray_config,
config_bundle,
identity_file,
Path(__file__).parent / "assets",
["python", "sql.py"] + cmd_args,
)


def down(ray_config: data_definitions.RayConfiguration):
def down(config_bundle: data_definitions.ConfigurationBundle):
_, ray_config = config_bundle
ray_sdk.teardown_cluster(ray_config)
print("Successfully spun the cluster down.")
Loading

0 comments on commit 4043d94

Please sign in to comment.