Skip to content

Commit

Permalink
Bump azure-mgmt-containerinstance (apache#34738)
Browse files Browse the repository at this point in the history
* Bump azure-mgmt-containerinstance

* Apply review suggestions
  • Loading branch information
pankajastro authored Oct 10, 2023
1 parent 9a29738 commit 9ee14a0
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 43 deletions.
9 changes: 7 additions & 2 deletions airflow/providers/microsoft/azure/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ Breaking changes
~~~~~~~~~~~~~~~~

.. warning::
AzureDataFactoryHook methods and AzureDataFactoryRunPipelineOperator arguments resource_group_name and factory_name is
now required instead of kwargs
In this version of the provider, we have removed network_profile param from AzureContainerInstancesOperator and
AzureDataFactoryHook methods and AzureDataFactoryRunPipelineOperator arguments resource_group_name and factory_name
is now required instead of kwargs

* resource_group_name and factory_name is now required argument in AzureDataFactoryHook method get_factory, update_factory,
create_factory, delete_factory, get_linked_service, delete_linked_service, get_dataset, delete_dataset, get_dataflow,
Expand All @@ -44,6 +45,10 @@ Breaking changes
stop_trigger, get_adf_pipeline_run_status, cancel_pipeline_run
* resource_group_name and factory_name is now required in AzureDataFactoryRunPipelineOperator
* Remove class ``PipelineRunInfo`` from ``airflow.providers.microsoft.azure.hooks.data_factory``
* Remove ``network_profile`` param from ``AzureContainerInstancesOperator``
* Remove deprecated ``extra__azure__tenantId`` from azure_container_instance connection extras
* Remove deprecated ``extra__azure__subscriptionId`` from azure_container_instance connection extras


7.0.0
.....
Expand Down
48 changes: 20 additions & 28 deletions airflow/providers/microsoft/azure/hooks/container_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,22 @@

import warnings
from functools import cached_property
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast

from azure.common.client_factory import get_client_from_auth_file, get_client_from_json_dict
from azure.common.credentials import ServicePrincipalCredentials
from azure.identity import DefaultAzureCredential
from azure.identity import ClientSecretCredential, DefaultAzureCredential
from azure.mgmt.containerinstance import ContainerInstanceManagementClient

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.microsoft.azure.hooks.base_azure import AzureBaseHook

if TYPE_CHECKING:
from azure.mgmt.containerinstance.models import ContainerGroup
from azure.mgmt.containerinstance.models import (
ContainerGroup,
ContainerPropertiesInstanceView,
ContainerState,
Event,
)


class AzureContainerInstanceHook(AzureBaseHook):
Expand Down Expand Up @@ -67,23 +71,6 @@ def get_conn(self) -> Any:
"""
conn = self.get_connection(self.conn_id)
tenant = conn.extra_dejson.get("tenantId")
if not tenant and conn.extra_dejson.get("extra__azure__tenantId"):
warnings.warn(
"`extra__azure__tenantId` is deprecated in azure connection extra, "
"please use `tenantId` instead",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
tenant = conn.extra_dejson.get("extra__azure__tenantId")
subscription_id = conn.extra_dejson.get("subscriptionId")
if not subscription_id and conn.extra_dejson.get("extra__azure__subscriptionId"):
warnings.warn(
"`extra__azure__subscriptionId` is deprecated in azure connection extra, "
"please use `subscriptionId` instead",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
subscription_id = conn.extra_dejson.get("extra__azure__subscriptionId")

key_path = conn.extra_dejson.get("key_path")
if key_path:
Expand All @@ -97,16 +84,17 @@ def get_conn(self) -> Any:
self.log.info("Getting connection using a JSON config.")
return get_client_from_json_dict(client_class=self.sdk_client, config_dict=key_json)

credential: ServicePrincipalCredentials | DefaultAzureCredential
credential: ClientSecretCredential | DefaultAzureCredential
if all([conn.login, conn.password, tenant]):
self.log.info("Getting connection using specific credentials and subscription_id.")
credential = ServicePrincipalCredentials(
client_id=conn.login, secret=conn.password, tenant=tenant
credential = ClientSecretCredential(
client_id=conn.login, client_secret=conn.password, tenant_id=cast(str, tenant)
)
else:
self.log.info("Using DefaultAzureCredential as credential")
credential = DefaultAzureCredential()

subscription_id = cast(str, conn.extra_dejson.get("subscriptionId"))
return ContainerInstanceManagementClient(
credential=credential,
subscription_id=subscription_id,
Expand Down Expand Up @@ -137,8 +125,10 @@ def get_state_exitcode_details(self, resource_group: str, name: str) -> tuple:
stacklevel=2,
)
cg_state = self.get_state(resource_group, name)
c_state = cg_state.containers[0].instance_view.current_state
return (c_state.state, c_state.exit_code, c_state.detail_status)
container = cg_state.containers[0]
instance_view: ContainerPropertiesInstanceView = container.instance_view # type: ignore[assignment]
c_state: ContainerState = instance_view.current_state # type: ignore[assignment]
return c_state.state, c_state.exit_code, c_state.detail_status

def get_messages(self, resource_group: str, name: str) -> list:
"""
Expand All @@ -154,8 +144,10 @@ def get_messages(self, resource_group: str, name: str) -> list:
stacklevel=2,
)
cg_state = self.get_state(resource_group, name)
instance_view = cg_state.containers[0].instance_view
return [event.message for event in instance_view.events]
container = cg_state.containers[0]
instance_view: ContainerPropertiesInstanceView = container.instance_view # type: ignore[assignment]
events: list[Event] = instance_view.events # type: ignore[assignment]
return [event.message for event in events]

def get_state(self, resource_group: str, name: str) -> ContainerGroup:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@
from azure.mgmt.containerinstance.models import (
Container,
ContainerGroup,
ContainerGroupNetworkProfile,
ContainerPort,
EnvironmentVariable,
IpAddress,
ResourceRequests,
ResourceRequirements,
Volume as _AzureVolume,
VolumeMount,
)
from msrestazure.azure_exceptions import CloudError
Expand All @@ -44,13 +44,11 @@
if TYPE_CHECKING:
from airflow.utils.context import Context


Volume = namedtuple(
"Volume",
["conn_id", "account_name", "share_name", "mount_path", "read_only"],
)


DEFAULT_ENVIRONMENT_VARIABLES: dict[str, str] = {}
DEFAULT_SECURED_VARIABLES: Sequence[str] = []
DEFAULT_VOLUMES: Sequence[Volume] = []
Expand Down Expand Up @@ -90,7 +88,6 @@ class AzureContainerInstancesOperator(BaseOperator):
:param restart_policy: Restart policy for all containers within the container group.
Possible values include: 'Always', 'OnFailure', 'Never'
:param ip_address: The IP address type of the container group.
:param network_profile: The network profile information for a container group.
**Example**::
Expand Down Expand Up @@ -145,7 +142,6 @@ def __init__(
restart_policy: str = "Never",
ip_address: IpAddress | None = None,
ports: list[ContainerPort] | None = None,
network_profile: ContainerGroupNetworkProfile | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -183,7 +179,6 @@ def __init__(
)
self.ip_address = ip_address
self.ports = ports
self.network_profile = network_profile

def execute(self, context: Context) -> int:
# Check name again in case it was templated.
Expand Down Expand Up @@ -212,7 +207,7 @@ def execute(self, context: Context) -> int:
e = EnvironmentVariable(name=key, value=value)
environment_variables.append(e)

volumes: list[Volume | Volume] = []
volumes: list[_AzureVolume] = []
volume_mounts: list[VolumeMount | VolumeMount] = []
for conn_id, account_name, share_name, mount_path, read_only in self.volumes:
hook = AzureContainerVolumeHook(conn_id)
Expand Down Expand Up @@ -256,7 +251,6 @@ def execute(self, context: Context) -> int:
os_type=self.os_type,
tags=self.tags,
ip_address=self.ip_address,
network_profile=self.network_profile,
)

self._ci_hook.create_or_update(self.resource_group, self.name, container_group)
Expand Down
4 changes: 1 addition & 3 deletions airflow/providers/microsoft/azure/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@ dependencies:
- azure-kusto-data>=4.1.0
- azure-mgmt-datafactory>=2.0.0
- azure-mgmt-containerregistry>=8.0.0
# TODO: upgrade to newer versions of all the below libraries.
# See issue https://github.com/apache/airflow/issues/30199
- azure-mgmt-containerinstance>=7.0.0,<9.0.0
- azure-mgmt-containerinstance>=9.0.0

integrations:
- integration-name: Microsoft Azure Batch
Expand Down
2 changes: 1 addition & 1 deletion generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@
"azure-identity>=1.3.1",
"azure-keyvault-secrets>=4.1.0",
"azure-kusto-data>=4.1.0",
"azure-mgmt-containerinstance>=7.0.0,<9.0.0",
"azure-mgmt-containerinstance>=9.0.0",
"azure-mgmt-containerregistry>=8.0.0",
"azure-mgmt-cosmosdb",
"azure-mgmt-datafactory>=2.0.0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ def setup_test_cases(self, create_mock_connection):
conn_type="azure_container_instances",
login="login",
password="key",
extra={"tenantId": "tenant_id", "subscriptionId": "subscription_id"},
extra={
"tenantId": "63e85d06-62e4-11ee-8c99-0242ac120002",
"subscriptionId": "63e85d06-62e4-11ee-8c99-0242ac120003",
},
)
)
self.resources = ResourceRequirements(requests=ResourceRequests(memory_in_gb="4", cpu="1"))
Expand Down

0 comments on commit 9ee14a0

Please sign in to comment.