Skip to content

Commit

Permalink
sdk/python: Move bucket providers to an Enum
Browse files Browse the repository at this point in the history
Signed-off-by: Aaron Wilson <[email protected]>
  • Loading branch information
aaronnw committed Oct 14, 2024
1 parent a9c7c2e commit 917c4c2
Show file tree
Hide file tree
Showing 33 changed files with 244 additions and 182 deletions.
6 changes: 4 additions & 2 deletions python/aistore/sdk/authn/role_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# pylint: disable=too-many-arguments, duplicate-code

from typing import List

from aistore.sdk.provider import Provider
from aistore.sdk.request_client import RequestClient
from aistore.sdk.authn.access_attr import AccessAttr
from aistore.sdk.authn.cluster_manager import ClusterManager
Expand Down Expand Up @@ -122,7 +124,7 @@ def create(
BucketPermission(
bck=BucketModel(
name=bucket_name,
provider="ais",
provider=Provider.AIS.value,
namespace=Namespace(uuid=cluster_uuid),
),
perm=perm_value,
Expand Down Expand Up @@ -206,7 +208,7 @@ def update(
BucketPermission(
bck=BucketModel(
name=bucket_name,
provider="ais",
provider=Provider.AIS.value,
namespace=Namespace(uuid=cluster_uuid),
),
perm=perm_value,
Expand Down
22 changes: 11 additions & 11 deletions python/aistore/sdk/bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import os
from pathlib import Path
import time
from typing import Dict, List, NewType, Iterable
from typing import Dict, List, NewType, Iterable, Union
import requests
from requests import structures

Expand All @@ -34,7 +34,6 @@
HTTP_METHOD_HEAD,
HTTP_METHOD_POST,
MSGPACK_CONTENT_TYPE,
PROVIDER_AIS,
QPARAM_BCK_TO,
QPARAM_BSUMM_REMOTE,
QPARAM_FLT_PRESENCE,
Expand All @@ -49,6 +48,7 @@
DEFAULT_JOB_POLL_TIME,
)
from aistore.sdk.enums import FLTPresence
from aistore.sdk.provider import Provider
from aistore.sdk.dataset.dataset_config import DatasetConfig

from aistore.sdk.errors import (
Expand Down Expand Up @@ -87,22 +87,22 @@ class Bucket(AISSource):
Args:
client (RequestClient): Client for interfacing with AIS cluster
name (str): name of bucket
provider (str, optional): Provider of bucket (one of "ais", "aws", "gcp", ...), defaults to "ais"
provider (str or Provider, optional): Provider of bucket (one of "ais", "aws", "gcp", ...), defaults to "ais"
namespace (Namespace, optional): Namespace of bucket, defaults to None
"""

def __init__(
self,
name: str,
client: RequestClient = None,
provider: str = PROVIDER_AIS,
provider: Union[Provider, str] = Provider.AIS,
namespace: Namespace = None,
):
self._client = client
self._name = name
self._provider = provider
self._provider = Provider.parse(provider)
self._namespace = namespace
self._qparam = {QPARAM_PROVIDER: provider}
self._qparam = {QPARAM_PROVIDER: self.provider.value}
if self.namespace:
self._qparam[QPARAM_NAMESPACE] = namespace.get_path()

Expand All @@ -122,7 +122,7 @@ def qparam(self) -> Dict:
return self._qparam

@property
def provider(self) -> str:
def provider(self) -> Provider:
"""The provider for this bucket."""
return self._provider

Expand Down Expand Up @@ -898,22 +898,22 @@ def _verify_ais_bucket(self):
"""
Verify the bucket provider is AIS
"""
if self.provider is not PROVIDER_AIS:
if self.provider is not Provider.AIS:
raise InvalidBckProvider(self.provider)

def verify_cloud_bucket(self):
"""
Verify the bucket provider is a cloud provider
"""
if self.provider is PROVIDER_AIS:
if self.provider is Provider.AIS:
raise InvalidBckProvider(self.provider)

def get_path(self) -> str:
"""
Get the path representation of this bucket
"""
namespace_path = self.namespace.get_path() if self.namespace else "@#"
return f"{ self.provider }/{ namespace_path }/{ self.name }/"
return f"{ self.provider.value }/{ namespace_path }/{ self.name }/"

def as_model(self) -> BucketModel:
"""
Expand All @@ -923,7 +923,7 @@ def as_model(self) -> BucketModel:
BucketModel representation
"""
return BucketModel(
name=self.name, namespace=self.namespace, provider=self.provider
name=self.name, namespace=self.namespace, provider=self.provider.value
)

def write_dataset(
Expand Down
14 changes: 8 additions & 6 deletions python/aistore/sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
from urllib3 import Retry

from aistore.sdk.bucket import Bucket
from aistore.sdk.const import (
PROVIDER_AIS,
AIS_AUTHN_TOKEN,
)
from aistore.sdk.provider import Provider
from aistore.sdk.const import AIS_AUTHN_TOKEN
from aistore.sdk.cluster import Cluster
from aistore.sdk.dsort import Dsort
from aistore.sdk.request_client import RequestClient
Expand Down Expand Up @@ -67,15 +65,19 @@ def __init__(
)

def bucket(
self, bck_name: str, provider: str = PROVIDER_AIS, namespace: Namespace = None
self,
bck_name: str,
provider: Union[Provider, str] = Provider.AIS,
namespace: Namespace = None,
):
"""
Factory constructor for bucket object.
Does not make any HTTP request, only instantiates a bucket object.
Args:
bck_name (str): Name of bucket
provider (str): Provider of bucket, one of "ais", "aws", "gcp", ... (optional, defaults to ais)
provider (str or Provider): Provider of bucket, one of "ais", "aws", "gcp", ...
(optional, defaults to ais)
namespace (Namespace): Namespace of bucket (optional, defaults to None)
Returns:
Expand Down
10 changes: 5 additions & 5 deletions python/aistore/sdk/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
from __future__ import annotations # pylint: disable=unused-variable

import logging
from typing import List, Optional
from typing import List, Optional, Union

from aistore.sdk.const import (
HTTP_METHOD_GET,
ACT_LIST,
PROVIDER_AIS,
QPARAM_WHAT,
QPARAM_PRIMARY_READY_REB,
QPARAM_PROVIDER,
Expand All @@ -27,6 +26,7 @@
WHAT_NODE_STATS_AND_STATUS,
WHAT_NODE_STATS_AND_STATUS_V322,
)
from aistore.sdk.provider import Provider

from aistore.sdk.types import (
BucketModel,
Expand Down Expand Up @@ -84,12 +84,12 @@ def get_primary_url(self) -> str:
"""
return self._get_smap().proxy_si.public_net.direct_url

def list_buckets(self, provider: str = PROVIDER_AIS):
def list_buckets(self, provider: Union[str, Provider] = Provider.AIS):
"""
Returns list of buckets in AIStore cluster.
Args:
provider (str, optional): Name of bucket provider, one of "ais", "aws", "gcp", "az" or "ht".
provider (str or Provider, optional): Name of bucket provider, one of "ais", "aws", "gcp", "az" or "ht".
Defaults to "ais". Empty provider returns buckets of all providers.
Returns:
Expand All @@ -101,7 +101,7 @@ def list_buckets(self, provider: str = PROVIDER_AIS):
requests.ConnectionTimeout: Timed out connecting to AIStore
requests.ReadTimeout: Timed out waiting response from AIStore
"""
params = {QPARAM_PROVIDER: provider}
params = {QPARAM_PROVIDER: Provider.parse(provider).value}
action = ActionMsg(action=ACT_LIST).dict()

return self.client.request_deserialize(
Expand Down
9 changes: 0 additions & 9 deletions python/aistore/sdk/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,6 @@
URL_PATH_AUTHN_ROLES = "roles"
URL_PATH_AUTHN_TOKENS = "tokens"

# Bucket providers
# See api/apc/provider.go
PROVIDER_AIS = "ais"
PROVIDER_AMAZON = "aws"
PROVIDER_S3 = "s3"
PROVIDER_AZURE = "azure"
PROVIDER_GOOGLE = "gcp"
PROVIDER_HTTP = "ht"

# HTTP Methods
HTTP_METHOD_GET = "get"
HTTP_METHOD_POST = "post"
Expand Down
3 changes: 2 additions & 1 deletion python/aistore/sdk/dataset/data_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from dataclasses import dataclass
from aistore.sdk.client import Client
from aistore.sdk.provider import Provider


@dataclass
Expand All @@ -21,7 +22,7 @@ class DataShard:

client_url: str
bucket_name: str
provider: str = "ais"
provider: str = Provider.AIS
prefix: str = ""
etl_name: str = None

Expand Down
10 changes: 5 additions & 5 deletions python/aistore/sdk/dsort/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,18 @@ The `DsortFramework` class in the Python SDK enables you to define and manage dS
2. **Creating a DsortFramework Directly:**

```python
from aistore.sdk import Client, BucketModel
from aistore.sdk import Client
from aistore.sdk.multiobj import ObjectNames, ObjectRange
from aistore.sdk.dsort import DsortFramework, DsortShardsGroup, ExternalKeyMap

# Initialize the AIStore client
client = Client("http://your-aistore-url:8080")

# Define the input bucket
input_bucket = BucketModel(name="input-bucket", provider="ais")
input_bucket = client.bucket(bck_name="input-bucket")

# Define the output bucket
output_bucket = BucketModel(name="output-bucket", provider="ais")
output_bucket = client.bucket(bck_name="output-bucket")

# Define the input format as ObjectRange
input_format = ObjectRange(
Expand All @@ -50,15 +50,15 @@ The `DsortFramework` class in the Python SDK enables you to define and manage dS

# Define the input shards group
input_shards_group = DsortShardsGroup(
bck=input_bucket,
bck=input_bucket.as_model(),
role="input",
format=input_format,
extension=".tar",
)

# Define the output shards group
output_shards_group = DsortShardsGroup(
bck=output_bucket,
bck=output_bucket.as_model(),
role="output",
format=output_format,
extension=".tar",
Expand Down
3 changes: 2 additions & 1 deletion python/aistore/sdk/dsort/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from aistore.sdk.dsort.types import JobInfo
from aistore.sdk.bucket import Bucket
from aistore.sdk.errors import Timeout
from aistore.sdk.request_client import RequestClient
from aistore.sdk.utils import validate_file, probing_frequency


Expand All @@ -26,7 +27,7 @@ class Dsort:
Class for managing jobs for the dSort extension: https://github.com/NVIDIA/aistore/blob/main/docs/cli/dsort.md
"""

def __init__(self, client: "Client", dsort_id: str = ""):
def __init__(self, client: RequestClient, dsort_id: str = ""):
self._client = client
self._dsort_id = dsort_id

Expand Down
1 change: 0 additions & 1 deletion python/aistore/sdk/enums.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#
# Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
#

from enum import IntEnum


Expand Down
2 changes: 1 addition & 1 deletion python/aistore/sdk/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class InvalidBckProvider(Exception):
"""

def __init__(self, provider):
super().__init__(f"Invalid bucket provider {provider}")
super().__init__(f"Invalid bucket provider: '{provider}'")


# pylint: disable=unused-variable
Expand Down
2 changes: 1 addition & 1 deletion python/aistore/sdk/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def start(
if buckets and len(buckets) > 0:
bucket_models = [
BucketModel(
name=bck.name, provider=bck.provider, namespace=bck.namespace
name=bck.name, provider=bck.provider.value, namespace=bck.namespace
)
for bck in buckets
]
Expand Down
7 changes: 4 additions & 3 deletions python/aistore/sdk/obj/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
HEADER_OBJECT_BLOB_WORKERS,
HEADER_OBJECT_BLOB_CHUNK_SIZE,
)
from aistore.sdk.provider import Provider
from aistore.sdk.obj.object_client import ObjectClient
from aistore.sdk.obj.object_reader import ObjectReader
from aistore.sdk.request_client import RequestClient
Expand All @@ -54,7 +55,7 @@ class BucketDetails:
"""

name: str
provider: str
provider: Provider
qparams: Dict[str, str]


Expand Down Expand Up @@ -89,7 +90,7 @@ def bucket_name(self) -> str:
return self._bck_details.name

@property
def bucket_provider(self):
def bucket_provider(self) -> Provider:
"""Provider of the bucket where this object resides (e.g. ais, s3, gcp)."""
return self._bck_details.provider

Expand Down Expand Up @@ -213,7 +214,7 @@ def get_semantic_url(self) -> str:
Semantic URL to get object
"""

return f"{self.bucket_provider}://{self.bucket_name}/{self._name}"
return f"{self.bucket_provider.value}://{self.bucket_name}/{self._name}"

def get_url(self, archpath: str = "", etl_name: str = None) -> str:
"""
Expand Down
44 changes: 44 additions & 0 deletions python/aistore/sdk/provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from __future__ import annotations

from enum import Enum
from typing import Union

from aistore.sdk.errors import InvalidBckProvider

ALIAS_S3 = "s3"
ALIAS_GS = "gs"


class Provider(Enum):
"""
Represent the providers used for a bucket. See https://aistore.nvidia.com/docs/providers and api/apc/provider.go
"""

AIS = "ais"
AMAZON = "aws"
AZURE = "azure"
GOOGLE = "gcp"
HTTP = "ht"

@staticmethod
def parse(provider: Union[Provider, str]) -> Provider:
"""
Parse a provider Enum instance from a given value.
Args:
provider: A Provider or string.
Returns: The given Provider or a new one constructed from the given value.
Raises: InvalidBckProvider if provided with a string that is not a valid Provider option.
"""
if isinstance(provider, Provider):
return provider
try:
# Use the provider alias for the given string if one exists
provider = provider_aliases.get(provider, provider)
return Provider(provider)
except ValueError as exc:
raise InvalidBckProvider(provider) from exc


provider_aliases = {ALIAS_GS: Provider.GOOGLE.value, ALIAS_S3: Provider.AMAZON.value}
Loading

0 comments on commit 917c4c2

Please sign in to comment.