Skip to content

Commit

Permalink
Fix Monodocs build (flyteorg#2235)
Browse files Browse the repository at this point in the history
* Fix Monodocs build

Signed-off-by: Kevin Su <[email protected]>

* SETUPTOOLS_SCM_PRETEND_VERSION=2.0.0

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* test

Signed-off-by: Kevin Su <[email protected]>

* test

Signed-off-by: Kevin Su <[email protected]>

* test

Signed-off-by: Kevin Su <[email protected]>

* test

Signed-off-by: Kevin Su <[email protected]>

---------

Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Mar 4, 2024
1 parent f1b5eba commit 4a58a67
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 18 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/monodocs_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,13 @@ jobs:
working-directory: ${{ github.workspace }}/flyte
run: |
conda activate monodocs-env
export SETUPTOOLS_SCM_PRETEND_VERSION="2.0.0"
pip install -e ./flyteidl
- shell: bash -el {0}
working-directory: ${{ github.workspace }}/flytekit
run: |
conda activate monodocs-env
pip install -e .
conda info
conda list
conda config --show-sources
Expand Down
29 changes: 12 additions & 17 deletions flytekit/extend/backend/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,6 @@ class Resource:
outputs: Optional[Union[LiteralMap, typing.Dict[str, Any]]] = None


T = typing.TypeVar("T", bound=ResourceMeta)


class AgentBase(ABC):
name = "Base Agent"

Expand Down Expand Up @@ -127,7 +124,7 @@ def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap], **kwargs
raise NotImplementedError


class AsyncAgentBase(AgentBase, typing.Generic[T]):
class AsyncAgentBase(AgentBase):
"""
This is the base class for all async agents. It defines the interface that all agents must implement.
The agent service is responsible for invoking agents. The propeller will communicate with the agent service
Expand All @@ -139,7 +136,7 @@ class AsyncAgentBase(AgentBase, typing.Generic[T]):

name = "Base Async Agent"

def __init__(self, metadata_type: typing.Type[T], **kwargs):
def __init__(self, metadata_type: ResourceMeta, **kwargs):
super().__init__(**kwargs)
self._metadata_type = metadata_type

Expand All @@ -148,14 +145,14 @@ def metadata_type(self) -> ResourceMeta:
return self._metadata_type

@abstractmethod
def create(self, task_template: TaskTemplate, inputs: Optional[LiteralMap], **kwargs) -> T:
def create(self, task_template: TaskTemplate, inputs: Optional[LiteralMap], **kwargs) -> ResourceMeta:
"""
Return a resource meta that can be used to get the status of the task.
"""
raise NotImplementedError

@abstractmethod
def get(self, resource_meta: T, **kwargs) -> Resource:
def get(self, resource_meta: ResourceMeta, **kwargs) -> Resource:
"""
Return the status of the task, and return the outputs in some cases. For example, bigquery job
can't write the structured dataset to the output location, so it returns the output literals to the propeller,
Expand All @@ -164,7 +161,7 @@ def get(self, resource_meta: T, **kwargs) -> Resource:
raise NotImplementedError

@abstractmethod
def delete(self, resource_meta: T, **kwargs):
def delete(self, resource_meta: ResourceMeta, **kwargs):
"""
Delete the task. This call should be idempotent. It should raise an error if fails to delete the task.
"""
Expand Down Expand Up @@ -231,9 +228,7 @@ class SyncAgentExecutorMixin:
Sending a prompt to ChatGPT and getting a response, or retrieving some metadata from a backend system.
"""

T = typing.TypeVar("T", "SyncAgentExecutorMixin", PythonTask)

def execute(self: T, **kwargs) -> LiteralMap:
def execute(self: PythonTask, **kwargs) -> LiteralMap:
from flytekit.tools.translator import get_serializable

ctx = FlyteContext.current_context()
Expand All @@ -250,7 +245,9 @@ def execute(self: T, **kwargs) -> LiteralMap:
return TypeEngine.dict_to_literal_map(ctx, resource.outputs)
return resource.outputs

async def _do(self: T, agent: SyncAgentBase, template: TaskTemplate, inputs: Dict[str, Any] = None) -> Resource:
async def _do(
self: PythonTask, agent: SyncAgentBase, template: TaskTemplate, inputs: Dict[str, Any] = None
) -> Resource:
try:
ctx = FlyteContext.current_context()
literal_map = TypeEngine.dict_to_literal_map(ctx, inputs or {}, self.get_input_types())
Expand All @@ -267,12 +264,10 @@ class AsyncAgentExecutorMixin:
Asynchronous tasks are tasks that take a long time to complete, such as running a query.
"""

T = typing.TypeVar("T", "AsyncAgentExecutorMixin", PythonTask)

_clean_up_task: coroutine = None
_agent: AsyncAgentBase = None

def execute(self: T, **kwargs) -> LiteralMap:
def execute(self: PythonTask, **kwargs) -> LiteralMap:
ctx = FlyteContext.current_context()
ss = ctx.serialization_settings or SerializationSettings(ImageConfig())
output_prefix = ctx.file_access.get_random_remote_directory()
Expand Down Expand Up @@ -301,7 +296,7 @@ def execute(self: T, **kwargs) -> LiteralMap:
return resource.outputs

async def _create(
self: T, task_template: TaskTemplate, output_prefix: str, inputs: Dict[str, Any] = None
self: PythonTask, task_template: TaskTemplate, output_prefix: str, inputs: Dict[str, Any] = None
) -> ResourceMeta:
ctx = FlyteContext.current_context()

Expand All @@ -322,7 +317,7 @@ async def _create(
signal.signal(signal.SIGINT, partial(self.signal_handler, resource_meta)) # type: ignore
return resource_meta

async def _get(self: T, resource_meta: ResourceMeta) -> Resource:
async def _get(self: PythonTask, resource_meta: ResourceMeta) -> Resource:
phase = TaskExecution.RUNNING

progress = Progress(transient=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class BigQueryMetadata(ResourceMeta):
location: str


class BigQueryAgent(AsyncAgentBase[BigQueryMetadata]):
class BigQueryAgent(AsyncAgentBase):
name = "Bigquery Agent"

def __init__(self):
Expand Down

0 comments on commit 4a58a67

Please sign in to comment.