-
Notifications
You must be signed in to change notification settings - Fork 549
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Core] Add Azure ML Compute Instance Support #3905
Open
cblmemo
wants to merge
9
commits into
master
Choose a base branch
from
az-ml
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
36ca924
wip
cblmemo e55f14b
wip
cblmemo 8c85b98
minor
cblmemo ebeeccb
upd
cblmemo 2836bae
fix resources version
cblmemo 097e6ca
support docker image
cblmemo e1ee720
upd not supported features
cblmemo 4426b41
add az ml catalog
cblmemo ae85f87
Merge remote-tracking branch 'origin/master' into az-ml
cblmemo File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,6 +36,7 @@ | |
from sky import resources as resources_lib | ||
from sky import serve as serve_lib | ||
from sky import sky_logging | ||
from sky import skypilot_config | ||
from sky import status_lib | ||
from sky import task as task_lib | ||
from sky.backends import backend_utils | ||
|
@@ -2616,6 +2617,21 @@ def check_resources_fit_cluster( | |
# was handled by ResourceHandle._update_cluster_region. | ||
assert launched_resources.region is not None, handle | ||
|
||
# Check whether Azure cluster uses Azure ML API | ||
if launched_resources.cloud.is_same_cloud(clouds.Azure()): | ||
task_use_az_ml = skypilot_config.get_nested(('azure', 'use_az_ml'), | ||
False) | ||
cluster_use_az_ml = launched_resources.use_az_ml | ||
if cluster_use_az_ml != task_use_az_ml: | ||
task_str = 'uses' if task_use_az_ml else 'does not use' | ||
cluster_str = 'uses' if cluster_use_az_ml else 'does not use' | ||
with ux_utils.print_exception_no_traceback(): | ||
raise exceptions.ResourcesMismatchError( | ||
f'Task requirements {task_str} Azure ML API, but the ' | ||
f'specified cluster {cluster_name} {cluster_str} it. ' | ||
f'Please set azure.use_az_ml to {cluster_use_az_ml} in ' | ||
'~/.sky/config.yaml.') | ||
|
||
mismatch_str = (f'To fix: specify a new cluster name, or down the ' | ||
f'existing cluster first: sky down {cluster_name}') | ||
valid_resource = None | ||
|
@@ -3510,6 +3526,20 @@ def _teardown(self, | |
else: | ||
raise | ||
|
||
if handle.launched_resources.cloud.is_same_cloud(clouds.Azure()): | ||
task_use_az_ml = skypilot_config.get_nested(('azure', 'use_az_ml'), | ||
False) | ||
cluster_use_az_ml = handle.launched_resources.use_az_ml | ||
if cluster_use_az_ml != task_use_az_ml: | ||
task_str = 'uses' if task_use_az_ml else 'does not use' | ||
cluster_str = 'uses' if cluster_use_az_ml else 'does not use' | ||
with ux_utils.print_exception_no_traceback(): | ||
raise exceptions.ResourcesMismatchError( | ||
f'Current setup {task_str} Azure ML API, but the ' | ||
f'specified cluster {cluster_name} to terminate ' | ||
f'{cluster_str} it. Please set azure.use_az_ml ' | ||
f'to {cluster_use_az_ml} in ~/.sky/config.yaml.') | ||
|
||
Comment on lines
+3529
to
+3542
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above, we should just allow the teardown using the non-ML API if the existing cluster was launched using non-ML API |
||
lock_path = os.path.expanduser( | ||
backend_utils.CLUSTER_STATUS_LOCK_PATH.format(cluster_name)) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
120 changes: 120 additions & 0 deletions
120
sky/clouds/service_catalog/data_fetchers/fetch_azure_ml.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
"""A script to fetch Azure ML pricing data. | ||
|
||
Requires running fetch_azure.py first to get the pricing data. | ||
""" | ||
from multiprocessing import pool as mp_pool | ||
import os | ||
import typing | ||
from typing import Dict, Set | ||
|
||
from azure.ai import ml | ||
from azure.ai.ml import entities | ||
|
||
from sky.adaptors import azure | ||
from sky.adaptors import common as adaptors_common | ||
from sky.clouds.service_catalog import common | ||
from sky.clouds.service_catalog.data_fetchers import fetch_azure | ||
|
||
if typing.TYPE_CHECKING: | ||
import pandas as pd | ||
else: | ||
pd = adaptors_common.LazyImport('pandas') | ||
|
||
SUBSCRIPTION_ID = azure.get_subscription_id() | ||
|
||
SINGLE_THREADED = False | ||
|
||
az_df = common.read_catalog('azure/vms.csv') | ||
|
||
|
||
def init_ml_client(region: str) -> ml.MLClient: | ||
resource_client = azure.get_client('resource', SUBSCRIPTION_ID) | ||
resource_group_name = f'az-ml-fetcher-{region}' | ||
workspace_name = f'az-ml-fetcher-{region}-ws' | ||
resource_client.resource_groups.create_or_update(resource_group_name, | ||
{'location': region}) | ||
ml_client: ml.MLClient = azure.get_client( | ||
'ml', | ||
SUBSCRIPTION_ID, | ||
resource_group=resource_group_name, | ||
workspace_name=workspace_name) | ||
try: | ||
ml_client.workspaces.get(workspace_name) | ||
except azure.exceptions().ResourceNotFoundError: | ||
print(f'Creating workspace {workspace_name} in {region}') | ||
ws = ml_client.workspaces.begin_create( | ||
entities.Workspace(name=workspace_name, location=region)).result() | ||
print(f'Created workspace {ws.name} in {ws.location}.') | ||
return ml_client | ||
|
||
|
||
def get_supported_instance_type(region: str) -> Dict[str, bool]: | ||
ml_client = init_ml_client(region) | ||
supported_instance_types = {} | ||
for sz in ml_client.compute.list_sizes(): | ||
if sz.supported_compute_types is None: | ||
continue | ||
if 'ComputeInstance' not in sz.supported_compute_types: | ||
continue | ||
supported_instance_types[sz.name] = sz.low_priority_capable | ||
return supported_instance_types | ||
|
||
|
||
def get_instance_type_df(region: str) -> 'pd.DataFrame': | ||
supported_instance_type = get_supported_instance_type(region) | ||
df_filtered = az_df[az_df['Region'] == region].copy() | ||
df_filtered = df_filtered[df_filtered['InstanceType'].isin( | ||
supported_instance_type.keys())] | ||
|
||
def _get_spot_price(row): | ||
ins_type = row['InstanceType'] | ||
assert ins_type in supported_instance_type, ( | ||
f'Instance type {ins_type} not in supported_instance_type') | ||
if supported_instance_type[ins_type]: | ||
return row['SpotPrice'] | ||
return None | ||
|
||
df_filtered['SpotPrice'] = df_filtered.apply(_get_spot_price, axis=1) | ||
|
||
supported_set = set(supported_instance_type.keys()) | ||
df_set = set(az_df[az_df['Region'] == region]['InstanceType']) | ||
missing_instance_types = supported_set - df_set | ||
missing_str = ', '.join(missing_instance_types) | ||
if missing_instance_types: | ||
print(f'Missing instance types for {region}: {missing_str}') | ||
else: | ||
print(f'All supported instance types for {region} are in the catalog.') | ||
|
||
return df_filtered | ||
|
||
|
||
def get_all_regions_instance_types_df(region_set: Set[str]) -> 'pd.DataFrame': | ||
if SINGLE_THREADED: | ||
dfs = [get_instance_type_df(region) for region in region_set] | ||
else: | ||
with mp_pool.Pool() as pool: | ||
dfs_result = pool.map_async(get_instance_type_df, region_set) | ||
dfs = dfs_result.get() | ||
df = pd.concat(dfs, ignore_index=True) | ||
df = df.sort_values(by='InstanceType').reset_index(drop=True) | ||
return df | ||
|
||
|
||
if __name__ == '__main__': | ||
az_ml_parser = fetch_azure.get_arg_parser('Fetch Azure ML pricing data.') | ||
# TODO(tian): Support cleanup after fetching the data. | ||
az_ml_parser.add_argument( | ||
'--cleanup', | ||
action='store_true', | ||
help='Cleanup the resource group and workspace after ' | ||
'fetching the data.') | ||
args = az_ml_parser.parse_args() | ||
|
||
SINGLE_THREADED = args.single_threaded | ||
|
||
instance_df = get_all_regions_instance_types_df( | ||
fetch_azure.get_region_filter(args.all_regions, args.regions, | ||
args.exclude)) | ||
os.makedirs('azure', exist_ok=True) | ||
instance_df.to_csv('azure/az_ml_vms.csv', index=False) | ||
print('Azure ML Service Catalog saved to azure/az_ml_vms.csv') |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should use the existing cluster's setting even if the config has changed to use ML API. This is to be consistent with other part of the code, e.g., if a user specify GCP DWS, and having a existing cluster not using DWS, we should still allow task to be submitted to that cluster.