Skip to content

Commit

Permalink
Add task execution metadata to agent create (flyteorg#2282)
Browse files Browse the repository at this point in the history
Signed-off-by: noahjax <[email protected]>
Signed-off-by: Kevin Su <[email protected]>
Co-authored-by: Kevin Su <[email protected]>
  • Loading branch information
noahjax and pingsutw authored Mar 26, 2024
1 parent 6a63c1f commit 133e8d5
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 11 deletions.
4 changes: 3 additions & 1 deletion flytekit/extend/backend/agent_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from flytekit.exceptions.system import FlyteAgentNotFound
from flytekit.extend.backend.base_agent import AgentRegistry, SyncAgentBase, mirror_async_methods
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate
from flytekit.models.task import TaskExecutionMetadata, TaskTemplate

metric_prefix = "flyte_agent_"
create_operation = "create"
Expand Down Expand Up @@ -115,13 +115,15 @@ async def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerCon
template = TaskTemplate.from_flyte_idl(request.template)
inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None
agent = AgentRegistry.get_agent(template.type, template.task_type_version)
task_execution_metadata = TaskExecutionMetadata.from_flyte_idl(request.task_execution_metadata)

logger.info(f"{agent.name} start creating the job")
resource_mata = await mirror_async_methods(
agent.create,
task_template=template,
inputs=inputs,
output_prefix=request.output_prefix,
task_execution_metadata=task_execution_metadata,
)
return CreateTaskResponse(resource_meta=resource_mata.encode())

Expand Down
9 changes: 7 additions & 2 deletions flytekit/extend/backend/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from flytekit.exceptions.user import FlyteUserException
from flytekit.extend.backend.utils import is_terminal_phase, mirror_async_methods, render_task_template
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate
from flytekit.models.task import TaskExecutionMetadata, TaskTemplate


class TaskCategory:
Expand Down Expand Up @@ -146,7 +146,12 @@ def metadata_type(self) -> ResourceMeta:

@abstractmethod
def create(
self, task_template: TaskTemplate, inputs: Optional[LiteralMap], output_prefix: Optional[str], **kwargs
self,
task_template: TaskTemplate,
inputs: Optional[LiteralMap],
output_prefix: Optional[str],
task_execution_metadata: Optional[TaskExecutionMetadata],
**kwargs,
) -> ResourceMeta:
"""
Return a resource meta that can be used to get the status of the task.
Expand Down
4 changes: 2 additions & 2 deletions flytekit/extend/backend/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import functools
import inspect
from typing import Callable, Coroutine

Expand All @@ -11,8 +12,7 @@
def mirror_async_methods(func: Callable, **kwargs) -> Coroutine:
if inspect.iscoroutinefunction(func):
return func(**kwargs)
args = [v for _, v in kwargs.items()]
return asyncio.get_running_loop().run_in_executor(None, func, *args)
return asyncio.get_running_loop().run_in_executor(None, functools.partial(func, **kwargs))


def convert_to_flyte_phase(state: str) -> TaskExecution.Phase:
Expand Down
89 changes: 89 additions & 0 deletions flytekit/models/task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json as _json
import typing

from flyteidl.admin import agent_pb2 as _admin_agent
from flyteidl.admin import task_pb2 as _admin_task
from flyteidl.core import compiler_pb2 as _compiler
from flyteidl.core import literals_pb2 as _literals_pb2
Expand Down Expand Up @@ -518,6 +519,94 @@ def from_flyte_idl(cls, pb2_object):
)


class TaskExecutionMetadata(_common.FlyteIdlEntity):
def __init__(
self,
task_execution_id,
namespace,
labels,
annotations,
k8s_service_account,
environment_variables,
):
"""
Runtime task execution metadata.
:param flytekit.models.core.identifier.TaskExecutionIdentifier task_execution_id: This is generated by the system and uniquely identifies
this execution of the task.
:param Text namespace: This is the namespace the task is executing in.
:param dict[str, str] labels: Labels to use for the execution of this task.
:param dict[str, str] annotations: Annotations to use for the execution of this task.
:param Text k8s_service_account: Service account to use for execution of this task.
:param dict[str, str] environment_variables: Environment variables for this task.
"""
self._task_execution_id = task_execution_id
self._namespace = namespace
self._labels = labels
self._annotations = annotations
self._k8s_service_account = k8s_service_account
self._environment_variables = environment_variables

@property
def task_execution_id(self):
return self._task_execution_id

@property
def namespace(self):
return self._namespace

@property
def labels(self):
return self._labels

@property
def annotations(self):
return self._annotations

@property
def k8s_service_account(self):
return self._k8s_service_account

@property
def environment_variables(self):
return self._environment_variables

def to_flyte_idl(self):
"""
:rtype: flyteidl.admin.agent_pb2.TaskExecutionMetadata
"""
task_execution_metadata = _admin_agent.TaskExecutionMetadata(
task_execution_id=self.task_execution_id.to_flyte_idl(),
namespace=self.namespace,
labels={k: v for k, v in self.labels.items()} if self.labels is not None else None,
annotations={k: v for k, v in self.annotations.items()} if self.annotations is not None else None,
k8s_service_account=self.k8s_service_account,
environment_variables={k: v for k, v in self.environment_variables.items()}
if self.labels is not None
else None,
)
return task_execution_metadata

@classmethod
def from_flyte_idl(cls, pb2_object):
"""
:param flyteidl.admin.agent_pb2.TaskExecutionMetadata pb2_object:
:rtype: TaskExecutionMetadata
"""
return cls(
task_execution_id=_identifier.TaskExecutionIdentifier.from_flyte_idl(pb2_object.task_execution_id),
namespace=pb2_object.namespace,
labels={k: v for k, v in pb2_object.labels.items()} if pb2_object.labels is not None else None,
annotations={k: v for k, v in pb2_object.annotations.items()}
if pb2_object.annotations is not None
else None,
k8s_service_account=pb2_object.k8s_service_account,
environment_variables={k: v for k, v in pb2_object.environment_variables.items()}
if pb2_object.environment_variables is not None
else None,
)


class TaskSpec(_common.FlyteIdlEntity):
def __init__(self, template: TaskTemplate, docs: typing.Optional[Documentation] = None):
"""
Expand Down
50 changes: 44 additions & 6 deletions tests/flytekit/unit/extend/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
TaskCategory,
)
from flyteidl.core.execution_pb2 import TaskExecution, TaskLog
from flyteidl.core.identifier_pb2 import ResourceType

from flytekit import PythonFunctionTask, task
from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings
Expand All @@ -37,8 +38,14 @@
)
from flytekit.extend.backend.utils import convert_to_flyte_phase, get_agent_secret
from flytekit.models import literals
from flytekit.models.core.identifier import (
Identifier,
NodeExecutionIdentifier,
TaskExecutionIdentifier,
WorkflowExecutionIdentifier,
)
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate
from flytekit.models.task import TaskExecutionMetadata, TaskTemplate
from flytekit.tools.translator import get_serializable

dummy_id = "dummy_id"
Expand All @@ -48,6 +55,7 @@
class DummyMetadata(ResourceMeta):
job_id: str
output_path: typing.Optional[str] = None
task_name: typing.Optional[str] = None


class DummyAgent(AsyncAgentBase):
Expand Down Expand Up @@ -77,10 +85,12 @@ async def create(
task_template: TaskTemplate,
inputs: typing.Optional[LiteralMap] = None,
output_prefix: typing.Optional[str] = None,
task_execution_metadata: typing.Optional[TaskExecutionMetadata] = None,
**kwargs,
) -> DummyMetadata:
output_path = f"{output_prefix}/{dummy_id}" if output_prefix else None
return DummyMetadata(job_id=dummy_id, output_path=output_path)
task_name = task_execution_metadata.task_execution_id.task_id.name if task_execution_metadata else "default"
return DummyMetadata(job_id=dummy_id, output_path=output_path, task_name=task_name)

async def get(self, resource_meta: DummyMetadata, **kwargs) -> Resource:
return Resource(phase=TaskExecution.SUCCEEDED, log_links=[TaskLog(name="console", uri="localhost:3000")])
Expand Down Expand Up @@ -136,6 +146,19 @@ def simple_task(i: int):
},
)

task_execution_metadata = TaskExecutionMetadata(
task_execution_id=TaskExecutionIdentifier(
task_id=Identifier(ResourceType.TASK, "project", "domain", "name", "version"),
node_execution_id=NodeExecutionIdentifier("node_id", WorkflowExecutionIdentifier("project", "domain", "name")),
retry_attempt=1,
),
namespace="namespace",
labels={"label_key": "label_val"},
annotations={"annotation_key": "annotation_val"},
k8s_service_account="k8s service account",
environment_variables={"env_var_key": "env_var_val"},
)


def test_dummy_agent():
AgentRegistry.register(DummyAgent(), override=True)
Expand All @@ -161,20 +184,35 @@ def __init__(self, **kwargs):
t.execute()


@pytest.mark.parametrize("agent", [DummyAgent(), AsyncDummyAgent()], ids=["sync", "async"])
@pytest.mark.parametrize(
"agent,consume_metadata", [(DummyAgent(), False), (AsyncDummyAgent(), True)], ids=["sync", "async"]
)
@pytest.mark.asyncio
async def test_async_agent_service(agent):
async def test_async_agent_service(agent, consume_metadata):
AgentRegistry.register(agent, override=True)
service = AsyncAgentService()
ctx = MagicMock(spec=grpc.ServicerContext)

inputs_proto = task_inputs.to_flyte_idl()
output_prefix = "/tmp"
metadata_bytes = DummyMetadata(job_id=dummy_id, output_path=f"{output_prefix}/{dummy_id}").encode()
metadata_bytes = (
DummyMetadata(
job_id=dummy_id,
output_path=f"{output_prefix}/{dummy_id}",
task_name=task_execution_metadata.task_execution_id.task_id.name,
).encode()
if consume_metadata
else DummyMetadata(job_id=dummy_id).encode()
)

tmp = get_task_template(agent.task_category.name).to_flyte_idl()
task_category = TaskCategory(name=agent.task_category.name, version=0)
req = CreateTaskRequest(inputs=inputs_proto, output_prefix=output_prefix, template=tmp)
req = CreateTaskRequest(
inputs=inputs_proto,
template=tmp,
output_prefix=output_prefix,
task_execution_metadata=task_execution_metadata.to_flyte_idl(),
)

res = await service.CreateTask(req, ctx)
assert res.resource_meta == metadata_bytes
Expand Down

0 comments on commit 133e8d5

Please sign in to comment.