Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate SageMakerTrainingOperatorAsync #1463

Merged
merged 5 commits into from
Feb 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 13 additions & 143 deletions astronomer/providers/amazon/aws/operators/sagemaker.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,16 @@
from __future__ import annotations

import json
import time
import warnings
from typing import Any

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import (
LogState,
secondary_training_status_message,
)
from airflow.providers.amazon.aws.operators.sagemaker import (
SageMakerProcessingOperator,
SageMakerTrainingOperator,
SageMakerTransformOperator,
)
from airflow.utils.json import AirflowJsonEncoder

from astronomer.providers.amazon.aws.triggers.sagemaker import (
SagemakerTrainingWithLogTrigger,
SagemakerTrigger,
)
from astronomer.providers.utils.typing_compat import Context


def serialize(result: dict[str, Any]) -> str:
"""Serialize any objects coming from Sagemaker API response to json string"""
Expand Down Expand Up @@ -71,137 +59,19 @@ def __init__(self, **kwargs: Any) -> None:

class SageMakerTrainingOperatorAsync(SageMakerTrainingOperator):
"""
SageMakerTrainingOperatorAsync starts a model training job and polls for the status asynchronously.
After training completes, Amazon SageMaker saves the resulting model artifacts to an Amazon S3 location
that you specify.

.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:``howto/operator:SageMakerTrainingOperator``

:param config: The configuration necessary to start a training job (templated).
For details of the configuration parameter see ``SageMaker.Client.create_training_job``
:param aws_conn_id: The AWS connection ID to use.
:param print_log: if the operator should print the cloudwatch log during training
:param check_interval: if wait is set to be true, this is the time interval
in seconds which the operator will check the status of the training job
:param max_ingestion_time: The operation fails if the training job
doesn't finish within max_ingestion_time seconds. If you set this parameter to None,
the operation does not timeout.
:param check_if_job_exists: If set to true, then the operator will check whether a training job
already exists for the name in the config.
:param action_if_job_exists: Behaviour if the job name already exists. Possible options are "increment"
(default) and "fail".
This is only relevant if check_if_job_exists is True.
This class is deprecated.
Please use :class: `~airflow.providers.amazon.aws.operators.sagemaker.SageMakerTrainingOperator`
and set `deferrable` param to `True` instead.
"""

def execute(self, context: Context) -> dict[str, Any] | None: # type: ignore[override]
"""
Creates SageMaker training job via sync hook `create_training_job` and pass the
control to trigger and polls for the status of the training job in async
"""
self.preprocess_config()
if self.check_if_job_exists: # pragma: no cover
try:
# for apache-airflow-providers-amazon<=7.2.1
self._check_if_job_exists() # type: ignore[call-arg]
except TypeError:
# for apache-airflow-providers-amazon>=7.3.0
self.config["TrainingJobName"] = self._get_unique_job_name(
self.config["TrainingJobName"],
self.action_if_job_exists == "fail",
self.hook.describe_training_job,
)
self.log.info("Creating SageMaker training job %s.", self.config["TrainingJobName"])
response = self.hook.create_training_job(
self.config,
wait_for_completion=False,
print_log=False,
check_interval=self.check_interval,
max_ingestion_time=self.max_ingestion_time,
def __init__(self, **kwargs: Any) -> None:
warnings.warn(
(
"This module is deprecated."
"Please use `airflow.providers.amazon.aws.operators.sagemaker.SageMakerTrainingOperator`"
"and set `deferrable` param to `True` instead."
),
DeprecationWarning,
stacklevel=2,
)
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
raise AirflowException(f"Sagemaker Training Job creation failed: {response}")

end_time: float | None = None
if self.max_ingestion_time is not None:
end_time = time.time() + self.max_ingestion_time

description = self.hook.describe_training_job(self.config["TrainingJobName"])
status = description["TrainingJobStatus"]
if self.print_log:
instance_count = description["ResourceConfig"]["InstanceCount"]
last_describe_job_call = time.monotonic()
job_already_completed = status not in self.hook.non_terminal_states
_, last_description, last_describe_job_call = self.hook.describe_training_job_with_log(
self.config["TrainingJobName"],
{},
[],
instance_count,
LogState.TAILING if job_already_completed else LogState.COMPLETE,
description,
last_describe_job_call,
)

self.log.info(secondary_training_status_message(description, None))

if status in self.hook.failed_states:
reason = last_description.get("FailureReason", "(No reason provided)")
raise AirflowException(f"SageMaker job failed because {reason}")
elif status == "Completed":
billable_time = (
last_description["TrainingEndTime"] - last_description["TrainingStartTime"]
) * instance_count
self.log.info(
f"Billable seconds: {int(billable_time.total_seconds()) + 1}\n"
f"{self.task_id} completed successfully."
)
return {"Training": serialize(description)}

self.defer(
timeout=self.execution_timeout,
trigger=SagemakerTrainingWithLogTrigger(
poke_interval=self.check_interval,
end_time=end_time,
aws_conn_id=self.aws_conn_id,
job_name=self.config["TrainingJobName"],
instance_count=int(instance_count),
status=status,
),
method_name="execute_complete",
)
else:
if status in self.hook.failed_states:
raise AirflowException(f"SageMaker job failed because {description['FailureReason']}")
elif status == "Completed":
self.log.info(f"{self.task_id} completed successfully.")
return {"Training": serialize(description)}

self.defer(
timeout=self.execution_timeout,
trigger=SagemakerTrigger(
poke_interval=self.check_interval,
end_time=end_time,
aws_conn_id=self.aws_conn_id,
job_name=self.config["TrainingJobName"],
job_type="Training",
response_key="TrainingJobStatus",
),
method_name="execute_complete",
)

# for bypassing mypy missing return error
return None # pragma: no cover

def execute_complete(self, context: Context, event: dict[str, Any]) -> dict[str, Any]: # type: ignore[override]
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event and event["status"] == "success":
self.log.info("%s completed successfully.", self.task_id)
return {"Training": serialize(event["message"])}
if event and event["status"] == "error":
raise AirflowException(event["message"])
raise AirflowException("No event received in trigger callback")
super().__init__(deferrable=True, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we pass args too?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not needed, the upstream operator has no positional args. It only has named/keyword args.

11 changes: 11 additions & 0 deletions astronomer/providers/amazon/aws/triggers/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ class SagemakerTrainingWithLogTrigger(BaseTrigger):
"""
SagemakerTrainingWithLogTrigger is fired as deferred class with params to run the task in triggerer.

This class is deprecated and will be removed in 2.0.0.
Lee-W marked this conversation as resolved.
Show resolved Hide resolved
Use :class: `~airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrainingPrintLogTrigger` instead

:param job_name: name of the job to check status
:param instance_count: count of the instance created for running the training job
:param status: The status of the training job created.
Expand All @@ -209,6 +212,14 @@ def __init__(
end_time: float | None = None,
aws_conn_id: str = "aws_default",
):
warnings.warn(
(
"This module is deprecated and will be removed in 2.0.0."
"Please use `airflow.providers.amazon.aws.hooks.sagemaker.SageMakerTrainingPrintLogTrigger`"
),
DeprecationWarning,
stacklevel=2,
)
super().__init__()
self.job_name = job_name
self.instance_count = instance_count
Expand Down
6 changes: 4 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ zip_safe = false

[options.extras_require]
amazon =
apache-airflow-providers-amazon>=8.17.0
# Update version when the below RC is released
apache-airflow-providers-amazon>=8.18.0rc2
aiobotocore>=2.1.1
apache.hive =
apache-airflow-providers-apache-hive>=6.1.5
Expand Down Expand Up @@ -118,7 +119,8 @@ mypy =
# All extras from above except 'mypy', 'docs' and 'tests'
all =
aiobotocore>=2.1.1
apache-airflow-providers-amazon>=8.17.0
# Update version when the below RC is released
apache-airflow-providers-amazon>=8.18.0rc2
apache-airflow-providers-apache-hive>=6.1.5
apache-airflow-providers-apache-livy>=3.7.1
apache-airflow-providers-cncf-kubernetes>=4
Expand Down
Loading
Loading