Skip to content

Commit

Permalink
Update assets cli after drivers breaking changes
Browse files Browse the repository at this point in the history
  • Loading branch information
tgenin committed Mar 4, 2024
1 parent 7e2457c commit 58426fe
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions modelkit/assets/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
from modelkit.assets.drivers.abc import StorageDriverSettings

try:
from modelkit.assets.drivers.gcs import GCSStorageDriver
from modelkit.assets.drivers.gcs import GCSStorageDriver, GCSStorageDriverSettings

has_gcs = True
except ModuleNotFoundError:
has_gcs = False
try:
from modelkit.assets.drivers.s3 import S3StorageDriver
from modelkit.assets.drivers.s3 import S3StorageDriver, S3StorageDriverSettings

has_s3 = True
except ModuleNotFoundError:
Expand Down Expand Up @@ -132,20 +132,23 @@ def new_(asset_path, asset_spec, storage_prefix, dry_run):
with tempfile.TemporaryDirectory() as tmp_dir:
if not os.path.exists(asset_path):
parsed_path = parse_remote_url(asset_path)
driver_settings = StorageDriverSettings(
bucket=parsed_path["bucket_name"]
)
if parsed_path["storage_prefix"] == "gs":
if not has_gcs:
raise DriverNotInstalledError(
"GCS driver not installed, install modelkit[assets-gcs]"
)
driver_settings = GCSStorageDriverSettings(
bucket=parsed_path["bucket_name"]
)
driver = GCSStorageDriver(driver_settings)
elif parsed_path["storage_prefix"] == "s3":
if not has_s3:
raise DriverNotInstalledError(
"S3 driver not installed, install modelkit[assets-s3]"
)
driver_settings = S3StorageDriverSettings(
bucket=parsed_path["bucket_name"]
)
driver = S3StorageDriver(driver_settings)
else:
raise ValueError(
Expand Down Expand Up @@ -234,20 +237,23 @@ def update_(asset_path, asset_spec, storage_prefix, bump_major, dry_run):
with tempfile.TemporaryDirectory() as tmp_dir:
if not os.path.exists(asset_path):
parsed_path = parse_remote_url(asset_path)
driver_settings = StorageDriverSettings(
bucket=parsed_path["bucket_name"]
)
if parsed_path["storage_prefix"] == "gs":
if not has_gcs:
raise DriverNotInstalledError(
"GCS driver not installed, install modelkit[assets-gcs]"
)
driver_settings = GCSStorageDriverSettings(
bucket=parsed_path["bucket_name"]
)
driver = GCSStorageDriver(driver_settings)
elif parsed_path["storage_prefix"] == "s3":
if not has_s3:
raise DriverNotInstalledError(
"S3 driver not installed, install modelkit[assets-s3]"
)
driver_settings = S3StorageDriverSettings(
bucket=parsed_path["bucket_name"]
)
driver = S3StorageDriver(driver_settings)
else:
raise ValueError(
Expand Down

0 comments on commit 58426fe

Please sign in to comment.