Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Default to ubuntu for GCP and avoid key pair checking #1641

Closed
wants to merge 16 commits into from
Closed
191 changes: 112 additions & 79 deletions sky/authentication.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Module to enable a single SkyPilot key for all VMs in each cloud."""
import copy
import functools
import json
import os
import re
import socket
Expand Down Expand Up @@ -102,7 +103,7 @@ def _replace_cloud_init_ssh_info_in_config(config: Dict[str, Any],
def setup_aws_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
_, public_key_path = get_or_generate_keys()
with open(public_key_path, 'r') as f:
public_key = f.read()
public_key = f.read().strip()
config = _replace_cloud_init_ssh_info_in_config(config, public_key)
return config

Expand Down Expand Up @@ -132,6 +133,105 @@ def _wait_for_compute_global_operation(project_name: str, operation_name: str,
return result


def _maybe_add_ssh_key_to_gcp_project_if_debian(compute, project,
config: Dict[str, Any],
os_login_enabled: bool):
"""Add ssh key to GCP account if using Debian image without cloud-init.

This function is for backward compatibility. It is only used when the user
is using the old Debian image without cloud-init. In this case, we need to
add the ssh key to the GCP account so that we can ssh into the instance.
"""
private_key_path, public_key_path = get_or_generate_keys()
user = config['auth']['ssh_user']

node_config = config.get('available_node_types',
{}).get('ray_head_default',
cblmemo marked this conversation as resolved.
Show resolved Hide resolved
{}).get('node_config', {})
image_id = node_config.get('disks', [{}])[0].get('initializeParams',
{}).get('sourceImage')
# image_id is None when TPU VM is used, as TPU VM does not use image.
if image_id is not None and 'debian' not in image_id.lower():
image_info = clouds.GCP.get_image_info(image_id)
if 'debian' not in json.dumps(image_info).lower():
# The non-Debian images have the ssh key setup by cloud-init.
return
logger.info('Adding ssh key to GCP account.')
if os_login_enabled:
# Add ssh key to GCP with oslogin
subprocess.run(
'gcloud compute os-login ssh-keys add '
f'--key-file={public_key_path}',
check=True,
shell=True,
stdout=subprocess.DEVNULL)
# Enable ssh port for all the instances
enable_ssh_cmd = ('gcloud compute firewall-rules create '
'allow-ssh-ingress-from-iap '
'--direction=INGRESS '
'--action=allow '
'--rules=tcp:22 '
'--source-ranges=0.0.0.0/0')
proc = subprocess.run(enable_ssh_cmd,
check=False,
shell=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.PIPE)
if proc.returncode != 0 and 'already exists' not in proc.stderr.decode(
'utf-8'):
subprocess_utils.handle_returncode(proc.returncode, enable_ssh_cmd,
'Failed to enable ssh port.',
proc.stderr.decode('utf-8'))
return config

# OS Login is not enabled for the project. Add the ssh key directly to the
# metadata.
project_keys: str = next( # type: ignore
(item for item in project['commonInstanceMetadata'].get('items', [])
if item['key'] == 'ssh-keys'), {}).get('value', '')
ssh_keys = project_keys.split('\n') if project_keys else []

# Get public key from file.
with open(public_key_path, 'r') as f:
public_key = f.read()

# Check if ssh key in Google Project's metadata
public_key_token = public_key.split(' ')[1]

key_found = False
for key in ssh_keys:
key_list = key.split(' ')
if len(key_list) != 3:
continue
if user == key_list[-1] and os.path.exists(
private_key_path) and key_list[1] == public_key.split(' ')[1]:
key_found = True

if not key_found:
new_ssh_key = '{user}:ssh-rsa {public_key_token} {user}'.format(
user=user, public_key_token=public_key_token)
metadata = project['commonInstanceMetadata'].get('items', [])

ssh_key_index = [
k for k, v in enumerate(metadata) if v['key'] == 'ssh-keys'
]
assert len(ssh_key_index) <= 1

if len(ssh_key_index) == 0:
metadata.append({'key': 'ssh-keys', 'value': new_ssh_key})
else:
first_ssh_key_index = ssh_key_index[0]
metadata[first_ssh_key_index]['value'] += '\n' + new_ssh_key

project['commonInstanceMetadata']['items'] = metadata

operation = compute.projects().setCommonInstanceMetadata(
project=project['name'],
body=project['commonInstanceMetadata']).execute()
_wait_for_compute_global_operation(project['name'], operation['name'],
compute)


# Snippets of code inspired from
# https://github.com/ray-project/ray/blob/master/python/ray/autoscaler/_private/gcp/config.py
# Takes in config, a yaml dict and outputs a postprocessed dict
Expand All @@ -140,15 +240,16 @@ def _wait_for_compute_global_operation(project_name: str, operation_name: str,
# Retry for the GCP as sometimes there will be connection reset by peer error.
@common_utils.retry
def setup_gcp_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
private_key_path, public_key_path = get_or_generate_keys()
_, public_key_path = get_or_generate_keys()
with open(public_key_path, 'r') as f:
public_key = f.read()
config = copy.deepcopy(config)

project_id = config['provider']['project_id']
compute = gcp.build('compute',
'v1',
credentials=None,
cache_discovery=False)
user = config['auth']['ssh_user']

try:
project = compute.projects().get(project=project_id).execute()
Expand Down Expand Up @@ -191,7 +292,8 @@ def setup_gcp_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
(item for item in project['commonInstanceMetadata'].get('items', [])
if item['key'] == 'enable-oslogin'), {}).get('value', 'False')

if project_oslogin.lower() == 'true':
oslogin_enabled = project_oslogin.lower() == 'true'
if oslogin_enabled:
logger.info(
f'OS Login is enabled for GCP project {project_id}. Running '
'additional authentication steps.')
Expand Down Expand Up @@ -243,81 +345,12 @@ def setup_gcp_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
os_login_username = account.replace('@', '_').replace('.', '_')
config['auth']['ssh_user'] = os_login_username

# Add ssh key to GCP with oslogin
subprocess.run(
'gcloud compute os-login ssh-keys add '
f'--key-file={public_key_path}',
check=True,
shell=True,
stdout=subprocess.DEVNULL)
# Enable ssh port for all the instances
enable_ssh_cmd = ('gcloud compute firewall-rules create '
'allow-ssh-ingress-from-iap '
'--direction=INGRESS '
'--action=allow '
'--rules=tcp:22 '
'--source-ranges=0.0.0.0/0')
proc = subprocess.run(enable_ssh_cmd,
check=False,
shell=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.PIPE)
if proc.returncode != 0 and 'already exists' not in proc.stderr.decode(
'utf-8'):
subprocess_utils.handle_returncode(proc.returncode, enable_ssh_cmd,
'Failed to enable ssh port.',
proc.stderr.decode('utf-8'))
return config

# OS Login is not enabled for the project. Add the ssh key directly to the
# metadata.
# TODO(zhwu): Use cloud init to add ssh public key, to avoid the permission
# issue. A blocker is that the cloud init is not installed in the debian
# image by default.
project_keys: str = next( # type: ignore
(item for item in project['commonInstanceMetadata'].get('items', [])
if item['key'] == 'ssh-keys'), {}).get('value', '')
ssh_keys = project_keys.split('\n') if project_keys else []

# Get public key from file.
with open(public_key_path, 'r') as f:
public_key = f.read()

# Check if ssh key in Google Project's metadata
public_key_token = public_key.split(' ')[1]

key_found = False
for key in ssh_keys:
key_list = key.split(' ')
if len(key_list) != 3:
continue
if user == key_list[-1] and os.path.exists(
private_key_path) and key_list[1] == public_key.split(' ')[1]:
key_found = True

if not key_found:
new_ssh_key = '{user}:ssh-rsa {public_key_token} {user}'.format(
user=user, public_key_token=public_key_token)
metadata = project['commonInstanceMetadata'].get('items', [])

ssh_key_index = [
k for k, v in enumerate(metadata) if v['key'] == 'ssh-keys'
]
assert len(ssh_key_index) <= 1

if len(ssh_key_index) == 0:
metadata.append({'key': 'ssh-keys', 'value': new_ssh_key})
else:
first_ssh_key_index = ssh_key_index[0]
metadata[first_ssh_key_index]['value'] += '\n' + new_ssh_key

project['commonInstanceMetadata']['items'] = metadata

operation = compute.projects().setCommonInstanceMetadata(
project=project['name'],
body=project['commonInstanceMetadata']).execute()
_wait_for_compute_global_operation(project['name'], operation['name'],
compute)
config = _replace_cloud_init_ssh_info_in_config(config, public_key)
# This function is for backward compatibility, as the user using the old
# Debian-based image may not have the cloud-init enabled, and we need to
# add the ssh key to the account.
_maybe_add_ssh_key_to_gcp_project_if_debian(compute, project, config,
oslogin_enabled)
return config


Expand Down
55 changes: 31 additions & 24 deletions sky/clouds/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import subprocess
import time
import typing
from typing import Dict, Iterator, List, Optional, Tuple
from typing import Any, Dict, Iterator, List, Optional, Tuple

from sky import clouds
from sky import exceptions
Expand Down Expand Up @@ -77,6 +77,11 @@

# TODO(zhwu): Move the default AMI size to the catalog instead.
DEFAULT_GCP_IMAGE_GB = 50
_DEFAULT_CPU_IMAGE = 'skypilot:cpu-ubuntu-2004'
# Other GPUs: CUDA driver version 510.47.03, CUDA Library 11.6.
# K80: CUDA driver version 470.103.01, CUDA Library 11.4 (we manually install
# the older CUDA driver in the gcp-ray.yaml to support K80).
_DEFAULT_GPU_IMAGE = 'skypilot:gpu-ubuntu-2004'


def _run_output(cmd):
Expand Down Expand Up @@ -243,26 +248,24 @@ def get_egress_cost(self, num_gigabytes):
def is_same_cloud(self, other):
return isinstance(other, GCP)

def get_image_size(self, image_id: str, region: Optional[str]) -> float:
del region # unused
if image_id.startswith('skypilot:'):
return DEFAULT_GCP_IMAGE_GB
@classmethod
def get_image_info(cls, image_id) -> Dict[str, Any]:
try:
compute = gcp.build('compute',
'v1',
credentials=None,
cache_discovery=False)
except gcp.credential_error_exception() as e:
return DEFAULT_GCP_IMAGE_GB
return {}
try:
image_attrs = image_id.split('/')
if len(image_attrs) == 1:
raise ValueError(f'Image {image_id!r} not found in GCP.')
project = image_attrs[1]
image_name = image_attrs[-1]
image_infos = compute.images().get(project=project,
image=image_name).execute()
return float(image_infos['diskSizeGb'])
image_info = compute.images().get(project=project,
image=image_name).execute()
return image_info
except gcp.http_error_exception() as e:
if e.resp.status == 403:
with ux_utils.print_exception_no_traceback():
Expand All @@ -274,6 +277,21 @@ def get_image_size(self, image_id: str, region: Optional[str]) -> float:
'GCP.') from None
raise

def get_image_size(self, image_id: str, region: Optional[str]) -> float:
del region # unused
if image_id.startswith('skypilot:'):
# Hack: this utilizes the knowledge that both the selected debian
# and ubuntu images on GCP have the same size of 50GB, to reduce
# the overhead for querying the image size.
return DEFAULT_GCP_IMAGE_GB
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: how do we guarantee that the ubuntu & debian tags have the same size, DEFAULT_GCP_IMAGE_GB?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The image size can be got using gcloud compute images describe projects/deeplearning-platform-release/global/images/common-cu113-v20230501-ubuntu-2004-py37, and it seems both of them have the same size of 50GB. Added a comment for the hack.

image_info = self.get_image_info(image_id)
if 'diskSizeGb' not in image_info:
# All the images in GCP should have the diskSizeGb field, but
# just in case, we do not want to crash the program, as the image
# size check is not critical.
return DEFAULT_GCP_IMAGE_GB
return float(image_info['diskSizeGb'])

@classmethod
def get_default_instance_type(
cls,
Expand All @@ -295,10 +313,8 @@ def make_deploy_resources_variables(

# gcloud compute images list \
# --project deeplearning-platform-release \
# --no-standard-images
# We use the debian image, as the ubuntu image has some connectivity
# issue when first booted.
image_id = 'skypilot:cpu-debian-10'
# --no-standard-images | grep ubuntu-2004
image_id = _DEFAULT_CPU_IMAGE

r = resources
# Find GPU spec, if any.
Expand Down Expand Up @@ -338,17 +354,8 @@ def make_deploy_resources_variables(
resources_vars['gpu'] = 'nvidia-tesla-{}'.format(
acc.lower())
resources_vars['gpu_count'] = acc_count
if acc == 'K80':
# Though the image is called cu113, it actually has later
# versions of CUDA as noted below.
# CUDA driver version 470.57.02, CUDA Library 11.4
image_id = 'skypilot:k80-debian-10'
else:
# Though the image is called cu113, it actually has later
# versions of CUDA as noted below.
# CUDA driver version 510.47.03, CUDA Library 11.6
# Does not support torch==1.13.0 with cu117
image_id = 'skypilot:gpu-debian-10'

image_id = _DEFAULT_GPU_IMAGE

if resources.image_id is not None:
if None in resources.image_id:
Expand Down
2 changes: 1 addition & 1 deletion sky/clouds/service_catalog/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def read_catalog(filename: str,
"""
assert filename.endswith('.csv'), 'The catalog file must be a CSV file.'
assert (pull_frequency_hours is None or
pull_frequency_hours > 0), pull_frequency_hours
pull_frequency_hours >= 0), pull_frequency_hours
catalog_path = get_catalog_path(filename)
cloud = cloud_lib.CLOUD_REGISTRY.from_str(os.path.dirname(filename))

Expand Down
4 changes: 4 additions & 0 deletions sky/clouds/service_catalog/gcp_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
pull_frequency_hours=_PULL_FREQUENCY_HOURS)
_image_df = common.read_catalog('gcp/images.csv',
pull_frequency_hours=_PULL_FREQUENCY_HOURS)
if _image_df[_image_df['Tag'] == 'skypilot:cpu-ubuntu-2004'].empty:
# Update the image catalog if it does not include the updated images
# https://github.com/skypilot-org/skypilot-catalog/pull/25.
_image_df = common.read_catalog('gcp/images.csv', pull_frequency_hours=0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure, after increasing our version number, we don't need this right?


_TPU_REGIONS = [
'us-central1',
Expand Down
Loading