Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Aws batch agent #1809

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions doc-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ great-expectations==0.17.6
# via -r doc-requirements.in
greenlet==2.0.2
# via sqlalchemy
grpcio==1.56.2
grpcio==1.51.1
# via
# -r doc-requirements.in
# flytekit
Expand All @@ -342,7 +342,7 @@ grpcio==1.56.2
# ray
# tensorboard
# tensorflow
grpcio-status==1.56.2
grpcio-status==1.51.1
# via
# flytekit
# google-api-core
Expand Down
17 changes: 15 additions & 2 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import os
import pathlib
import tempfile
import typing
from dataclasses import dataclass
from typing import cast
Expand Down Expand Up @@ -582,6 +583,12 @@ def get_workflow_command_base_params() -> typing.List[click.Option]:
type=JsonParamType(),
help="Environment variables to set in the container",
),
click.Option(
param_decls=["--output-prefix", "output_prefix"],
required=False,
type=str,
help="Where to store the task output",
),
]


Expand Down Expand Up @@ -662,12 +669,18 @@ def _run(*args, **kwargs):

run_level_params = ctx.obj[RUN_LEVEL_PARAMS_KEY]
project, domain = run_level_params.get("project"), run_level_params.get("domain")
output_prefix = run_level_params.get("output_prefix")
inputs = {}
for input_name, _ in entity.python_interface.inputs.items():
inputs[input_name] = kwargs.get(input_name)

if not ctx.obj[REMOTE_FLAG_KEY]:
output = entity(**inputs)
output_prefix = output_prefix if output_prefix else tempfile.mkdtemp(prefix="raw")
file_access = FileAccessProvider(
local_sandbox_dir=tempfile.mkdtemp(prefix="flyte"), raw_output_prefix=output_prefix
)
with FlyteContextManager.with_context(FlyteContextManager.current_context().with_file_access(file_access)):
output = entity(**inputs)
click.echo(output)
if ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_FILE_NAME):
os.remove(ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_FILE_NAME))
Expand Down Expand Up @@ -695,7 +708,7 @@ def _run(*args, **kwargs):
if service_account:
# options are only passed for the execution. This is to prevent errors when registering a duplicate workflow
# It is assumed that the users expectations is to override the service account only for the execution
options = Options.default_from(k8s_service_account=service_account)
options = Options.default_from(k8s_service_account=service_account, raw_data_prefix=output_prefix)

execution = remote.execute(
remote_entity,
Expand Down
2 changes: 0 additions & 2 deletions flytekit/core/python_function_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
PythonInstanceTask

"""


from abc import ABC
from collections import OrderedDict
from enum import Enum
Expand Down
50 changes: 46 additions & 4 deletions flytekit/extend/backend/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import grpc
from flyteidl.admin.agent_pb2 import (
PENDING,
PERMANENT_FAILURE,
RETRYABLE_FAILURE,
RUNNING,
Expand All @@ -17,10 +18,12 @@
from flyteidl.core.tasks_pb2 import TaskTemplate
from rich.progress import Progress

from flytekit import FlyteContext, logger
from flytekit import FlyteContext, PythonFunctionTask, logger
from flytekit.configuration import ImageConfig, SerializationSettings
from flytekit.core import utils
from flytekit.core.base_task import PythonTask
from flytekit.core.type_engine import TypeEngine
from flytekit.exceptions import scopes as exception_scopes
from flytekit.models.literals import LiteralMap


Expand Down Expand Up @@ -111,6 +114,8 @@ def convert_to_flyte_state(state: str) -> State:
return SUCCEEDED
elif state in ["running"]:
return RUNNING
elif state in ["submitted", "pending", "starting", "runnable"]:
return PENDING
raise ValueError(f"Unrecognized state: {state}")


Expand All @@ -128,6 +133,20 @@ class AsyncAgentExecutorMixin:
"""

def execute(self, **kwargs) -> typing.Any:
ctx = FlyteContext.current_context()
output_prefix = ctx.file_access.get_random_remote_directory()
print(output_prefix)

# If the task is a PythonFunctionTask, we can run it locally or remotely (e.g. AWS batch, ECS).
# If the output location is remote, we will use the agent to run the task, and
# the agent will write intermediate outputs to the blob store.
if getattr(self, "_task_function", None) and not ctx.file_access.is_remote(output_prefix):
entity = typing.cast(PythonFunctionTask, self)
if entity.execution_mode == entity.ExecutionBehavior.DEFAULT:
return exception_scopes.user_entry_point(entity.task_function)(**kwargs)
elif entity.execution_mode == entity.ExecutionBehavior.DYNAMIC:
return entity.dynamic_execute(entity.task_function, **kwargs)

from unittest.mock import MagicMock

from flytekit.tools.translator import get_serializable
Expand All @@ -140,13 +159,19 @@ def execute(self, **kwargs) -> typing.Any:

if agent is None:
raise Exception("Cannot run the task locally, please mock.")

literals = {}
ctx = FlyteContext.current_context()
for k, v in kwargs.items():
literals[k] = TypeEngine.to_literal(ctx, v, type(v), entity.interface.inputs[k].type)
inputs = LiteralMap(literals) if literals else None
output_prefix = ctx.file_access.get_random_local_directory()
cp_entity = get_serializable(m, settings=SerializationSettings(ImageConfig()), entity=entity)

if inputs:
print("Writing inputs to file")
path = ctx.file_access.get_random_local_path()
utils.write_proto_to_file(inputs.to_flyte_idl(), path)
# ctx.file_access.put_data(path, f"{file_prefix}/inputs.pb")
cp_entity._template = render_task_template(cp_entity.template, output_prefix)

res = agent.create(dummy_context, output_prefix, cp_entity.template, inputs)
state = RUNNING
metadata = res.resource_meta
Expand All @@ -163,4 +188,21 @@ def execute(self, **kwargs) -> typing.Any:
if state != SUCCEEDED:
raise Exception(f"Failed to run the task {entity.name}")

if res.resource.outputs is None:
local_outputs_file = ctx.file_access.get_random_local_path()
# ctx.file_access.get_data(f"{output_prefix}/outputs.pb", local_outputs_file)
# output_proto = utils.load_proto_from_file(literals_pb2.LiteralMap, local_outputs_file)
# return LiteralMap.from_flyte_idl(output_proto)

return LiteralMap.from_flyte_idl(res.resource.outputs)


def render_task_template(tt: TaskTemplate, file_prefix: str) -> TaskTemplate:
args = tt.container.args
for i in range(len(args)):
tt.container.args[i] = args[i].replace("{{.input}}", f"{file_prefix}/inputs.pb")
tt.container.args[i] = args[i].replace("{{.outputPrefix}}", f"{file_prefix}/output")
tt.container.args[i] = args[i].replace("{{.rawOutputDataPrefix}}", f"{file_prefix}/raw_output")
tt.container.args[i] = args[i].replace("{{.checkpointOutputPrefix}}", f"{file_prefix}/checkpoint_output")
tt.container.args[i] = args[i].replace("{{.prevCheckpointPrefix}}", f"{file_prefix}/prev_checkpoint")
return tt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@
AWSBatchConfig
"""

from .agent import AWSBatchAgent
from .task import AWSBatchConfig
66 changes: 66 additions & 0 deletions plugins/flytekit-aws-batch/flytekitplugins/awsbatch/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from dataclasses import dataclass
from typing import Optional

import cloudpickle
import grpc
from flyteidl.admin.agent_pb2 import CreateTaskResponse, DeleteTaskResponse, GetTaskResponse, Resource
import boto3
from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, convert_to_flyte_state
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate


@dataclass
class Metadata:
job_id: str


class AWSBatchAgent(AgentBase):
def __init__(self):
super().__init__(task_type="aws-batch")

def _get_client(self):
"""
Get a boto3 client for AWS Batch
:rtype: boto3.client
"""
return boto3.client('batch', region_name='us-west-2')

def create(
self,
context: grpc.ServicerContext,
output_prefix: str,
task_template: TaskTemplate,
inputs: Optional[LiteralMap] = None,
) -> CreateTaskResponse:
client = self._get_client()
resources = task_template.container.resources
print(resources.requests)
container_properties = {
'image': 'pingsutw/flyte-app:65316e88460657fa28f46b75067de5b3',
'vcpus': 1,
'memory': 512,
# ... other container properties ...
}
# response = client.register_job_definition(jobDefinitionName="flyte-batch",
# type="container",
# containerProperties=container_properties)
# response = client.submit_job(jobName="test", jobQueue="flyte-test", jobDefinition=response['jobDefinitionName'])
return CreateTaskResponse(resource_meta=cloudpickle.dumps("b9b5bbfb-a85c-416f-b157-491189ce8f7c"))

def get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse:
client = self._get_client()
job_id = cloudpickle.loads(resource_meta)
response = client.describe_jobs(jobs=[job_id])
status = response['jobs'][0]['status']
cur_state = convert_to_flyte_state(status)
return GetTaskResponse(resource=Resource(state=cur_state))

def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse:
client = self._get_client()
job_id = cloudpickle.loads(resource_meta)
client.terminate_job(jobId=job_id, reason='Cancelling job.')
return DeleteTaskResponse()


AgentRegistry.register(AWSBatchAgent())
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from flytekit import PythonFunctionTask
from flytekit.configuration import SerializationSettings
from flytekit.extend import TaskPlugins
from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin


@dataclass_json
Expand All @@ -31,7 +32,7 @@ def to_dict(self):
return json_format.MessageToDict(s)


class AWSBatchFunctionTask(PythonFunctionTask):
class AWSBatchFunctionTask(AsyncAgentExecutorMixin, PythonFunctionTask):
"""
Actual Plugin that transforms the local python code for execution within AWS batch job
"""
Expand Down
1 change: 1 addition & 0 deletions plugins/flytekit-aws-batch/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,5 @@
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
],
entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]},
)
2 changes: 2 additions & 0 deletions plugins/flytekit-envd/flytekitplugins/envd/image_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def build():
envd_config += f' install.cuda(version="{image_spec.cuda}", cudnn="{cudnn}")\n'

if image_spec.source_root:
print(ctx.execution_state)
print(ctx.compilation_state)
shutil.copytree(image_spec.source_root, pathlib.Path(cfg_path).parent, dirs_exist_ok=True)
# Indentation is required by envd
envd_config += ' io.copy(host_path="./", envd_path="/root")'
Expand Down