Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] add reuse gateway #933

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 43 additions & 14 deletions skyplane/api/dataplane.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@


class DataplaneAutoDeprovision:
def __init__(self, dataplane: "Dataplane"):
def __init__(self, dataplane: "Dataplane", reuse_instances: bool = False):
self.dataplane = dataplane
self.reuse_instances = reuse_instances

def __enter__(self):
return self.dataplane

def __exit__(self, exc_type, exc_value, exc_tb):
if self.reuse_instances: return
logger.fs.warning("Deprovisioning dataplane")
self.dataplane.deprovision()

Expand Down Expand Up @@ -129,6 +131,7 @@ def _start_gateway(
def provision(
self,
allow_firewall: bool = True,
reuse_instances: bool = False,
gateway_docker_image: str = os.environ.get("SKYPLANE_DOCKER_IMAGE", gateway_docker_image()),
authorize_ssh_pub_key: Optional[str] = None,
max_jobs: int = 16,
Expand Down Expand Up @@ -160,28 +163,54 @@ def provision(
is_scp_used = any(n.region_tag.startswith("scp:") for n in self.topology.get_gateways())

# create VMs from the topology
uuids = []
reuse_nodes_tmp = {}
if reuse_instances:
for node in self.topology.get_gateways():
cloud_provider, region = node.region_tag.split(":")
if reuse_nodes_tmp.get(f"{cloud_provider} {region}", []): continue
if cloud_provider == "gcp":
reuse_nodes_tmp[f"{cloud_provider} {region}"] = self.provisioner.gcp.get_instance_list(region=region)
if cloud_provider == "azure":
reuse_nodes_tmp[f"{cloud_provider} {region}"] = self.provisioner.azure.get_instance_list(region=region)
if cloud_provider == "aws":
reuse_nodes_tmp[f"{cloud_provider} {region}"] = self.provisioner.aws.get_instance_list(region=region)
if cloud_provider == "ibmcloud":
reuse_nodes_tmp[f"{cloud_provider} {region}"] = self.provisioner.ibmcloud.get_instance_list(region=region)
if cloud_provider == "scp":
reuse_nodes_tmp[f"{cloud_provider} {region}"] = self.provisioner.scp.get_instance_list(region=region)
for node in self.topology.get_gateways():
cloud_provider, region = node.region_tag.split(":")
assert (
cloud_provider != "cloudflare"
), f"Cannot create VMs in certain cloud providers: check planner output {self.topology.to_dict()}"
self.provisioner.add_task(
cloud_provider=cloud_provider,
region=region,
vm_type=node.vm_type or getattr(self.transfer_config, f"{cloud_provider}_instance_class"),
spot=getattr(self.transfer_config, f"{cloud_provider}_use_spot_instances"),
autoterminate_minutes=self.transfer_config.autoterminate_minutes,
)
if reuse_instances:
reuse_nodes = reuse_nodes_tmp.get(f"{cloud_provider} {region}", [])

if reuse_instances and reuse_nodes:
reuse_server = reuse_nodes[0]
reuse_nodes.remove(reuse_server)
uid = self.provisioner.add_node(reuse_server)
uuids.append(uid)
else:
assert (
cloud_provider != "cloudflare"
), f"Cannot create VMs in certain cloud providers: check planner output {self.topology.to_dict()}"
self.provisioner.add_task(
cloud_provider=cloud_provider,
region=region,
vm_type=node.vm_type or getattr(self.transfer_config, f"{cloud_provider}_instance_class"),
spot=getattr(self.transfer_config, f"{cloud_provider}_use_spot_instances"),
autoterminate_minutes=self.transfer_config.autoterminate_minutes,
)

# initialize clouds
self.provisioner.init_global(aws=is_aws_used, azure=is_azure_used, gcp=is_gcp_used, ibmcloud=is_ibmcloud_used, scp=is_scp_used)

# provision VMs
uuids = self.provisioner.provision(
uuids_tmp = self.provisioner.provision(
authorize_firewall=allow_firewall,
max_jobs=max_jobs,
spinner=spinner,
)
uuids.extend(uuids_tmp)

# bind VMs to nodes
servers = [self.provisioner.get_node(u) for u in uuids]
Expand Down Expand Up @@ -291,9 +320,9 @@ def get_error_logs(args):
errors[instance] = result
return errors

def auto_deprovision(self) -> DataplaneAutoDeprovision:
def auto_deprovision(self, reuse_instances: bool = False) -> DataplaneAutoDeprovision:
"""Returns a context manager that will automatically call deprovision upon exit."""
return DataplaneAutoDeprovision(self)
return DataplaneAutoDeprovision(self, reuse_instances)

def source_gateways(self) -> List[compute.Server]:
"""Returns a list of source gateway nodes"""
Expand Down
7 changes: 6 additions & 1 deletion skyplane/api/provisioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,12 @@ def get_node(self, uuid: str) -> compute.Server:
:type uuid: str
"""
return self.provisioned_vms[uuid]


def add_node(self, server) -> str:
uid = uuid.uuid4().int
self.provisioned_vms[uid] = server
return uid

def _provision_task(self, task: ProvisionerTask):
with Timer() as t:
if task.cloud_provider == "aws":
Expand Down
15 changes: 11 additions & 4 deletions skyplane/cli/cli_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ def run_transfer(
src: str,
dst: str,
recursive: bool,
reuse_gateways: bool,
debug: bool,
multipart: bool,
confirm: bool,
Expand Down Expand Up @@ -392,15 +393,17 @@ def run_transfer(

# dataplane must be created after transfers are queued
dp = pipeline.create_dataplane(debug=debug)
with dp.auto_deprovision():
with dp.auto_deprovision(reuse_instances=reuse_gateways):
try:
dp.provision(spinner=True)
dp.provision(reuse_instances=reuse_gateways, spinner=True)
dp.run(pipeline.jobs_to_dispatch, hooks=ProgressBarTransferHook(dp.topology.dest_region_tags))
if reuse_gateways: return 0
except KeyboardInterrupt:
logger.fs.warning("Transfer cancelled by user (KeyboardInterrupt).")
console.print("\n[red]Transfer cancelled by user. Copying gateway logs and exiting.[/red]")
try:
dp.copy_gateway_logs()
if reuse_gateways: return 0
force_deprovision(dp)
except Exception as e:
logger.fs.exception(e)
Expand All @@ -414,19 +417,22 @@ def run_transfer(
console.print(f"[bright_black]{traceback.format_exc()}[/bright_black]")
console.print(e.pretty_print_str())
UsageClient.log_exception("cli_query_objstore", e, args, cli.src_region_tag, [cli.dst_region_tag])
if reuse_gateways: return 0
force_deprovision(dp)
except Exception as e:
logger.fs.exception(e)
console.print(f"[bright_black]{traceback.format_exc()}[/bright_black]")
console.print(e)
UsageClient.log_exception("cli_query_objstore", e, args, cli.src_region_tag, [cli.dst_region_tag])
if reuse_gateways: return 0
force_deprovision(dp)


def cp(
src: str,
dst: str,
recursive: bool = typer.Option(False, "--recursive", "-r", help="If true, will copy objects at folder prefix recursively"),
reuse_gateways: bool = typer.Option(False, help="If true, will leave provisioned instances running to be reused"),
debug: bool = typer.Option(False, help="If true, will write debug information to debug directory."),
multipart: bool = typer.Option(cloud_config.get_flag("multipart_enabled"), help="If true, will use multipart uploads."),
# transfer flags
Expand Down Expand Up @@ -468,12 +474,13 @@ def cp(
:param solver: The solver to use for the transfer (default: direct)
:type solver: str
"""
return run_transfer(src, dst, recursive, debug, multipart, confirm, max_instances, max_connections, solver, "cp")
return run_transfer(src, dst, recursive, reuse_gateways, debug, multipart, confirm, max_instances, max_connections, solver, "cp")


def sync(
src: str,
dst: str,
reuse_gateways: bool = typer.Option(False, help="If true, will leave provisioned instances running to be reused"),
debug: bool = typer.Option(False, help="If true, will write debug information to debug directory."),
multipart: bool = typer.Option(cloud_config.get_flag("multipart_enabled"), help="If true, will use multipart uploads."),
# transfer flags
Expand Down Expand Up @@ -517,4 +524,4 @@ def sync(
:param solver: The solver to use for the transfer (default: direct)
:type solver: str
"""
return run_transfer(src, dst, False, debug, multipart, confirm, max_instances, max_connections, solver, "sync")
return run_transfer(src, dst, False, debug, reuse_gateways, multipart, confirm, max_instances, max_connections, solver, "sync")