Skip to content
This repository has been archived by the owner on Jul 19, 2024. It is now read-only.

Commit

Permalink
Eager local entrypoint and support for offloaded types (flyteorg#1833)
Browse files Browse the repository at this point in the history
* implement eager workflow local entrypoint, support offloaded types

Signed-off-by: Niels Bantilan <[email protected]>

* wip local entrypoint

Signed-off-by: Niels Bantilan <[email protected]>

* add tests

Signed-off-by: Niels Bantilan <[email protected]>

* add local entrypoint tests

Signed-off-by: Niels Bantilan <[email protected]>

* update eager unit tests, delete test script

Signed-off-by: Niels Bantilan <[email protected]>

* clean up tests

Signed-off-by: Niels Bantilan <[email protected]>

* update ci

Signed-off-by: Niels Bantilan <[email protected]>

* update ci

Signed-off-by: Niels Bantilan <[email protected]>

* update ci

Signed-off-by: Niels Bantilan <[email protected]>

* update ci

Signed-off-by: Niels Bantilan <[email protected]>

* remove push step

Signed-off-by: Niels Bantilan <[email protected]>

---------

Signed-off-by: Niels Bantilan <[email protected]>
Signed-off-by: Future Outlier <[email protected]>
  • Loading branch information
cosmicBboy authored and Future Outlier committed Oct 3, 2023
1 parent 1aed4ef commit 16c73c0
Show file tree
Hide file tree
Showing 6 changed files with 225 additions and 46 deletions.
1 change: 0 additions & 1 deletion .github/workflows/pythonbuild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ jobs:
- name: Build and push to local registry
run: |
docker build . -f Dockerfile.dev -t localhost:30000/flytekit:dev --build-arg PYTHON_VERSION=${{ matrix.python-version }}
docker push localhost:30000/flytekit:dev
- name: Integration Test with coverage
env:
FLYTEKIT_IMAGE: localhost:30000/flytekit:dev
Expand Down
82 changes: 61 additions & 21 deletions flytekit/experimental/eager_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
}}
</style>
<h3>{entity_type}: `{entity_name}`</h3>
<h3>{entity_type}: {entity_name}</h3>
<p>
<strong>Execution:</strong>
Expand Down Expand Up @@ -99,13 +99,14 @@ def __init__(
async_stack: "AsyncStack",
timeout: Optional[timedelta] = None,
poll_interval: Optional[timedelta] = None,
local_entrypoint: bool = False,
):
self.entity = entity
self.ctx = ctx
self.async_stack = async_stack
self.execution_state = self.ctx.execution_state.mode
# TODO: move this out into the eager wrapper.
self.remote = remote
self.local_entrypoint = local_entrypoint
if self.remote is not None:
logger.debug(f"Using remote config: {self.remote.config}")
else:
Expand All @@ -127,7 +128,7 @@ async def __call__(self, **kwargs):
"If you need to use a subworkflow, use a static @workflow or nested @eager workflow."
)

if self.ctx.execution_state.is_local_execution():
if not self.local_entrypoint and self.ctx.execution_state.is_local_execution():
# If running as a local workflow execution, just execute the python function
try:
if isinstance(self.entity, WorkflowBase):
Expand All @@ -137,7 +138,8 @@ async def __call__(self, **kwargs):
out = await out
return out
elif isinstance(self.entity, PythonTask):
out = self.entity._task_function(**kwargs)
# invoke the task-decorated entity
out = self.entity(**kwargs)
if inspect.iscoroutine(out):
out = await out
return out
Expand Down Expand Up @@ -167,9 +169,14 @@ async def __call__(self, **kwargs):
self._execution = execution

url = self.remote.generate_console_url(execution)
msg = f"Running flyte {type(self.entity)} {entity_name} on remote cluster: {url}"
if self.local_entrypoint:
logger.info(msg)
else:
logger.debug(msg)

node = AsyncNode(self, entity_name, execution, url)
self.async_stack.set_node(node)
logger.debug(url)

poll_interval = self._poll_interval or timedelta(seconds=30)
time_to_give_up = datetime.max if self._timeout is None else datetime.utcnow() + self._timeout
Expand Down Expand Up @@ -219,7 +226,16 @@ def __init__(self, async_entity, entity_name, execution=None, url=None):
self.entity_name = entity_name
self.async_entity = async_entity
self.execution = execution
self.url = url
self._url = url

@property
def url(self) -> str:
# make sure that internal flyte sandbox endpoint is replaced with localhost endpoint when rendering the urls
# for flyte decks
endpoint_root = FLYTE_SANDBOX_INTERNAL_ENDPOINT.replace("http://", "")
if endpoint_root in self._url:
return self._url.replace(endpoint_root, "localhost:30080")
return self._url

@property
def entity_type(self) -> str:
Expand Down Expand Up @@ -273,7 +289,7 @@ def get_io(dict_like):
except Exception:
return dict_like

output = "# Nodes\n\n<hr>"
output = "<h2>Nodes</h2><hr>"
for node in async_stack.call_stack:
node_inputs = get_io(node.execution.inputs)
if node.execution.closure.phase in {WorkflowExecutionPhase.FAILED}:
Expand All @@ -296,11 +312,12 @@ def get_io(dict_like):
@asynccontextmanager
async def eager_context(
fn,
remote: FlyteRemote,
remote: Optional[FlyteRemote],
ctx: FlyteContext,
async_stack: AsyncStack,
timeout: Optional[timedelta] = None,
poll_interval: Optional[timedelta] = None,
local_entrypoint: bool = False,
):
"""This context manager overrides all tasks in the global namespace with async versions."""

Expand All @@ -310,7 +327,7 @@ async def eager_context(
for k, v in fn.__globals__.items():
if isinstance(v, (PythonTask, WorkflowBase)):
_original_cache[k] = v
fn.__globals__[k] = AsyncEntity(v, remote, ctx, async_stack, timeout, poll_interval)
fn.__globals__[k] = AsyncEntity(v, remote, ctx, async_stack, timeout, poll_interval, local_entrypoint)

try:
yield
Expand Down Expand Up @@ -356,6 +373,7 @@ def eager(
client_secret_key: Optional[str] = None,
timeout: Optional[timedelta] = None,
poll_interval: Optional[timedelta] = None,
local_entrypoint: bool = False,
**kwargs,
):
"""Eager workflow decorator.
Expand All @@ -367,6 +385,9 @@ def eager(
workflow to complete or terminate. By default, the eager workflow will wait indefinitely until complete.
:param poll_interval: The poll interval for checking if a task/workflow execution within the eager workflow has
finished. If not specified, the default poll interval is 6 seconds.
:param local_entrypoint: If True, the eager workflow will can be executed locally but use the provided
:py:func:`~flytekit.remote.FlyteRemote` object to create task/workflow executions. This is useful for local
testing against a remote Flyte cluster.
:param kwargs: keyword-arguments forwarded to :py:func:`~flytekit.task`.
This type of workflow will execute all flyte entities within it eagerly, meaning that all python constructs can be
Expand Down Expand Up @@ -430,12 +451,7 @@ async def eager_workflow(x: int) -> int:
client_secret_key="my_client_secret_key",
)
async def eager_workflow(x: int) -> int:
try:
out = await add_one(x)
except EagerException:
# The ValueError error is caught
# and raised as an EagerException
raise
out = await add_one(x)
return await double(one)
Where ``config.yaml`` contains is a flytectl-compatible config file.
Expand All @@ -446,10 +462,16 @@ async def eager_workflow(x: int) -> int:
.. code-block:: python
@eager(remote=FlyteRemote(config=Config.from_sandbox()))
@eager(remote=FlyteRemote(config=Config.for_sandbox()))
async def eager_workflow(x: int) -> int:
...
.. important::
When using ``local_entrypoint=True`` you also need to specify the ``remote`` argument. In this case, the eager
workflow runtime will be local, but all task/subworkflow invocations will occur on the specified Flyte cluster.
This argument is primarily used for testing and debugging eager workflow logic locally.
"""

if _fn is None:
Expand All @@ -458,9 +480,13 @@ async def eager_workflow(x: int) -> int:
remote=remote,
client_secret_group=client_secret_group,
client_secret_key=client_secret_key,
local_entrypoint=local_entrypoint,
**kwargs,
)

if local_entrypoint and remote is None:
raise ValueError("Must specify remote argument if local_entrypoint is True")

@wraps(_fn)
async def wrapper(*args, **kws):
# grab the "async_ctx" argument injected by PythonFunctionTask.execute
Expand All @@ -476,17 +502,21 @@ async def wrapper(*args, **kws):
execution_id = exec_params.execution_id

async_stack = AsyncStack(task_id, execution_id)
_remote = _prepare_remote(_remote, ctx, client_secret_group, client_secret_key)
_remote = _prepare_remote(_remote, ctx, client_secret_group, client_secret_key, local_entrypoint)

# make sure sub-nodes as cleaned up on termination signal
loop = asyncio.get_event_loop()
node_cleanup_partial = partial(node_cleanup_async, async_stack=async_stack)
cleanup_fn = partial(asyncio.ensure_future, node_cleanup_partial(signal.SIGTERM, loop))
signal.signal(signal.SIGTERM, partial(node_cleanup, loop=loop, async_stack=async_stack))

async with eager_context(_fn, _remote, ctx, async_stack, timeout, poll_interval):
async with eager_context(_fn, _remote, ctx, async_stack, timeout, poll_interval, local_entrypoint):
try:
out = await _fn(*args, **kws)
if _remote is not None:
with _remote.remote_context():
out = await _fn(*args, **kws)
else:
out = await _fn(*args, **kws)
# need to await for _fn to complete, then invoke the deck
await render_deck(async_stack)
return out
Expand All @@ -512,13 +542,23 @@ def _prepare_remote(
ctx: FlyteContext,
client_secret_group: Optional[str] = None,
client_secret_key: Optional[str] = None,
local_entrypoint: bool = False,
) -> Optional[FlyteRemote]:
"""Prepare FlyteRemote object for accessing Flyte cluster in a task running on the same cluster."""

if remote is None or ctx.execution_state.mode in {
is_local_execution_mode = ctx.execution_state.mode in {
ExecutionState.Mode.LOCAL_TASK_EXECUTION,
ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION,
}:
}

if remote is not None and local_entrypoint and is_local_execution_mode:
# when running eager workflows as a local entrypoint, we don't have to modify the remote object
# because we can assume that the user is running this from their local machine and can do browser-based
# authentication.
logger.info("Running eager workflow as local entrypoint")
return remote

if remote is None or is_local_execution_mode:
# if running the "eager workflow" (which is actually task) locally, run the task as a function,
# which doesn't need a remote object
return None
Expand Down
Empty file.
61 changes: 61 additions & 0 deletions tests/flytekit/integration/experimental/eager_workflows.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import asyncio
import os
import typing
from functools import partial
from pathlib import Path

import pandas as pd

from flytekit import task, workflow
from flytekit.configuration import Config
from flytekit.experimental import EagerException, eager
from flytekit.remote import FlyteRemote
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile
from flytekit.types.structured import StructuredDataset

remote = FlyteRemote(
config=Config.for_sandbox(),
Expand Down Expand Up @@ -39,6 +46,29 @@ def raises_exc(x: int) -> int:
return x


@task
def create_structured_dataset() -> StructuredDataset:
df = pd.DataFrame({"a": [1, 2, 3]})
return StructuredDataset(dataframe=df)


@task
def create_file() -> FlyteFile:
fname = "/tmp/flytekit_test_file"
with open(fname, "w") as fh:
fh.write("some data\n")
return FlyteFile(path=fname)


@task
def create_directory() -> FlyteDirectory:
dirname = "/tmp/flytekit_test_dir"
Path(dirname).mkdir(exist_ok=True, parents=True)
with open(os.path.join(dirname, "file"), "w") as tmp:
tmp.write("some data\n")
return FlyteDirectory(path=dirname)


@eager_partial
async def simple_eager_wf(x: int) -> int:
out = await add_one(x=x)
Expand Down Expand Up @@ -89,5 +119,36 @@ async def eager_wf_with_subworkflow(x: int) -> int:
return await double(x=out)


@eager_partial
async def eager_wf_structured_dataset() -> int:
dataset = await create_structured_dataset()
df = dataset.open(pd.DataFrame).all()
return int(df["a"].sum())


@eager_partial
async def eager_wf_flyte_file() -> str:
file = await create_file()
file.download()
with open(file.path) as f:
data = f.read().strip()
return data


@eager_partial
async def eager_wf_flyte_directory() -> str:
directory = await create_directory()
directory.download()
with open(os.path.join(directory.path, "file")) as f:
data = f.read().strip()
return data


@eager(remote=remote, local_entrypoint=True)
async def eager_wf_local_entrypoint(x: int) -> int:
out = await add_one(x=x)
return await double(x=out)


if __name__ == "__main__":
print(asyncio.run(simple_eager_wf(x=1)))
18 changes: 16 additions & 2 deletions tests/flytekit/integration/experimental/test_eager_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
```
"""


import asyncio
import os
import subprocess
import time
Expand All @@ -31,6 +31,8 @@
from flytekit.configuration import Config
from flytekit.remote import FlyteRemote

from .eager_workflows import eager_wf_local_entrypoint

MODULE = "eager_workflows"
MODULE_PATH = Path(__file__).parent / f"{MODULE}.py"
CONFIG = os.environ.get("FLYTECTL_CONFIG", str(Path.home() / ".flyte" / "config-sandbox.yaml"))
Expand Down Expand Up @@ -70,6 +72,9 @@ def register():
("eager", "gather_eager_wf", 1, [2] * 10),
("eager", "nested_eager_wf", 1, 8),
("eager", "eager_wf_with_subworkflow", 1, 4),
("eager", "eager_wf_structured_dataset", None, 6),
("eager", "eager_wf_flyte_file", None, "some data"),
("eager", "eager_wf_flyte_directory", None, "some data"),
("workflow", "wf_with_eager_wf", 1, 8),
],
)
Expand Down Expand Up @@ -98,5 +103,14 @@ def test_eager_workflows(register, entity_type, entity_name, input, output):
if entity is None:
raise RuntimeError("failed to fetch entity")

execution = remote.execute(entity, inputs={"x": input}, wait=True)
inputs = {} if input is None else {"x": input}
execution = remote.execute(entity, inputs=inputs, wait=True)
assert execution.outputs["o0"] == output


@pytest.mark.skipif(
os.environ.get("FLYTEKIT_CI", False), reason="Running workflows with sandbox cluster fails due to memory pressure"
)
def test_eager_workflow_local_entrypoint(register):
result = asyncio.run(eager_wf_local_entrypoint(x=1))
assert result == 4
Loading

0 comments on commit 16c73c0

Please sign in to comment.