From 99e6ba586106b4867585f9d7a15d94d046f456d7 Mon Sep 17 00:00:00 2001 From: Thomas Genin <6623268+tgenin@users.noreply.github.com> Date: Mon, 4 Mar 2024 18:26:11 +0100 Subject: [PATCH] Update assets cli after drivers breaking changes --- modelkit/assets/cli.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/modelkit/assets/cli.py b/modelkit/assets/cli.py index 3e03d61b..8e29aeba 100755 --- a/modelkit/assets/cli.py +++ b/modelkit/assets/cli.py @@ -10,16 +10,14 @@ from rich.table import Table from rich.tree import Tree -from modelkit.assets.drivers.abc import StorageDriverSettings - try: - from modelkit.assets.drivers.gcs import GCSStorageDriver + from modelkit.assets.drivers.gcs import GCSStorageDriver, GCSStorageDriverSettings has_gcs = True except ModuleNotFoundError: has_gcs = False try: - from modelkit.assets.drivers.s3 import S3StorageDriver + from modelkit.assets.drivers.s3 import S3StorageDriver, S3StorageDriverSettings has_s3 = True except ModuleNotFoundError: @@ -132,20 +130,23 @@ def new_(asset_path, asset_spec, storage_prefix, dry_run): with tempfile.TemporaryDirectory() as tmp_dir: if not os.path.exists(asset_path): parsed_path = parse_remote_url(asset_path) - driver_settings = StorageDriverSettings( - bucket=parsed_path["bucket_name"] - ) if parsed_path["storage_prefix"] == "gs": if not has_gcs: raise DriverNotInstalledError( "GCS driver not installed, install modelkit[assets-gcs]" ) + driver_settings = GCSStorageDriverSettings( + bucket=parsed_path["bucket_name"] + ) driver = GCSStorageDriver(driver_settings) elif parsed_path["storage_prefix"] == "s3": if not has_s3: raise DriverNotInstalledError( "S3 driver not installed, install modelkit[assets-s3]" ) + driver_settings = S3StorageDriverSettings( + bucket=parsed_path["bucket_name"] + ) driver = S3StorageDriver(driver_settings) else: raise ValueError( @@ -234,20 +235,23 @@ def update_(asset_path, asset_spec, storage_prefix, bump_major, dry_run): with tempfile.TemporaryDirectory() as tmp_dir: if not os.path.exists(asset_path): parsed_path = parse_remote_url(asset_path) - driver_settings = StorageDriverSettings( - bucket=parsed_path["bucket_name"] - ) if parsed_path["storage_prefix"] == "gs": if not has_gcs: raise DriverNotInstalledError( "GCS driver not installed, install modelkit[assets-gcs]" ) + driver_settings = GCSStorageDriverSettings( + bucket=parsed_path["bucket_name"] + ) driver = GCSStorageDriver(driver_settings) elif parsed_path["storage_prefix"] == "s3": if not has_s3: raise DriverNotInstalledError( "S3 driver not installed, install modelkit[assets-s3]" ) + driver_settings = S3StorageDriverSettings( + bucket=parsed_path["bucket_name"] + ) driver = S3StorageDriver(driver_settings) else: raise ValueError(