Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update assets cli after drivers breaking changes #216

Merged
merged 1 commit into from
Mar 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions modelkit/assets/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,14 @@
from rich.table import Table
from rich.tree import Tree

from modelkit.assets.drivers.abc import StorageDriverSettings

try:
from modelkit.assets.drivers.gcs import GCSStorageDriver
from modelkit.assets.drivers.gcs import GCSStorageDriver, GCSStorageDriverSettings

has_gcs = True
except ModuleNotFoundError:
has_gcs = False
try:
from modelkit.assets.drivers.s3 import S3StorageDriver
from modelkit.assets.drivers.s3 import S3StorageDriver, S3StorageDriverSettings

has_s3 = True
except ModuleNotFoundError:
Expand Down Expand Up @@ -132,20 +130,23 @@ def new_(asset_path, asset_spec, storage_prefix, dry_run):
with tempfile.TemporaryDirectory() as tmp_dir:
if not os.path.exists(asset_path):
parsed_path = parse_remote_url(asset_path)
driver_settings = StorageDriverSettings(
bucket=parsed_path["bucket_name"]
)
if parsed_path["storage_prefix"] == "gs":
if not has_gcs:
raise DriverNotInstalledError(
"GCS driver not installed, install modelkit[assets-gcs]"
)
driver_settings = GCSStorageDriverSettings(
bucket=parsed_path["bucket_name"]
)
driver = GCSStorageDriver(driver_settings)
elif parsed_path["storage_prefix"] == "s3":
if not has_s3:
raise DriverNotInstalledError(
"S3 driver not installed, install modelkit[assets-s3]"
)
driver_settings = S3StorageDriverSettings(
bucket=parsed_path["bucket_name"]
)
driver = S3StorageDriver(driver_settings)
else:
raise ValueError(
Expand Down Expand Up @@ -234,20 +235,23 @@ def update_(asset_path, asset_spec, storage_prefix, bump_major, dry_run):
with tempfile.TemporaryDirectory() as tmp_dir:
if not os.path.exists(asset_path):
parsed_path = parse_remote_url(asset_path)
driver_settings = StorageDriverSettings(
bucket=parsed_path["bucket_name"]
)
if parsed_path["storage_prefix"] == "gs":
if not has_gcs:
raise DriverNotInstalledError(
"GCS driver not installed, install modelkit[assets-gcs]"
)
driver_settings = GCSStorageDriverSettings(
bucket=parsed_path["bucket_name"]
)
driver = GCSStorageDriver(driver_settings)
elif parsed_path["storage_prefix"] == "s3":
if not has_s3:
raise DriverNotInstalledError(
"S3 driver not installed, install modelkit[assets-s3]"
)
driver_settings = S3StorageDriverSettings(
bucket=parsed_path["bucket_name"]
)
driver = S3StorageDriver(driver_settings)
else:
raise ValueError(
Expand Down
Loading