From 343751cde7fbfef377efb72f143a0ef9550a8c2b Mon Sep 17 00:00:00 2001 From: Jen Hamon Date: Thu, 18 Jul 2024 10:13:32 -0400 Subject: [PATCH] Fix tests, add validation of dp value --- pinecone/control/pinecone.py | 12 ++++++++---- pinecone/models/index_list.py | 3 ++- .../serverless/test_deletion_protection.py | 16 ++++++++++++++-- 3 files changed, 24 insertions(+), 7 deletions(-) diff --git a/pinecone/control/pinecone.py b/pinecone/control/pinecone.py index b4dbdf15..b04e0243 100644 --- a/pinecone/control/pinecone.py +++ b/pinecone/control/pinecone.py @@ -327,7 +327,10 @@ def create_index( def _parse_non_empty_args(args: List[Tuple[str, Any]]) -> Dict[str, Any]: return {arg_name: val for arg_name, val in args if val is not None} - dp = DeletionProtection(deletion_protection) + if deletion_protection in ["enabled", "disabled"]: + dp = DeletionProtection(deletion_protection) + else: + raise ValueError("deletion_protection must be either 'enabled' or 'disabled'") if isinstance(spec, dict): if "serverless" in spec: @@ -573,12 +576,13 @@ def configure_index( """ api_instance = self.index_api - description = self.describe_index(name=name) - if deletion_protection is None: + description = self.describe_index(name=name) dp = DeletionProtection(description.deletion_protection) - else: + elif deletion_protection in ["enabled", "disabled"]: dp = DeletionProtection(deletion_protection) + else: + raise ValueError("deletion_protection must be either 'enabled' or 'disabled'") pod_config_args: Dict[str, Any] = {} if pod_type: diff --git a/pinecone/models/index_list.py b/pinecone/models/index_list.py index d9823e21..093d0a92 100644 --- a/pinecone/models/index_list.py +++ b/pinecone/models/index_list.py @@ -1,4 +1,5 @@ from pinecone.core.openapi.control.models import IndexList as OpenAPIIndexList +from .index_model import IndexModel class IndexList: @@ -10,7 +11,7 @@ def names(self): return [i["name"] for i in self.index_list.indexes] def __getitem__(self, key): - return self.index_list.indexes[key] + return IndexModel(self.index_list.indexes[key]) def __len__(self): return len(self.index_list.indexes) diff --git a/tests/integration/control/serverless/test_deletion_protection.py b/tests/integration/control/serverless/test_deletion_protection.py index eaddb75d..0a656e8e 100644 --- a/tests/integration/control/serverless/test_deletion_protection.py +++ b/tests/integration/control/serverless/test_deletion_protection.py @@ -4,7 +4,7 @@ class TestDeletionProtection: def test_deletion_protection(self, client, create_sl_index_params): name = create_sl_index_params["name"] - client.create_index(**create_sl_index_params, deletion_protection=True) + client.create_index(**create_sl_index_params, deletion_protection="enabled") desc = client.describe_index(name) assert desc.deletion_protection == "enabled" @@ -12,8 +12,20 @@ def test_deletion_protection(self, client, create_sl_index_params): client.delete_index(name) assert "Deletion protection is enabled for this index" in str(e.value) - client.configure_index(name, deletion_protection=False) + client.configure_index(name, deletion_protection="disabled") desc = client.describe_index(name) assert desc.deletion_protection == "disabled" client.delete_index(name) + + @pytest.mark.parametrize("deletion_protection", ["invalid", None]) + def test_deletion_protection_invalid_options(self, client, create_sl_index_params, deletion_protection): + with pytest.raises(Exception) as e: + client.create_index(**create_sl_index_params, deletion_protection=deletion_protection) + assert "deletion_protection must be either 'enabled' or 'disabled'" in str(e.value) + + @pytest.mark.parametrize("deletion_protection", ["invalid"]) + def test_configure_deletion_protection_invalid_options(self, client, create_sl_index_params, deletion_protection): + with pytest.raises(Exception) as e: + client.create_index(**create_sl_index_params, deletion_protection=deletion_protection) + assert "deletion_protection must be either 'enabled' or 'disabled'" in str(e.value)