Skip to content

Commit

Permalink
Improve permission (#86)
Browse files Browse the repository at this point in the history
  • Loading branch information
Weisu Yin authored Aug 24, 2023
1 parent 8b866a1 commit 22a0abc
Show file tree
Hide file tree
Showing 11 changed files with 35 additions and 44 deletions.
3 changes: 2 additions & 1 deletion docs/tutorials/autogluon-cloud.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ To help you to setup the necessary permissions, you can generate trust relations
```python
from autogluon.cloud import TabularCloudPredictor # Can be other CloudPredictor as well

TabularCloudPredictor.generate_trust_relationship_and_iam_policy_file(
TabularCloudPredictor.generate_default_permission(
backend="BACKNED_YOU_WANT" # We currently support sagemaker and ray_aws
account_id="YOUR_ACCOUNT_ID", # The AWS account ID you plan to use for CloudPredictor.
cloud_output_bucket="S3_BUCKET" # S3 bucket name where intermediate artifacts will be uploaded and trained models should be saved. You need to create this bucket beforehand.
)
Expand Down
8 changes: 7 additions & 1 deletion src/autogluon/cloud/backend/backend_factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .backend import Backend
from .multimodal_sagemaker_backend import MultiModalSagemakerBackend
from .ray_aws_backend import TabularRayAWSBackend
from .ray_aws_backend import RayAWSBackend, TabularRayAWSBackend
from .sagemaker_backend import SagemakerBackend
from .tabular_sagemaker_backend import TabularSagemakerBackend
from .timeseries_sagemaker_backend import TimeSeriesSagemakerBackend
Expand All @@ -12,10 +12,16 @@ class BackendFactory:
TabularSagemakerBackend,
MultiModalSagemakerBackend,
TimeSeriesSagemakerBackend,
RayAWSBackend,
TabularRayAWSBackend,
]
__name_to_backend = {cls.name: cls for cls in __supported_backend}

@staticmethod
def get_backend_cls(backend: str) -> Backend.__class__:
assert backend in BackendFactory.__name_to_backend, f"{backend} not supported"
return BackendFactory.__name_to_backend[backend]

@staticmethod
def get_backend(backend: str, **init_args) -> Backend:
"""Return the corresponding backend"""
Expand Down
3 changes: 2 additions & 1 deletion src/autogluon/cloud/backend/ray_aws_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def initialize(self, **kwargs) -> None:
self.region is not None
), "Please setup a region via `export AWS_DEFAULT_REGION=YOUR_REGION` in the terminal"

def generate_default_permission(self, **kwargs) -> Dict[str, str]:
@staticmethod
def generate_default_permission(**kwargs) -> Dict[str, str]:
"""Generate default permission file user could use to setup the corresponding entity, i.e. IAM Role in AWS"""
return RayAWSClusterManager.generate_default_permission(**kwargs)

Expand Down
3 changes: 2 additions & 1 deletion src/autogluon/cloud/backend/ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def initialize(self, **kwargs) -> None:
self._fit_job = RayFitJob()
os.makedirs(os.path.join(self.local_output_path, "job"), exist_ok=True)

def generate_default_permission(self, **kwargs) -> Dict[str, str]:
@staticmethod
def generate_default_permission(**kwargs) -> Dict[str, str]:
"""Generate default permission file user could use to setup the corresponding entity, i.e. IAM Role in AWS"""
return RayClusterManager.generate_default_permission(**kwargs)

Expand Down
3 changes: 2 additions & 1 deletion src/autogluon/cloud/backend/sagemaker_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,9 @@ def initialize(self, **kwargs) -> None:
self._fit_job: SageMakerFitJob = SageMakerFitJob(session=self.sagemaker_session)
self._batch_transform_jobs = MostRecentInsertedOrderedDict()

@staticmethod
def generate_default_permission(
self, account_id: str, cloud_output_bucket: str, output_path: Optional[str] = None
account_id: str, cloud_output_bucket: str, output_path: Optional[str] = None
) -> Dict[str, str]:
"""
Generate required trust relationship and IAM policy file in json format for CloudPredictor with SageMaker backend.
Expand Down
6 changes: 4 additions & 2 deletions src/autogluon/cloud/cluster/ray_aws_cluster_manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import logging
import os
from typing import Dict
from typing import Dict, Optional

from ..utils.iam import replace_iam_policy_place_holder, replace_trust_relationship_place_holder
from ..utils.ray_aws_iam import (
Expand All @@ -21,7 +21,9 @@ def __init__(self, config: str, cloud_output_bucket: str, **kwargs) -> None:
self.cloud_output_bucket = cloud_output_bucket

@staticmethod
def generate_default_permission(account_id: str, cloud_output_bucket: str, output_path: str) -> Dict[str, str]:
def generate_default_permission(
account_id: str, cloud_output_bucket: str, output_path: Optional[str] = None
) -> Dict[str, str]:
"""
Generate trust relationship and iam policy required to manage cluster
Users can use the generated files to create an IAM role for themselves.
Expand Down
14 changes: 4 additions & 10 deletions src/autogluon/cloud/predictor/cloud_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

class CloudPredictor(ABC):
predictor_file_name = "CloudPredictor.pkl"
backend_map = {}

def __init__(
self,
Expand Down Expand Up @@ -83,14 +84,6 @@ def predictor_type(self) -> str:
"""
raise NotImplementedError

@property
@abstractmethod
def backend_map(self) -> Dict:
"""
Map between general backend to module specific backend
"""
raise NotImplementedError

@property
def is_fit(self) -> bool:
"""
Expand All @@ -107,7 +100,8 @@ def endpoint_name(self) -> Optional[str]:
return self.backend.endpoint.endpoint_name
return None

def generate_default_permission(self, **kwargs) -> Dict[str, str]:
@staticmethod
def generate_default_permission(backend: str = SAGEMAKER, **kwargs) -> Dict[str, str]:
"""
Generate required permission file in json format for CloudPredictor with your choice of backend.
Users can use the generated files to create an entity for themselves.
Expand All @@ -122,7 +116,7 @@ def generate_default_permission(self, **kwargs) -> Dict[str, str]:
------
A dict containing the trust relationship and IAM policy files paths
"""
return self.backend.generate_default_permission(**kwargs)
return BackendFactory.get_backend_cls(backend=backend).generate_default_permission(**kwargs)

def info(self) -> Dict[str, Any]:
"""
Expand Down
9 changes: 1 addition & 8 deletions src/autogluon/cloud/predictor/multimodal_cloud_predictor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
from typing import Dict

from ..backend.constant import MULTIMODL_SAGEMAKER, SAGEMAKER
from .cloud_predictor import CloudPredictor
Expand All @@ -9,6 +8,7 @@

class MultiModalCloudPredictor(CloudPredictor):
predictor_file_name = "MultiModalCloudPredictor.pkl"
backend_map = {SAGEMAKER: MULTIMODL_SAGEMAKER}

@property
def predictor_type(self) -> str:
Expand All @@ -17,13 +17,6 @@ def predictor_type(self) -> str:
"""
return "multimodal"

@property
def backend_map(self) -> Dict:
"""
Map between general backend to module specific backend
"""
return {SAGEMAKER: MULTIMODL_SAGEMAKER}

def _get_local_predictor_cls(self):
from autogluon.multimodal import MultiModalPredictor

Expand Down
9 changes: 1 addition & 8 deletions src/autogluon/cloud/predictor/tabular_cloud_predictor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
from typing import Dict

from ..backend.constant import RAY_AWS, SAGEMAKER, TABULAR_RAY_AWS, TABULAR_SAGEMAKER
from .cloud_predictor import CloudPredictor
Expand All @@ -9,6 +8,7 @@

class TabularCloudPredictor(CloudPredictor):
predictor_file_name = "TabularCloudPredictor.pkl"
backend_map = {SAGEMAKER: TABULAR_SAGEMAKER, RAY_AWS: TABULAR_RAY_AWS}

@property
def predictor_type(self):
Expand All @@ -17,13 +17,6 @@ def predictor_type(self):
"""
return "tabular"

@property
def backend_map(self) -> Dict:
"""
Map between general backend to module specific backend
"""
return {SAGEMAKER: TABULAR_SAGEMAKER, RAY_AWS: TABULAR_RAY_AWS}

def _get_local_predictor_cls(self):
from autogluon.tabular import TabularPredictor

Expand Down
8 changes: 1 addition & 7 deletions src/autogluon/cloud/predictor/timeseries_cloud_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

class TimeSeriesCloudPredictor(CloudPredictor):
predictor_file_name = "TimeSeriesCloudPredictor.pkl"
backend_map = {SAGEMAKER: TIMESERIES_SAGEMAKER}

@property
def predictor_type(self):
Expand All @@ -21,13 +22,6 @@ def predictor_type(self):
"""
return "timeseries"

@property
def backend_map(self) -> Dict:
"""
Map between general backend to module specific backend
"""
return {SAGEMAKER: TIMESERIES_SAGEMAKER}

def _get_local_predictor_cls(self):
from autogluon.timeseries import TimeSeriesPredictor

Expand Down
13 changes: 9 additions & 4 deletions tests/unittests/general/test_general.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import json
import tempfile

from autogluon.cloud.backend import SagemakerBackend
import pytest

from autogluon.cloud import TabularCloudPredictor
from autogluon.cloud.backend.constant import RAY_AWS, SAGEMAKER

def test_generate_trust_relationship_and_iam_policy():

@pytest.mark.parametrize("backend", [RAY_AWS, SAGEMAKER])
def test_generate_default_permission(backend):
with tempfile.TemporaryDirectory() as root:
backend = SagemakerBackend(local_output_path="dummy", cloud_output_path="dummy", predictor_type="dummy")
paths = backend.generate_default_permission(account_id="foo", cloud_output_bucket="foo", output_path=root)
paths = TabularCloudPredictor.generate_default_permission(
backend=backend, account_id="foo", cloud_output_bucket="foo", output_path=root
)
trust_relationship_path, iam_policy_path = paths["trust_relationship"], paths["iam_policy"]
for path in [trust_relationship_path, iam_policy_path]:
with open(path, "r") as file:
Expand Down

0 comments on commit 22a0abc

Please sign in to comment.