diff --git a/tests/providers/amazon/aws/system/utils/test_helpers.py b/tests/providers/amazon/aws/system/utils/test_helpers.py index f48de1788b74c..3af3720688a09 100644 --- a/tests/providers/amazon/aws/system/utils/test_helpers.py +++ b/tests/providers/amazon/aws/system/utils/test_helpers.py @@ -24,7 +24,7 @@ import os import sys from io import StringIO -from unittest.mock import ANY, patch +from unittest.mock import patch import pytest from moto import mock_aws @@ -79,8 +79,15 @@ def test_fetch_variable_success( ) -> None: mock_getenv.return_value = env_value or ssm_value - result = utils.fetch_variable(ANY, default_value) if default_value else utils.fetch_variable(ANY_STR) + utils._fetch_from_ssm.cache_clear() + result = ( + utils.fetch_variable("some_key", default_value) + if default_value + else utils.fetch_variable(ANY_STR) + ) + + utils._fetch_from_ssm.cache_clear() assert result == expected_result def test_fetch_variable_no_value_found_raises_exception(self): diff --git a/tests/system/providers/amazon/aws/example_bedrock_retrieve_and_generate.py b/tests/system/providers/amazon/aws/example_bedrock_retrieve_and_generate.py index fcebc8c40a0d4..2b7bce2fecde8 100644 --- a/tests/system/providers/amazon/aws/example_bedrock_retrieve_and_generate.py +++ b/tests/system/providers/amazon/aws/example_bedrock_retrieve_and_generate.py @@ -127,7 +127,7 @@ def create_opensearch_policies(bedrock_role_arn: str, collection_name: str, poli def _create_security_policy(name, policy_type, policy): try: - aoss_client.create_security_policy(name=name, policy=json.dumps(policy), type=policy_type) + aoss_client.conn.create_security_policy(name=name, policy=json.dumps(policy), type=policy_type) except ClientError as e: if e.response["Error"]["Code"] == "ConflictException": log.info("OpenSearch security policy %s already exists.", name) @@ -135,7 +135,7 @@ def _create_security_policy(name, policy_type, policy): def _create_access_policy(name, policy_type, policy): try: - aoss_client.create_access_policy(name=name, policy=json.dumps(policy), type=policy_type) + aoss_client.conn.create_access_policy(name=name, policy=json.dumps(policy), type=policy_type) except ClientError as e: if e.response["Error"]["Code"] == "ConflictException": log.info("OpenSearch data access policy %s already exists.", name) @@ -204,9 +204,9 @@ def create_collection(collection_name: str): :param collection_name: The name of the Collection to create. """ log.info("\nCreating collection: %s.", collection_name) - return aoss_client.create_collection(name=collection_name, type="VECTORSEARCH")["createCollectionDetail"][ - "id" - ] + return aoss_client.conn.create_collection(name=collection_name, type="VECTORSEARCH")[ + "createCollectionDetail" + ]["id"] @task @@ -317,7 +317,7 @@ def get_collection_arn(collection_id: str): """ return next( colxn["arn"] - for colxn in aoss_client.list_collections()["collectionSummaries"] + for colxn in aoss_client.conn.list_collections()["collectionSummaries"] if colxn["id"] == collection_id ) @@ -336,7 +336,9 @@ def delete_data_source(knowledge_base_id: str, data_source_id: str): :param data_source_id: The unique identifier of the data source to delete. """ log.info("Deleting data source %s from Knowledge Base %s.", data_source_id, knowledge_base_id) - bedrock_agent_client.delete_data_source(dataSourceId=data_source_id, knowledgeBaseId=knowledge_base_id) + bedrock_agent_client.conn.delete_data_source( + dataSourceId=data_source_id, knowledgeBaseId=knowledge_base_id + ) # [END howto_operator_bedrock_delete_data_source] @@ -355,7 +357,7 @@ def delete_knowledge_base(knowledge_base_id: str): :param knowledge_base_id: The unique identifier of the knowledge base to delete. """ log.info("Deleting Knowledge Base %s.", knowledge_base_id) - bedrock_agent_client.delete_knowledge_base(knowledgeBaseId=knowledge_base_id) + bedrock_agent_client.conn.delete_knowledge_base(knowledgeBaseId=knowledge_base_id) # [END howto_operator_bedrock_delete_knowledge_base] @@ -393,7 +395,7 @@ def delete_collection(collection_id: str): :param collection_id: ID of the collection to be indexed. """ log.info("Deleting collection %s.", collection_id) - aoss_client.delete_collection(id=collection_id) + aoss_client.conn.delete_collection(id=collection_id) @task(trigger_rule=TriggerRule.ALL_DONE) @@ -404,7 +406,7 @@ def delete_opensearch_policies(collection_name: str): :param collection_name: All policies in the given collection name will be deleted. """ - access_policies = aoss_client.list_access_policies( + access_policies = aoss_client.conn.list_access_policies( type="data", resource=[f"collection/{collection_name}"] )["accessPolicySummaries"] log.info("Found access policies for %s: %s", collection_name, access_policies) @@ -412,10 +414,10 @@ def delete_opensearch_policies(collection_name: str): raise Exception("No access policies found?") for policy in access_policies: log.info("Deleting access policy for %s: %s", collection_name, policy["name"]) - aoss_client.delete_access_policy(name=policy["name"], type="data") + aoss_client.conn.delete_access_policy(name=policy["name"], type="data") for policy_type in ["encryption", "network"]: - policies = aoss_client.list_security_policies( + policies = aoss_client.conn.list_security_policies( type=policy_type, resource=[f"collection/{collection_name}"] )["securityPolicySummaries"] if not policies: @@ -423,7 +425,7 @@ def delete_opensearch_policies(collection_name: str): log.info("Found %s security policies for %s: %s", policy_type, collection_name, policies) for policy in policies: log.info("Deleting %s security policy for %s: %s", policy_type, collection_name, policy["name"]) - aoss_client.delete_security_policy(name=policy["name"], type=policy_type) + aoss_client.conn.delete_security_policy(name=policy["name"], type=policy_type) with DAG( @@ -436,8 +438,8 @@ def delete_opensearch_policies(collection_name: str): test_context = sys_test_context_task() env_id = test_context["ENV_ID"] - aoss_client = OpenSearchServerlessHook(aws_conn_id=None).conn - bedrock_agent_client = BedrockAgentHook(aws_conn_id=None).conn + aoss_client = OpenSearchServerlessHook(aws_conn_id=None) + bedrock_agent_client = BedrockAgentHook(aws_conn_id=None) region_name = boto3.session.Session().region_name diff --git a/tests/system/providers/amazon/aws/utils/__init__.py b/tests/system/providers/amazon/aws/utils/__init__.py index 8b4114fc90ad0..411f92ab7bf3a 100644 --- a/tests/system/providers/amazon/aws/utils/__init__.py +++ b/tests/system/providers/amazon/aws/utils/__init__.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import functools import inspect import json import logging @@ -92,6 +93,7 @@ def _validate_env_id(env_id: str) -> str: return env_id.lower() +@functools.cache def _fetch_from_ssm(key: str, test_name: str | None = None) -> str: """ Test values are stored in the SSM Value as a JSON-encoded dict of key/value pairs.