diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml
index 01d9db29a9..9cdb3f302e 100644
--- a/.github/workflows/pythonbuild.yml
+++ b/.github/workflows/pythonbuild.yml
@@ -74,7 +74,6 @@ jobs:
- name: Build and push to local registry
run: |
docker build . -f Dockerfile.dev -t localhost:30000/flytekit:dev --build-arg PYTHON_VERSION=${{ matrix.python-version }}
- docker push localhost:30000/flytekit:dev
- name: Integration Test with coverage
env:
FLYTEKIT_IMAGE: localhost:30000/flytekit:dev
diff --git a/flytekit/experimental/eager_function.py b/flytekit/experimental/eager_function.py
index 2ade725c4f..264d0d641a 100644
--- a/flytekit/experimental/eager_function.py
+++ b/flytekit/experimental/eager_function.py
@@ -34,7 +34,7 @@
}}
-
{entity_type}: `{entity_name}`
+{entity_type}: {entity_name}
Execution:
@@ -99,13 +99,14 @@ def __init__(
async_stack: "AsyncStack",
timeout: Optional[timedelta] = None,
poll_interval: Optional[timedelta] = None,
+ local_entrypoint: bool = False,
):
self.entity = entity
self.ctx = ctx
self.async_stack = async_stack
self.execution_state = self.ctx.execution_state.mode
- # TODO: move this out into the eager wrapper.
self.remote = remote
+ self.local_entrypoint = local_entrypoint
if self.remote is not None:
logger.debug(f"Using remote config: {self.remote.config}")
else:
@@ -127,7 +128,7 @@ async def __call__(self, **kwargs):
"If you need to use a subworkflow, use a static @workflow or nested @eager workflow."
)
- if self.ctx.execution_state.is_local_execution():
+ if not self.local_entrypoint and self.ctx.execution_state.is_local_execution():
# If running as a local workflow execution, just execute the python function
try:
if isinstance(self.entity, WorkflowBase):
@@ -137,7 +138,8 @@ async def __call__(self, **kwargs):
out = await out
return out
elif isinstance(self.entity, PythonTask):
- out = self.entity._task_function(**kwargs)
+ # invoke the task-decorated entity
+ out = self.entity(**kwargs)
if inspect.iscoroutine(out):
out = await out
return out
@@ -167,9 +169,14 @@ async def __call__(self, **kwargs):
self._execution = execution
url = self.remote.generate_console_url(execution)
+ msg = f"Running flyte {type(self.entity)} {entity_name} on remote cluster: {url}"
+ if self.local_entrypoint:
+ logger.info(msg)
+ else:
+ logger.debug(msg)
+
node = AsyncNode(self, entity_name, execution, url)
self.async_stack.set_node(node)
- logger.debug(url)
poll_interval = self._poll_interval or timedelta(seconds=30)
time_to_give_up = datetime.max if self._timeout is None else datetime.utcnow() + self._timeout
@@ -219,7 +226,16 @@ def __init__(self, async_entity, entity_name, execution=None, url=None):
self.entity_name = entity_name
self.async_entity = async_entity
self.execution = execution
- self.url = url
+ self._url = url
+
+ @property
+ def url(self) -> str:
+ # make sure that internal flyte sandbox endpoint is replaced with localhost endpoint when rendering the urls
+ # for flyte decks
+ endpoint_root = FLYTE_SANDBOX_INTERNAL_ENDPOINT.replace("http://", "")
+ if endpoint_root in self._url:
+ return self._url.replace(endpoint_root, "localhost:30080")
+ return self._url
@property
def entity_type(self) -> str:
@@ -273,7 +289,7 @@ def get_io(dict_like):
except Exception:
return dict_like
- output = "# Nodes\n\n
"
+ output = "Nodes
"
for node in async_stack.call_stack:
node_inputs = get_io(node.execution.inputs)
if node.execution.closure.phase in {WorkflowExecutionPhase.FAILED}:
@@ -296,11 +312,12 @@ def get_io(dict_like):
@asynccontextmanager
async def eager_context(
fn,
- remote: FlyteRemote,
+ remote: Optional[FlyteRemote],
ctx: FlyteContext,
async_stack: AsyncStack,
timeout: Optional[timedelta] = None,
poll_interval: Optional[timedelta] = None,
+ local_entrypoint: bool = False,
):
"""This context manager overrides all tasks in the global namespace with async versions."""
@@ -310,7 +327,7 @@ async def eager_context(
for k, v in fn.__globals__.items():
if isinstance(v, (PythonTask, WorkflowBase)):
_original_cache[k] = v
- fn.__globals__[k] = AsyncEntity(v, remote, ctx, async_stack, timeout, poll_interval)
+ fn.__globals__[k] = AsyncEntity(v, remote, ctx, async_stack, timeout, poll_interval, local_entrypoint)
try:
yield
@@ -356,6 +373,7 @@ def eager(
client_secret_key: Optional[str] = None,
timeout: Optional[timedelta] = None,
poll_interval: Optional[timedelta] = None,
+ local_entrypoint: bool = False,
**kwargs,
):
"""Eager workflow decorator.
@@ -367,6 +385,9 @@ def eager(
workflow to complete or terminate. By default, the eager workflow will wait indefinitely until complete.
:param poll_interval: The poll interval for checking if a task/workflow execution within the eager workflow has
finished. If not specified, the default poll interval is 6 seconds.
+ :param local_entrypoint: If True, the eager workflow will can be executed locally but use the provided
+ :py:func:`~flytekit.remote.FlyteRemote` object to create task/workflow executions. This is useful for local
+ testing against a remote Flyte cluster.
:param kwargs: keyword-arguments forwarded to :py:func:`~flytekit.task`.
This type of workflow will execute all flyte entities within it eagerly, meaning that all python constructs can be
@@ -430,12 +451,7 @@ async def eager_workflow(x: int) -> int:
client_secret_key="my_client_secret_key",
)
async def eager_workflow(x: int) -> int:
- try:
- out = await add_one(x)
- except EagerException:
- # The ValueError error is caught
- # and raised as an EagerException
- raise
+ out = await add_one(x)
return await double(one)
Where ``config.yaml`` contains is a flytectl-compatible config file.
@@ -446,10 +462,16 @@ async def eager_workflow(x: int) -> int:
.. code-block:: python
- @eager(remote=FlyteRemote(config=Config.from_sandbox()))
+ @eager(remote=FlyteRemote(config=Config.for_sandbox()))
async def eager_workflow(x: int) -> int:
...
+ .. important::
+
+ When using ``local_entrypoint=True`` you also need to specify the ``remote`` argument. In this case, the eager
+ workflow runtime will be local, but all task/subworkflow invocations will occur on the specified Flyte cluster.
+ This argument is primarily used for testing and debugging eager workflow logic locally.
+
"""
if _fn is None:
@@ -458,9 +480,13 @@ async def eager_workflow(x: int) -> int:
remote=remote,
client_secret_group=client_secret_group,
client_secret_key=client_secret_key,
+ local_entrypoint=local_entrypoint,
**kwargs,
)
+ if local_entrypoint and remote is None:
+ raise ValueError("Must specify remote argument if local_entrypoint is True")
+
@wraps(_fn)
async def wrapper(*args, **kws):
# grab the "async_ctx" argument injected by PythonFunctionTask.execute
@@ -476,7 +502,7 @@ async def wrapper(*args, **kws):
execution_id = exec_params.execution_id
async_stack = AsyncStack(task_id, execution_id)
- _remote = _prepare_remote(_remote, ctx, client_secret_group, client_secret_key)
+ _remote = _prepare_remote(_remote, ctx, client_secret_group, client_secret_key, local_entrypoint)
# make sure sub-nodes as cleaned up on termination signal
loop = asyncio.get_event_loop()
@@ -484,9 +510,13 @@ async def wrapper(*args, **kws):
cleanup_fn = partial(asyncio.ensure_future, node_cleanup_partial(signal.SIGTERM, loop))
signal.signal(signal.SIGTERM, partial(node_cleanup, loop=loop, async_stack=async_stack))
- async with eager_context(_fn, _remote, ctx, async_stack, timeout, poll_interval):
+ async with eager_context(_fn, _remote, ctx, async_stack, timeout, poll_interval, local_entrypoint):
try:
- out = await _fn(*args, **kws)
+ if _remote is not None:
+ with _remote.remote_context():
+ out = await _fn(*args, **kws)
+ else:
+ out = await _fn(*args, **kws)
# need to await for _fn to complete, then invoke the deck
await render_deck(async_stack)
return out
@@ -512,13 +542,23 @@ def _prepare_remote(
ctx: FlyteContext,
client_secret_group: Optional[str] = None,
client_secret_key: Optional[str] = None,
+ local_entrypoint: bool = False,
) -> Optional[FlyteRemote]:
"""Prepare FlyteRemote object for accessing Flyte cluster in a task running on the same cluster."""
- if remote is None or ctx.execution_state.mode in {
+ is_local_execution_mode = ctx.execution_state.mode in {
ExecutionState.Mode.LOCAL_TASK_EXECUTION,
ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION,
- }:
+ }
+
+ if remote is not None and local_entrypoint and is_local_execution_mode:
+ # when running eager workflows as a local entrypoint, we don't have to modify the remote object
+ # because we can assume that the user is running this from their local machine and can do browser-based
+ # authentication.
+ logger.info("Running eager workflow as local entrypoint")
+ return remote
+
+ if remote is None or is_local_execution_mode:
# if running the "eager workflow" (which is actually task) locally, run the task as a function,
# which doesn't need a remote object
return None
diff --git a/tests/flytekit/integration/experimental/__init__.py b/tests/flytekit/integration/experimental/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/flytekit/integration/experimental/eager_workflows.py b/tests/flytekit/integration/experimental/eager_workflows.py
index b8bc1c9a1b..2dbc28a640 100644
--- a/tests/flytekit/integration/experimental/eager_workflows.py
+++ b/tests/flytekit/integration/experimental/eager_workflows.py
@@ -1,11 +1,18 @@
import asyncio
+import os
import typing
from functools import partial
+from pathlib import Path
+
+import pandas as pd
from flytekit import task, workflow
from flytekit.configuration import Config
from flytekit.experimental import EagerException, eager
from flytekit.remote import FlyteRemote
+from flytekit.types.directory import FlyteDirectory
+from flytekit.types.file import FlyteFile
+from flytekit.types.structured import StructuredDataset
remote = FlyteRemote(
config=Config.for_sandbox(),
@@ -39,6 +46,29 @@ def raises_exc(x: int) -> int:
return x
+@task
+def create_structured_dataset() -> StructuredDataset:
+ df = pd.DataFrame({"a": [1, 2, 3]})
+ return StructuredDataset(dataframe=df)
+
+
+@task
+def create_file() -> FlyteFile:
+ fname = "/tmp/flytekit_test_file"
+ with open(fname, "w") as fh:
+ fh.write("some data\n")
+ return FlyteFile(path=fname)
+
+
+@task
+def create_directory() -> FlyteDirectory:
+ dirname = "/tmp/flytekit_test_dir"
+ Path(dirname).mkdir(exist_ok=True, parents=True)
+ with open(os.path.join(dirname, "file"), "w") as tmp:
+ tmp.write("some data\n")
+ return FlyteDirectory(path=dirname)
+
+
@eager_partial
async def simple_eager_wf(x: int) -> int:
out = await add_one(x=x)
@@ -89,5 +119,36 @@ async def eager_wf_with_subworkflow(x: int) -> int:
return await double(x=out)
+@eager_partial
+async def eager_wf_structured_dataset() -> int:
+ dataset = await create_structured_dataset()
+ df = dataset.open(pd.DataFrame).all()
+ return int(df["a"].sum())
+
+
+@eager_partial
+async def eager_wf_flyte_file() -> str:
+ file = await create_file()
+ file.download()
+ with open(file.path) as f:
+ data = f.read().strip()
+ return data
+
+
+@eager_partial
+async def eager_wf_flyte_directory() -> str:
+ directory = await create_directory()
+ directory.download()
+ with open(os.path.join(directory.path, "file")) as f:
+ data = f.read().strip()
+ return data
+
+
+@eager(remote=remote, local_entrypoint=True)
+async def eager_wf_local_entrypoint(x: int) -> int:
+ out = await add_one(x=x)
+ return await double(x=out)
+
+
if __name__ == "__main__":
print(asyncio.run(simple_eager_wf(x=1)))
diff --git a/tests/flytekit/integration/experimental/test_eager_workflows.py b/tests/flytekit/integration/experimental/test_eager_workflows.py
index a4b2186cf5..ad1bc44112 100644
--- a/tests/flytekit/integration/experimental/test_eager_workflows.py
+++ b/tests/flytekit/integration/experimental/test_eager_workflows.py
@@ -20,7 +20,7 @@
```
"""
-
+import asyncio
import os
import subprocess
import time
@@ -31,6 +31,8 @@
from flytekit.configuration import Config
from flytekit.remote import FlyteRemote
+from .eager_workflows import eager_wf_local_entrypoint
+
MODULE = "eager_workflows"
MODULE_PATH = Path(__file__).parent / f"{MODULE}.py"
CONFIG = os.environ.get("FLYTECTL_CONFIG", str(Path.home() / ".flyte" / "config-sandbox.yaml"))
@@ -70,6 +72,9 @@ def register():
("eager", "gather_eager_wf", 1, [2] * 10),
("eager", "nested_eager_wf", 1, 8),
("eager", "eager_wf_with_subworkflow", 1, 4),
+ ("eager", "eager_wf_structured_dataset", None, 6),
+ ("eager", "eager_wf_flyte_file", None, "some data"),
+ ("eager", "eager_wf_flyte_directory", None, "some data"),
("workflow", "wf_with_eager_wf", 1, 8),
],
)
@@ -98,5 +103,14 @@ def test_eager_workflows(register, entity_type, entity_name, input, output):
if entity is None:
raise RuntimeError("failed to fetch entity")
- execution = remote.execute(entity, inputs={"x": input}, wait=True)
+ inputs = {} if input is None else {"x": input}
+ execution = remote.execute(entity, inputs=inputs, wait=True)
assert execution.outputs["o0"] == output
+
+
+@pytest.mark.skipif(
+ os.environ.get("FLYTEKIT_CI", False), reason="Running workflows with sandbox cluster fails due to memory pressure"
+)
+def test_eager_workflow_local_entrypoint(register):
+ result = asyncio.run(eager_wf_local_entrypoint(x=1))
+ assert result == 4
diff --git a/tests/flytekit/unit/experimental/test_eager_workflows.py b/tests/flytekit/unit/experimental/test_eager_workflows.py
index abebf6f292..f588c04db0 100644
--- a/tests/flytekit/unit/experimental/test_eager_workflows.py
+++ b/tests/flytekit/unit/experimental/test_eager_workflows.py
@@ -1,13 +1,22 @@
import asyncio
+import os
import typing
+from pathlib import Path
import hypothesis.strategies as st
+import pandas as pd
import pytest
-from hypothesis import given, infer, settings
+from hypothesis import given, settings
from flytekit import dynamic, task, workflow
from flytekit.core.type_engine import TypeTransformerFailedError
from flytekit.experimental import EagerException, eager
+from flytekit.types.directory import FlyteDirectory
+from flytekit.types.file import FlyteFile
+from flytekit.types.structured import StructuredDataset
+
+DEADLINE = 2000
+INTEGER_ST = st.integers(max_value=10_000_000)
@task
@@ -38,8 +47,8 @@ def dynamic_wf(x: int) -> int:
return double(x=out)
-@given(x_input=infer)
-@settings(deadline=1000, max_examples=5)
+@given(x_input=INTEGER_ST)
+@settings(deadline=DEADLINE, max_examples=5)
def test_simple_eager_workflow(x_input: int):
"""Testing simple eager workflow with just tasks."""
@@ -52,8 +61,8 @@ async def eager_wf(x: int) -> int:
assert result == (x_input + 1) * 2
-@given(x_input=infer)
-@settings(deadline=1000, max_examples=5)
+@given(x_input=INTEGER_ST)
+@settings(deadline=DEADLINE, max_examples=5)
def test_conditional_eager_workflow(x_input: int):
"""Test eager workfow with conditional logic."""
@@ -70,8 +79,8 @@ async def eager_wf(x: int) -> int:
assert result == 1
-@given(x_input=infer)
-@settings(deadline=1000, max_examples=5)
+@given(x_input=INTEGER_ST)
+@settings(deadline=DEADLINE, max_examples=5)
def test_try_except_eager_workflow(x_input: int):
"""Test eager workflow with try/except logic."""
@@ -89,8 +98,8 @@ async def eager_wf(x: int) -> int:
assert result == x_input
-@given(x_input=infer, n_input=st.integers(min_value=1, max_value=20))
-@settings(deadline=1000, max_examples=5)
+@given(x_input=INTEGER_ST, n_input=st.integers(min_value=1, max_value=20))
+@settings(deadline=DEADLINE, max_examples=5)
def test_gather_eager_workflow(x_input: int, n_input: int):
"""Test eager workflow with asyncio gather."""
@@ -103,8 +112,8 @@ async def eager_wf(x: int, n: int) -> typing.List[int]:
assert results == [x_input + 1 for _ in range(n_input)]
-@given(x_input=infer)
-@settings(deadline=1000, max_examples=5)
+@given(x_input=INTEGER_ST)
+@settings(deadline=DEADLINE, max_examples=5)
def test_eager_workflow_with_dynamic_exception(x_input: int):
"""Test eager workflow with dynamic workflow is not supported."""
@@ -121,8 +130,8 @@ async def nested_eager_wf(x: int) -> int:
return await add_one(x=x)
-@given(x_input=infer)
-@settings(deadline=1000, max_examples=5)
+@given(x_input=INTEGER_ST)
+@settings(deadline=DEADLINE, max_examples=5)
def test_nested_eager_workflow(x_input: int):
"""Testing running nested eager workflows."""
@@ -135,8 +144,8 @@ async def eager_wf(x: int) -> int:
assert result == (x_input + 1) * 2
-@given(x_input=infer)
-@settings(deadline=1000, max_examples=5)
+@given(x_input=INTEGER_ST)
+@settings(deadline=DEADLINE, max_examples=5)
def test_eager_workflow_within_workflow(x_input: int):
"""Testing running eager workflow within a static workflow."""
@@ -158,8 +167,8 @@ def subworkflow(x: int) -> int:
return add_one(x=x)
-@given(x_input=infer)
-@settings(deadline=1000, max_examples=5)
+@given(x_input=INTEGER_ST)
+@settings(deadline=DEADLINE, max_examples=5)
def test_workflow_within_eager_workflow(x_input: int):
"""Testing running a static workflow within an eager workflow."""
@@ -172,10 +181,10 @@ async def eager_wf(x: int) -> int:
assert result == (x_input + 1) * 2
-@given(x_input=infer)
-@settings(deadline=1000, max_examples=5)
+@given(x_input=INTEGER_ST)
+@settings(deadline=DEADLINE, max_examples=5)
def test_local_task_eager_workflow_exception(x_input: int):
- """Testing simple eager workflow with just tasks."""
+ """Testing simple eager workflow with a local function task doesn't work."""
@task
def local_task(x: int) -> int:
@@ -189,8 +198,8 @@ async def eager_wf_with_local(x: int) -> int:
asyncio.run(eager_wf_with_local(x=x_input))
-@given(x_input=infer)
-@settings(deadline=1000, max_examples=5)
+@given(x_input=INTEGER_ST)
+@settings(deadline=DEADLINE, max_examples=5)
@pytest.mark.filterwarnings("ignore:coroutine 'AsyncEntity.__call__' was never awaited")
def test_local_workflow_within_eager_workflow_exception(x_input: int):
"""Cannot call a locally-defined workflow within an eager workflow"""
@@ -206,3 +215,59 @@ async def eager_wf(x: int) -> int:
with pytest.raises(TypeTransformerFailedError):
asyncio.run(eager_wf(x=x_input))
+
+
+@task
+def create_structured_dataset() -> StructuredDataset:
+ df = pd.DataFrame({"a": [1, 2, 3]})
+ return StructuredDataset(dataframe=df)
+
+
+@task
+def create_file() -> FlyteFile:
+ fname = "/tmp/flytekit_test_file"
+ with open(fname, "w") as fh:
+ fh.write("some data\n")
+ return FlyteFile(path=fname)
+
+
+@task
+def create_directory() -> FlyteDirectory:
+ dirname = "/tmp/flytekit_test_dir"
+ Path(dirname).mkdir(exist_ok=True, parents=True)
+ with open(os.path.join(dirname, "file"), "w") as tmp:
+ tmp.write("some data\n")
+ return FlyteDirectory(path=dirname)
+
+
+def test_eager_workflow_with_offloaded_types():
+ """Test eager workflow that eager workflows work with offloaded types."""
+
+ @eager
+ async def eager_wf_structured_dataset() -> int:
+ dataset = await create_structured_dataset()
+ df = dataset.open(pd.DataFrame).all()
+ return df["a"].sum()
+
+ @eager
+ async def eager_wf_flyte_file() -> str:
+ file = await create_file()
+ with open(file.path) as f:
+ data = f.read().strip()
+ return data
+
+ @eager
+ async def eager_wf_flyte_directory() -> str:
+ directory = await create_directory()
+ with open(os.path.join(directory.path, "file")) as f:
+ data = f.read().strip()
+ return data
+
+ result = asyncio.run(eager_wf_structured_dataset())
+ assert result == 6
+
+ result = asyncio.run(eager_wf_flyte_file())
+ assert result == "some data"
+
+ result = asyncio.run(eager_wf_flyte_directory())
+ assert result == "some data"