diff --git a/.github/workflows/monodocs_build.yml b/.github/workflows/monodocs_build.yml index 7b30ef957d..7f11de452c 100644 --- a/.github/workflows/monodocs_build.yml +++ b/.github/workflows/monodocs_build.yml @@ -38,7 +38,13 @@ jobs: working-directory: ${{ github.workspace }}/flyte run: | conda activate monodocs-env + export SETUPTOOLS_SCM_PRETEND_VERSION="2.0.0" pip install -e ./flyteidl + - shell: bash -el {0} + working-directory: ${{ github.workspace }}/flytekit + run: | + conda activate monodocs-env + pip install -e . conda info conda list conda config --show-sources diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 7cbf380b16..0f1b71068d 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -90,9 +90,6 @@ class Resource: outputs: Optional[Union[LiteralMap, typing.Dict[str, Any]]] = None -T = typing.TypeVar("T", bound=ResourceMeta) - - class AgentBase(ABC): name = "Base Agent" @@ -127,7 +124,7 @@ def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap], **kwargs raise NotImplementedError -class AsyncAgentBase(AgentBase, typing.Generic[T]): +class AsyncAgentBase(AgentBase): """ This is the base class for all async agents. It defines the interface that all agents must implement. The agent service is responsible for invoking agents. The propeller will communicate with the agent service @@ -139,7 +136,7 @@ class AsyncAgentBase(AgentBase, typing.Generic[T]): name = "Base Async Agent" - def __init__(self, metadata_type: typing.Type[T], **kwargs): + def __init__(self, metadata_type: ResourceMeta, **kwargs): super().__init__(**kwargs) self._metadata_type = metadata_type @@ -148,14 +145,14 @@ def metadata_type(self) -> ResourceMeta: return self._metadata_type @abstractmethod - def create(self, task_template: TaskTemplate, inputs: Optional[LiteralMap], **kwargs) -> T: + def create(self, task_template: TaskTemplate, inputs: Optional[LiteralMap], **kwargs) -> ResourceMeta: """ Return a resource meta that can be used to get the status of the task. """ raise NotImplementedError @abstractmethod - def get(self, resource_meta: T, **kwargs) -> Resource: + def get(self, resource_meta: ResourceMeta, **kwargs) -> Resource: """ Return the status of the task, and return the outputs in some cases. For example, bigquery job can't write the structured dataset to the output location, so it returns the output literals to the propeller, @@ -164,7 +161,7 @@ def get(self, resource_meta: T, **kwargs) -> Resource: raise NotImplementedError @abstractmethod - def delete(self, resource_meta: T, **kwargs): + def delete(self, resource_meta: ResourceMeta, **kwargs): """ Delete the task. This call should be idempotent. It should raise an error if fails to delete the task. """ @@ -231,9 +228,7 @@ class SyncAgentExecutorMixin: Sending a prompt to ChatGPT and getting a response, or retrieving some metadata from a backend system. """ - T = typing.TypeVar("T", "SyncAgentExecutorMixin", PythonTask) - - def execute(self: T, **kwargs) -> LiteralMap: + def execute(self: PythonTask, **kwargs) -> LiteralMap: from flytekit.tools.translator import get_serializable ctx = FlyteContext.current_context() @@ -250,7 +245,9 @@ def execute(self: T, **kwargs) -> LiteralMap: return TypeEngine.dict_to_literal_map(ctx, resource.outputs) return resource.outputs - async def _do(self: T, agent: SyncAgentBase, template: TaskTemplate, inputs: Dict[str, Any] = None) -> Resource: + async def _do( + self: PythonTask, agent: SyncAgentBase, template: TaskTemplate, inputs: Dict[str, Any] = None + ) -> Resource: try: ctx = FlyteContext.current_context() literal_map = TypeEngine.dict_to_literal_map(ctx, inputs or {}, self.get_input_types()) @@ -267,12 +264,10 @@ class AsyncAgentExecutorMixin: Asynchronous tasks are tasks that take a long time to complete, such as running a query. """ - T = typing.TypeVar("T", "AsyncAgentExecutorMixin", PythonTask) - _clean_up_task: coroutine = None _agent: AsyncAgentBase = None - def execute(self: T, **kwargs) -> LiteralMap: + def execute(self: PythonTask, **kwargs) -> LiteralMap: ctx = FlyteContext.current_context() ss = ctx.serialization_settings or SerializationSettings(ImageConfig()) output_prefix = ctx.file_access.get_random_remote_directory() @@ -301,7 +296,7 @@ def execute(self: T, **kwargs) -> LiteralMap: return resource.outputs async def _create( - self: T, task_template: TaskTemplate, output_prefix: str, inputs: Dict[str, Any] = None + self: PythonTask, task_template: TaskTemplate, output_prefix: str, inputs: Dict[str, Any] = None ) -> ResourceMeta: ctx = FlyteContext.current_context() @@ -322,7 +317,7 @@ async def _create( signal.signal(signal.SIGINT, partial(self.signal_handler, resource_meta)) # type: ignore return resource_meta - async def _get(self: T, resource_meta: ResourceMeta) -> Resource: + async def _get(self: PythonTask, resource_meta: ResourceMeta) -> Resource: phase = TaskExecution.RUNNING progress = Progress(transient=True) diff --git a/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py b/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py index 0275162f72..f6b7cfd6e6 100644 --- a/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py +++ b/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py @@ -31,7 +31,7 @@ class BigQueryMetadata(ResourceMeta): location: str -class BigQueryAgent(AsyncAgentBase[BigQueryMetadata]): +class BigQueryAgent(AsyncAgentBase): name = "Bigquery Agent" def __init__(self):