Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP - Lazy Evaluation #3003

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 90 additions & 76 deletions flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,94 @@ class IgnoreOutputs(Exception):
pass


class Run(object):
def __init__(self, ctx, task, inputs):
self._ctx = ctx
self._task = task
self._inputs = inputs
self._has_run: bool = False
self._outputs: "Literal" = None

@property
def outputs(self) -> "Literal":
return self._outputs

def invoke(self, node_metadata: dict) -> _literal_models.LiteralMap:
if self._has_run:
return self._outputs
task = self._task
try:
literals = translate_inputs_to_literals(
self._ctx,
incoming_values=self._inputs,
flyte_interface_types=task.interface.inputs,
native_types=task.get_input_types(), # type: ignore
)
except TypeTransformerFailedError as exc:
exc.args = (f"Failed to convert inputs of task '{task.name}':\n {exc.args[0]}",)
raise

input_literal_map = _literal_models.LiteralMap(literals=literals)

local_config = LocalConfig.auto()
# TODO: use node metadata here
print("node_metadata", node_metadata)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use node metadata here

if task.metadata.cache and local_config.cache_enabled:
if local_config.cache_overwrite:
outputs_literal_map = None
logger.info("Cache overwrite, task will be executed now")
else:
logger.info(
f"Checking cache for task named {task.name}, cache version {task.metadata.cache_version} "
f", inputs: {self._inputs}, and ignore input vars: {task.metadata.cache_ignore_input_vars}"
)
outputs_literal_map = LocalTaskCache.get(
task.name, task.metadata.cache_version, input_literal_map, task.metadata.cache_ignore_input_vars
)
# The cache returns None iff the key does not exist in the cache
if outputs_literal_map is None:
logger.info("Cache miss, task will be executed now")
else:
logger.info("Cache hit")
if outputs_literal_map is None:
outputs_literal_map = task.sandbox_execute(self._ctx, input_literal_map)
LocalTaskCache.set(
task.name,
task.metadata.cache_version,
input_literal_map,
task.metadata.cache_ignore_input_vars,
outputs_literal_map,
)
logger.info(
f"Cache set for task named {task.name}, cache version {task.metadata.cache_version} "
f", inputs: {self._inputs}, and ignore input vars: {task.metadata.cache_ignore_input_vars}"
)
else:
# This code should mirror the call to `sandbox_execute` in the above cache case.
# Code is simpler with duplication and less metaprogramming, but introduces regressions
# if one is changed and not the other.
outputs_literal_map = task.sandbox_execute(self._ctx, input_literal_map)

self._has_run = True

if inspect.iscoroutine(outputs_literal_map):
self._outputs = outputs_literal_map
return outputs_literal_map

outputs_literals = outputs_literal_map.literals

# TODO maybe this is the part that should be done for local execution, we pass the outputs to some special
# location, otherwise we dont really need to right? The higher level execute could just handle literalMap
# After running, we again have to wrap the outputs, if any, back into Promise objects
output_names = list(task.interface.outputs.keys()) # type: ignore
if len(output_names) != len(outputs_literals):
# Length check, clean up exception
raise AssertionError(f"Length difference {len(output_names)} {len(outputs_literals)}")

self._outputs = outputs_literal_map
return outputs_literal_map


class Task(object):
"""
The base of all Tasks in flytekit. This task is closest to the FlyteIDL TaskTemplate and captures information in
Expand Down Expand Up @@ -278,86 +366,12 @@ def get_input_types(self) -> Optional[Dict[str, type]]:
def local_execute(
self, ctx: FlyteContext, **kwargs
) -> Union[Tuple[Promise], Promise, VoidPromise, Coroutine, None]:
"""
This function is used only in the local execution path and is responsible for calling dispatch execute.
Use this function when calling a task with native values (or Promises containing Flyte literals derived from
Python native values).
"""
# Unwrap the kwargs values. After this, we essentially have a LiteralMap
# The reason why we need to do this is because the inputs during local execute can be of 2 types
# - Promises or native constants
# Promises as essentially inputs from previous task executions
# native constants are just bound to this specific task (default values for a task input)
# Also along with promises and constants, there could be dictionary or list of promises or constants
try:
literals = translate_inputs_to_literals(
ctx,
incoming_values=kwargs,
flyte_interface_types=self.interface.inputs,
native_types=self.get_input_types(), # type: ignore
)
except TypeTransformerFailedError as exc:
exc.args = (f"Failed to convert inputs of task '{self.name}':\n {exc.args[0]}",)
raise
input_literal_map = _literal_models.LiteralMap(literals=literals)

# if metadata.cache is set, check memoized version
local_config = LocalConfig.auto()
if self.metadata.cache and local_config.cache_enabled:
if local_config.cache_overwrite:
outputs_literal_map = None
logger.info("Cache overwrite, task will be executed now")
else:
logger.info(
f"Checking cache for task named {self.name}, cache version {self.metadata.cache_version} "
f", inputs: {kwargs}, and ignore input vars: {self.metadata.cache_ignore_input_vars}"
)
outputs_literal_map = LocalTaskCache.get(
self.name, self.metadata.cache_version, input_literal_map, self.metadata.cache_ignore_input_vars
)
# The cache returns None iff the key does not exist in the cache
if outputs_literal_map is None:
logger.info("Cache miss, task will be executed now")
else:
logger.info("Cache hit")
if outputs_literal_map is None:
outputs_literal_map = self.sandbox_execute(ctx, input_literal_map)
# TODO: need `native_inputs`
LocalTaskCache.set(
self.name,
self.metadata.cache_version,
input_literal_map,
self.metadata.cache_ignore_input_vars,
outputs_literal_map,
)
logger.info(
f"Cache set for task named {self.name}, cache version {self.metadata.cache_version} "
f", inputs: {kwargs}, and ignore input vars: {self.metadata.cache_ignore_input_vars}"
)
else:
# This code should mirror the call to `sandbox_execute` in the above cache case.
# Code is simpler with duplication and less metaprogramming, but introduces regressions
# if one is changed and not the other.
outputs_literal_map = self.sandbox_execute(ctx, input_literal_map)

if inspect.iscoroutine(outputs_literal_map):
return outputs_literal_map

outputs_literals = outputs_literal_map.literals

# TODO maybe this is the part that should be done for local execution, we pass the outputs to some special
# location, otherwise we dont really need to right? The higher level execute could just handle literalMap
# After running, we again have to wrap the outputs, if any, back into Promise objects
output_names = list(self.interface.outputs.keys()) # type: ignore
if len(output_names) != len(outputs_literals):
# Length check, clean up exception
raise AssertionError(f"Length difference {len(output_names)} {len(outputs_literals)}")

output_names = list(self.interface.outputs.keys())
# Tasks that don't return anything still return a VoidPromise
if len(output_names) == 0:
return VoidPromise(self.name)

vals = [Promise(var, outputs_literals[var]) for var in output_names]
vals = [Promise(var, Run(ctx, self, kwargs)) for var in output_names]
return create_task_output(vals, self.python_interface)

def __call__(self, *args: object, **kwargs: object) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, None]:
Expand Down
16 changes: 13 additions & 3 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def t1() -> (int, str): ...
def __init__(
self,
var: str,
val: Union[NodeOutput, _literals_models.Literal],
val: Union[NodeOutput, _literals_models.Literal, "Run"],
type: typing.Optional[_type_models.LiteralType] = None,
):
self._var = var
Expand All @@ -448,6 +448,8 @@ def __init__(
self._ref = None
self._attr_path: List[Union[str, int]] = []
self._type = type
self.node_metadata = None

if val and isinstance(val, NodeOutput):
self._ref = val
self._promise_ready = False
Expand Down Expand Up @@ -487,6 +489,12 @@ def val(self) -> _literals_models.Literal:
"""
If the promise is ready then this holds the actual evaluate value in Flyte's type system
"""
from flytekit.core.base_task import Run

if isinstance(self._val, Run):
# Invoke the task when using it in a workflow
lt_map = self._val.invoke(self.node_metadata)
self._val = lt_map.literals[self._var]
return self._val

@property
Expand All @@ -512,7 +520,7 @@ def attr_path(self) -> List[Union[str, int]]:
return self._attr_path

def eval(self) -> Any:
if not self._promise_ready or self._val is None:
if not self._promise_ready or self.val is None:
raise ValueError("Cannot Eval with incomplete promises")
if self.val.scalar is None or self.val.scalar.primitive is None:
raise ValueError("Eval can be invoked for primitive types only")
Expand Down Expand Up @@ -599,11 +607,13 @@ def with_overrides(
*args,
**kwargs,
)
self.node_metadata = {"cache": cache}
return self

def __repr__(self):
if self._promise_ready:
return f"Resolved({self._var}={self._val})"
return f"Resolved({self._var}={self.val})"

return f"Promise(node:{self.ref.node_id}.{self._var}.{self.attr_path})"

def __str__(self):
Expand Down