diff --git a/airflow/providers/amazon/aws/operators/ec2.py b/airflow/providers/amazon/aws/operators/ec2.py index b9de533378324..2dbb6986d7092 100644 --- a/airflow/providers/amazon/aws/operators/ec2.py +++ b/airflow/providers/amazon/aws/operators/ec2.py @@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Sequence +from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook @@ -254,3 +255,128 @@ def execute(self, context: Context): "MaxAttempts": self.max_attempts, }, ) + + +class EC2RebootInstanceOperator(BaseOperator): + """ + Reboot Amazon EC2 instances. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:EC2RebootInstanceOperator` + + :param instance_ids: ID of the instance(s) to be rebooted. + :param aws_conn_id: AWS connection to use + :param region_name: AWS region name associated with the client. + :param poll_interval: Number of seconds to wait before attempting to + check state of instance. Only used if wait_for_completion is True. Default is 20. + :param max_attempts: Maximum number of attempts when checking state of instance. + Only used if wait_for_completion is True. Default is 20. + :param wait_for_completion: If True, the operator will wait for the instance to be + in the `running` state before returning. + """ + + template_fields: Sequence[str] = ("instance_ids", "region_name") + ui_color = "#eeaa11" + ui_fgcolor = "#ffffff" + + def __init__( + self, + *, + instance_ids: str | list[str], + aws_conn_id: str = "aws_default", + region_name: str | None = None, + poll_interval: int = 20, + max_attempts: int = 20, + wait_for_completion: bool = False, + **kwargs, + ): + super().__init__(**kwargs) + self.instance_ids = instance_ids + self.aws_conn_id = aws_conn_id + self.region_name = region_name + self.poll_interval = poll_interval + self.max_attempts = max_attempts + self.wait_for_completion = wait_for_completion + + def execute(self, context: Context): + if isinstance(self.instance_ids, str): + self.instance_ids = [self.instance_ids] + ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name, api_type="client_type") + self.log.info("Rebooting EC2 instances %s", ", ".join(self.instance_ids)) + ec2_hook.conn.reboot_instances(InstanceIds=self.instance_ids) + + if self.wait_for_completion: + ec2_hook.get_waiter("instance_running").wait( + InstanceIds=self.instance_ids, + WaiterConfig={ + "Delay": self.poll_interval, + "MaxAttempts": self.max_attempts, + }, + ) + + +class EC2HibernateInstanceOperator(BaseOperator): + """ + Hibernate Amazon EC2 instances. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:EC2HibernateInstanceOperator` + + :param instance_ids: ID of the instance(s) to be hibernated. + :param aws_conn_id: AWS connection to use + :param region_name: AWS region name associated with the client. + :param poll_interval: Number of seconds to wait before attempting to + check state of instance. Only used if wait_for_completion is True. Default is 20. + :param max_attempts: Maximum number of attempts when checking state of instance. + Only used if wait_for_completion is True. Default is 20. + :param wait_for_completion: If True, the operator will wait for the instance to be + in the `stopped` state before returning. + """ + + template_fields: Sequence[str] = ("instance_ids", "region_name") + ui_color = "#eeaa11" + ui_fgcolor = "#ffffff" + + def __init__( + self, + *, + instance_ids: str | list[str], + aws_conn_id: str = "aws_default", + region_name: str | None = None, + poll_interval: int = 20, + max_attempts: int = 20, + wait_for_completion: bool = False, + **kwargs, + ): + super().__init__(**kwargs) + self.instance_ids = instance_ids + self.aws_conn_id = aws_conn_id + self.region_name = region_name + self.poll_interval = poll_interval + self.max_attempts = max_attempts + self.wait_for_completion = wait_for_completion + + def execute(self, context: Context): + if isinstance(self.instance_ids, str): + self.instance_ids = [self.instance_ids] + ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name, api_type="client_type") + self.log.info("Hibernating EC2 instances %s", ", ".join(self.instance_ids)) + instances = ec2_hook.get_instances(instance_ids=self.instance_ids) + + for instance in instances: + hibernation_options = instance.get("HibernationOptions") + if not hibernation_options or not hibernation_options["Configured"]: + raise AirflowException(f"Instance {instance['InstanceId']} is not configured for hibernation") + + ec2_hook.conn.stop_instances(InstanceIds=self.instance_ids, Hibernate=True) + + if self.wait_for_completion: + ec2_hook.get_waiter("instance_stopped").wait( + InstanceIds=self.instance_ids, + WaiterConfig={ + "Delay": self.poll_interval, + "MaxAttempts": self.max_attempts, + }, + ) diff --git a/docs/apache-airflow-providers-amazon/operators/ec2.rst b/docs/apache-airflow-providers-amazon/operators/ec2.rst index 2018d8113fcf9..e5462b32a18f4 100644 --- a/docs/apache-airflow-providers-amazon/operators/ec2.rst +++ b/docs/apache-airflow-providers-amazon/operators/ec2.rst @@ -86,6 +86,34 @@ To terminate an Amazon EC2 instance you can use :start-after: [START howto_operator_ec2_terminate_instance] :end-before: [END howto_operator_ec2_terminate_instance] +.. _howto/operator:EC2RebootInstanceOperator: + +Reboot an Amazon EC2 instance +================================ + +To reboot an Amazon EC2 instance you can use +:class:`~airflow.providers.amazon.aws.operators.ec2.EC2RebootInstanceOperator`. + +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_ec2.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_ec2_reboot_instance] + :end-before: [END howto_operator_ec2_reboot_instance] + +.. _howto/operator:EC2HibernateInstanceOperator: + +Hibernate an Amazon EC2 instance +================================ + +To hibernate an Amazon EC2 instance you can use +:class:`~airflow.providers.amazon.aws.operators.ec2.EC2HibernateInstanceOperator`. + +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_ec2.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_ec2_hibernate_instance] + :end-before: [END howto_operator_ec2_hibernate_instance] + Sensors ------- diff --git a/tests/providers/amazon/aws/operators/test_ec2.py b/tests/providers/amazon/aws/operators/test_ec2.py index adf3ffeb91cc0..b11d72b71445b 100644 --- a/tests/providers/amazon/aws/operators/test_ec2.py +++ b/tests/providers/amazon/aws/operators/test_ec2.py @@ -17,11 +17,15 @@ # under the License. from __future__ import annotations +import pytest from moto import mock_ec2 +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook from airflow.providers.amazon.aws.operators.ec2 import ( EC2CreateInstanceOperator, + EC2HibernateInstanceOperator, + EC2RebootInstanceOperator, EC2StartInstanceOperator, EC2StopInstanceOperator, EC2TerminateInstanceOperator, @@ -205,3 +209,166 @@ def test_stop_instance(self): stop_test.execute(None) # assert instance state is running assert ec2_hook.get_instance_state(instance_id=instance_id[0]) == "stopped" + + +class TestEC2HibernateInstanceOperator(BaseEc2TestClass): + def test_init(self): + ec2_operator = EC2HibernateInstanceOperator( + task_id="task_test", + instance_ids="i-123abc", + ) + assert ec2_operator.task_id == "task_test" + assert ec2_operator.instance_ids == "i-123abc" + + @mock_ec2 + def test_hibernate_instance(self): + # create instance + ec2_hook = EC2Hook() + create_instance = EC2CreateInstanceOperator( + image_id=self._get_image_id(ec2_hook), + task_id="test_create_instance", + config={"HibernationOptions": {"Configured": True}}, + ) + instance_id = create_instance.execute(None) + + # hibernate instance + hibernate_test = EC2HibernateInstanceOperator( + task_id="hibernate_test", + instance_ids=instance_id[0], + ) + hibernate_test.execute(None) + # assert instance state is stopped + assert ec2_hook.get_instance_state(instance_id=instance_id[0]) == "stopped" + + @mock_ec2 + def test_hibernate_multiple_instances(self): + ec2_hook = EC2Hook() + create_instances = EC2CreateInstanceOperator( + task_id="test_create_multiple_instances", + image_id=self._get_image_id(hook=ec2_hook), + config={"HibernationOptions": {"Configured": True}}, + min_count=5, + max_count=5, + ) + instance_ids = create_instances.execute(None) + assert len(instance_ids) == 5 + + for id in instance_ids: + assert ec2_hook.get_instance_state(instance_id=id) == "running" + + hibernate_instance = EC2HibernateInstanceOperator( + task_id="test_hibernate_instance", instance_ids=instance_ids + ) + hibernate_instance.execute(None) + for id in instance_ids: + assert ec2_hook.get_instance_state(instance_id=id) == "stopped" + + @mock_ec2 + def test_cannot_hibernate_instance(self): + # create instance + ec2_hook = EC2Hook() + create_instance = EC2CreateInstanceOperator( + image_id=self._get_image_id(ec2_hook), + task_id="test_create_instance", + ) + instance_id = create_instance.execute(None) + + # hibernate instance + hibernate_test = EC2HibernateInstanceOperator( + task_id="hibernate_test", + instance_ids=instance_id[0], + ) + + # assert hibernating an instance not configured for hibernation raises an error + with pytest.raises( + AirflowException, + match="Instance .* is not configured for hibernation", + ): + hibernate_test.execute(None) + + # assert instance state is running + assert ec2_hook.get_instance_state(instance_id=instance_id[0]) == "running" + + @mock_ec2 + def test_cannot_hibernate_some_instances(self): + # create instance + ec2_hook = EC2Hook() + create_instance_hibernate = EC2CreateInstanceOperator( + image_id=self._get_image_id(ec2_hook), + task_id="test_create_instance", + config={"HibernationOptions": {"Configured": True}}, + ) + instance_id_hibernate = create_instance_hibernate.execute(None) + create_instance_cannot_hibernate = EC2CreateInstanceOperator( + image_id=self._get_image_id(ec2_hook), + task_id="test_create_instance", + ) + instance_id_cannot_hibernate = create_instance_cannot_hibernate.execute(None) + instance_ids = [instance_id_hibernate[0], instance_id_cannot_hibernate[0]] + + # hibernate instance + hibernate_test = EC2HibernateInstanceOperator( + task_id="hibernate_test", + instance_ids=instance_ids, + ) + # assert hibernating an instance not configured for hibernation raises an error + with pytest.raises( + AirflowException, + match="Instance .* is not configured for hibernation", + ): + hibernate_test.execute(None) + + # assert instance state is running + for id in instance_ids: + assert ec2_hook.get_instance_state(instance_id=id) == "running" + + +class TestEC2RebootInstanceOperator(BaseEc2TestClass): + def test_init(self): + ec2_operator = EC2RebootInstanceOperator( + task_id="task_test", + instance_ids="i-123abc", + ) + assert ec2_operator.task_id == "task_test" + assert ec2_operator.instance_ids == "i-123abc" + + @mock_ec2 + def test_reboot_instance(self): + # create instance + ec2_hook = EC2Hook() + create_instance = EC2CreateInstanceOperator( + image_id=self._get_image_id(ec2_hook), + task_id="test_create_instance", + ) + instance_id = create_instance.execute(None) + + # reboot instance + reboot_test = EC2RebootInstanceOperator( + task_id="reboot_test", + instance_ids=instance_id[0], + ) + reboot_test.execute(None) + # assert instance state is running + assert ec2_hook.get_instance_state(instance_id=instance_id[0]) == "running" + + @mock_ec2 + def test_reboot_multiple_instances(self): + ec2_hook = EC2Hook() + create_instances = EC2CreateInstanceOperator( + task_id="test_create_multiple_instances", + image_id=self._get_image_id(hook=ec2_hook), + min_count=5, + max_count=5, + ) + instance_ids = create_instances.execute(None) + assert len(instance_ids) == 5 + + for id in instance_ids: + assert ec2_hook.get_instance_state(instance_id=id) == "running" + + terminate_instance = EC2RebootInstanceOperator( + task_id="test_reboot_instance", instance_ids=instance_ids + ) + terminate_instance.execute(None) + for id in instance_ids: + assert ec2_hook.get_instance_state(instance_id=id) == "running" diff --git a/tests/system/providers/amazon/aws/example_ec2.py b/tests/system/providers/amazon/aws/example_ec2.py index 506d73908b224..aeffe4ad34ff7 100644 --- a/tests/system/providers/amazon/aws/example_ec2.py +++ b/tests/system/providers/amazon/aws/example_ec2.py @@ -26,6 +26,8 @@ from airflow.models.dag import DAG from airflow.providers.amazon.aws.operators.ec2 import ( EC2CreateInstanceOperator, + EC2HibernateInstanceOperator, + EC2RebootInstanceOperator, EC2StartInstanceOperator, EC2StopInstanceOperator, EC2TerminateInstanceOperator, @@ -103,6 +105,7 @@ def parse_response(instance_ids: list): # Use IMDSv2 for greater security, see the following doc for more details: # https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-instance-metadata-service.html "MetadataOptions": {"HttpEndpoint": "enabled", "HttpTokens": "required"}, + "HibernationOptions": {"Configured": True}, } # EC2CreateInstanceOperator creates and starts the EC2 instances. To test the EC2StartInstanceOperator, @@ -142,6 +145,20 @@ def parse_response(instance_ids: list): ) # [END howto_sensor_ec2_instance_state] + # [START howto_operator_ec2_reboot_instance] + reboot_instance = EC2RebootInstanceOperator( + task_id="reboot_instace", + instance_ids=instance_id, + ) + # [END howto_operator_ec2_reboot_instance] + + # [START howto_operator_ec2_hibernate_instance] + hibernate_instance = EC2HibernateInstanceOperator( + task_id="hibernate_instace", + instance_ids=instance_id, + ) + # [END howto_operator_ec2_hibernate_instance] + # [START howto_operator_ec2_terminate_instance] terminate_instance = EC2TerminateInstanceOperator( task_id="terminate_instance", @@ -161,6 +178,8 @@ def parse_response(instance_ids: list): stop_instance, start_instance, await_instance, + reboot_instance, + hibernate_instance, terminate_instance, # TEST TEARDOWN delete_key_pair(key_name),