diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index 01d9db29a9..9cdb3f302e 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -74,7 +74,6 @@ jobs: - name: Build and push to local registry run: | docker build . -f Dockerfile.dev -t localhost:30000/flytekit:dev --build-arg PYTHON_VERSION=${{ matrix.python-version }} - docker push localhost:30000/flytekit:dev - name: Integration Test with coverage env: FLYTEKIT_IMAGE: localhost:30000/flytekit:dev diff --git a/flytekit/experimental/eager_function.py b/flytekit/experimental/eager_function.py index 2ade725c4f..264d0d641a 100644 --- a/flytekit/experimental/eager_function.py +++ b/flytekit/experimental/eager_function.py @@ -34,7 +34,7 @@ }} -

{entity_type}: `{entity_name}`

+

{entity_type}: {entity_name}

Execution: @@ -99,13 +99,14 @@ def __init__( async_stack: "AsyncStack", timeout: Optional[timedelta] = None, poll_interval: Optional[timedelta] = None, + local_entrypoint: bool = False, ): self.entity = entity self.ctx = ctx self.async_stack = async_stack self.execution_state = self.ctx.execution_state.mode - # TODO: move this out into the eager wrapper. self.remote = remote + self.local_entrypoint = local_entrypoint if self.remote is not None: logger.debug(f"Using remote config: {self.remote.config}") else: @@ -127,7 +128,7 @@ async def __call__(self, **kwargs): "If you need to use a subworkflow, use a static @workflow or nested @eager workflow." ) - if self.ctx.execution_state.is_local_execution(): + if not self.local_entrypoint and self.ctx.execution_state.is_local_execution(): # If running as a local workflow execution, just execute the python function try: if isinstance(self.entity, WorkflowBase): @@ -137,7 +138,8 @@ async def __call__(self, **kwargs): out = await out return out elif isinstance(self.entity, PythonTask): - out = self.entity._task_function(**kwargs) + # invoke the task-decorated entity + out = self.entity(**kwargs) if inspect.iscoroutine(out): out = await out return out @@ -167,9 +169,14 @@ async def __call__(self, **kwargs): self._execution = execution url = self.remote.generate_console_url(execution) + msg = f"Running flyte {type(self.entity)} {entity_name} on remote cluster: {url}" + if self.local_entrypoint: + logger.info(msg) + else: + logger.debug(msg) + node = AsyncNode(self, entity_name, execution, url) self.async_stack.set_node(node) - logger.debug(url) poll_interval = self._poll_interval or timedelta(seconds=30) time_to_give_up = datetime.max if self._timeout is None else datetime.utcnow() + self._timeout @@ -219,7 +226,16 @@ def __init__(self, async_entity, entity_name, execution=None, url=None): self.entity_name = entity_name self.async_entity = async_entity self.execution = execution - self.url = url + self._url = url + + @property + def url(self) -> str: + # make sure that internal flyte sandbox endpoint is replaced with localhost endpoint when rendering the urls + # for flyte decks + endpoint_root = FLYTE_SANDBOX_INTERNAL_ENDPOINT.replace("http://", "") + if endpoint_root in self._url: + return self._url.replace(endpoint_root, "localhost:30080") + return self._url @property def entity_type(self) -> str: @@ -273,7 +289,7 @@ def get_io(dict_like): except Exception: return dict_like - output = "# Nodes\n\n


" + output = "

Nodes


" for node in async_stack.call_stack: node_inputs = get_io(node.execution.inputs) if node.execution.closure.phase in {WorkflowExecutionPhase.FAILED}: @@ -296,11 +312,12 @@ def get_io(dict_like): @asynccontextmanager async def eager_context( fn, - remote: FlyteRemote, + remote: Optional[FlyteRemote], ctx: FlyteContext, async_stack: AsyncStack, timeout: Optional[timedelta] = None, poll_interval: Optional[timedelta] = None, + local_entrypoint: bool = False, ): """This context manager overrides all tasks in the global namespace with async versions.""" @@ -310,7 +327,7 @@ async def eager_context( for k, v in fn.__globals__.items(): if isinstance(v, (PythonTask, WorkflowBase)): _original_cache[k] = v - fn.__globals__[k] = AsyncEntity(v, remote, ctx, async_stack, timeout, poll_interval) + fn.__globals__[k] = AsyncEntity(v, remote, ctx, async_stack, timeout, poll_interval, local_entrypoint) try: yield @@ -356,6 +373,7 @@ def eager( client_secret_key: Optional[str] = None, timeout: Optional[timedelta] = None, poll_interval: Optional[timedelta] = None, + local_entrypoint: bool = False, **kwargs, ): """Eager workflow decorator. @@ -367,6 +385,9 @@ def eager( workflow to complete or terminate. By default, the eager workflow will wait indefinitely until complete. :param poll_interval: The poll interval for checking if a task/workflow execution within the eager workflow has finished. If not specified, the default poll interval is 6 seconds. + :param local_entrypoint: If True, the eager workflow will can be executed locally but use the provided + :py:func:`~flytekit.remote.FlyteRemote` object to create task/workflow executions. This is useful for local + testing against a remote Flyte cluster. :param kwargs: keyword-arguments forwarded to :py:func:`~flytekit.task`. This type of workflow will execute all flyte entities within it eagerly, meaning that all python constructs can be @@ -430,12 +451,7 @@ async def eager_workflow(x: int) -> int: client_secret_key="my_client_secret_key", ) async def eager_workflow(x: int) -> int: - try: - out = await add_one(x) - except EagerException: - # The ValueError error is caught - # and raised as an EagerException - raise + out = await add_one(x) return await double(one) Where ``config.yaml`` contains is a flytectl-compatible config file. @@ -446,10 +462,16 @@ async def eager_workflow(x: int) -> int: .. code-block:: python - @eager(remote=FlyteRemote(config=Config.from_sandbox())) + @eager(remote=FlyteRemote(config=Config.for_sandbox())) async def eager_workflow(x: int) -> int: ... + .. important:: + + When using ``local_entrypoint=True`` you also need to specify the ``remote`` argument. In this case, the eager + workflow runtime will be local, but all task/subworkflow invocations will occur on the specified Flyte cluster. + This argument is primarily used for testing and debugging eager workflow logic locally. + """ if _fn is None: @@ -458,9 +480,13 @@ async def eager_workflow(x: int) -> int: remote=remote, client_secret_group=client_secret_group, client_secret_key=client_secret_key, + local_entrypoint=local_entrypoint, **kwargs, ) + if local_entrypoint and remote is None: + raise ValueError("Must specify remote argument if local_entrypoint is True") + @wraps(_fn) async def wrapper(*args, **kws): # grab the "async_ctx" argument injected by PythonFunctionTask.execute @@ -476,7 +502,7 @@ async def wrapper(*args, **kws): execution_id = exec_params.execution_id async_stack = AsyncStack(task_id, execution_id) - _remote = _prepare_remote(_remote, ctx, client_secret_group, client_secret_key) + _remote = _prepare_remote(_remote, ctx, client_secret_group, client_secret_key, local_entrypoint) # make sure sub-nodes as cleaned up on termination signal loop = asyncio.get_event_loop() @@ -484,9 +510,13 @@ async def wrapper(*args, **kws): cleanup_fn = partial(asyncio.ensure_future, node_cleanup_partial(signal.SIGTERM, loop)) signal.signal(signal.SIGTERM, partial(node_cleanup, loop=loop, async_stack=async_stack)) - async with eager_context(_fn, _remote, ctx, async_stack, timeout, poll_interval): + async with eager_context(_fn, _remote, ctx, async_stack, timeout, poll_interval, local_entrypoint): try: - out = await _fn(*args, **kws) + if _remote is not None: + with _remote.remote_context(): + out = await _fn(*args, **kws) + else: + out = await _fn(*args, **kws) # need to await for _fn to complete, then invoke the deck await render_deck(async_stack) return out @@ -512,13 +542,23 @@ def _prepare_remote( ctx: FlyteContext, client_secret_group: Optional[str] = None, client_secret_key: Optional[str] = None, + local_entrypoint: bool = False, ) -> Optional[FlyteRemote]: """Prepare FlyteRemote object for accessing Flyte cluster in a task running on the same cluster.""" - if remote is None or ctx.execution_state.mode in { + is_local_execution_mode = ctx.execution_state.mode in { ExecutionState.Mode.LOCAL_TASK_EXECUTION, ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION, - }: + } + + if remote is not None and local_entrypoint and is_local_execution_mode: + # when running eager workflows as a local entrypoint, we don't have to modify the remote object + # because we can assume that the user is running this from their local machine and can do browser-based + # authentication. + logger.info("Running eager workflow as local entrypoint") + return remote + + if remote is None or is_local_execution_mode: # if running the "eager workflow" (which is actually task) locally, run the task as a function, # which doesn't need a remote object return None diff --git a/tests/flytekit/integration/experimental/__init__.py b/tests/flytekit/integration/experimental/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/flytekit/integration/experimental/eager_workflows.py b/tests/flytekit/integration/experimental/eager_workflows.py index b8bc1c9a1b..2dbc28a640 100644 --- a/tests/flytekit/integration/experimental/eager_workflows.py +++ b/tests/flytekit/integration/experimental/eager_workflows.py @@ -1,11 +1,18 @@ import asyncio +import os import typing from functools import partial +from pathlib import Path + +import pandas as pd from flytekit import task, workflow from flytekit.configuration import Config from flytekit.experimental import EagerException, eager from flytekit.remote import FlyteRemote +from flytekit.types.directory import FlyteDirectory +from flytekit.types.file import FlyteFile +from flytekit.types.structured import StructuredDataset remote = FlyteRemote( config=Config.for_sandbox(), @@ -39,6 +46,29 @@ def raises_exc(x: int) -> int: return x +@task +def create_structured_dataset() -> StructuredDataset: + df = pd.DataFrame({"a": [1, 2, 3]}) + return StructuredDataset(dataframe=df) + + +@task +def create_file() -> FlyteFile: + fname = "/tmp/flytekit_test_file" + with open(fname, "w") as fh: + fh.write("some data\n") + return FlyteFile(path=fname) + + +@task +def create_directory() -> FlyteDirectory: + dirname = "/tmp/flytekit_test_dir" + Path(dirname).mkdir(exist_ok=True, parents=True) + with open(os.path.join(dirname, "file"), "w") as tmp: + tmp.write("some data\n") + return FlyteDirectory(path=dirname) + + @eager_partial async def simple_eager_wf(x: int) -> int: out = await add_one(x=x) @@ -89,5 +119,36 @@ async def eager_wf_with_subworkflow(x: int) -> int: return await double(x=out) +@eager_partial +async def eager_wf_structured_dataset() -> int: + dataset = await create_structured_dataset() + df = dataset.open(pd.DataFrame).all() + return int(df["a"].sum()) + + +@eager_partial +async def eager_wf_flyte_file() -> str: + file = await create_file() + file.download() + with open(file.path) as f: + data = f.read().strip() + return data + + +@eager_partial +async def eager_wf_flyte_directory() -> str: + directory = await create_directory() + directory.download() + with open(os.path.join(directory.path, "file")) as f: + data = f.read().strip() + return data + + +@eager(remote=remote, local_entrypoint=True) +async def eager_wf_local_entrypoint(x: int) -> int: + out = await add_one(x=x) + return await double(x=out) + + if __name__ == "__main__": print(asyncio.run(simple_eager_wf(x=1))) diff --git a/tests/flytekit/integration/experimental/test_eager_workflows.py b/tests/flytekit/integration/experimental/test_eager_workflows.py index a4b2186cf5..ad1bc44112 100644 --- a/tests/flytekit/integration/experimental/test_eager_workflows.py +++ b/tests/flytekit/integration/experimental/test_eager_workflows.py @@ -20,7 +20,7 @@ ``` """ - +import asyncio import os import subprocess import time @@ -31,6 +31,8 @@ from flytekit.configuration import Config from flytekit.remote import FlyteRemote +from .eager_workflows import eager_wf_local_entrypoint + MODULE = "eager_workflows" MODULE_PATH = Path(__file__).parent / f"{MODULE}.py" CONFIG = os.environ.get("FLYTECTL_CONFIG", str(Path.home() / ".flyte" / "config-sandbox.yaml")) @@ -70,6 +72,9 @@ def register(): ("eager", "gather_eager_wf", 1, [2] * 10), ("eager", "nested_eager_wf", 1, 8), ("eager", "eager_wf_with_subworkflow", 1, 4), + ("eager", "eager_wf_structured_dataset", None, 6), + ("eager", "eager_wf_flyte_file", None, "some data"), + ("eager", "eager_wf_flyte_directory", None, "some data"), ("workflow", "wf_with_eager_wf", 1, 8), ], ) @@ -98,5 +103,14 @@ def test_eager_workflows(register, entity_type, entity_name, input, output): if entity is None: raise RuntimeError("failed to fetch entity") - execution = remote.execute(entity, inputs={"x": input}, wait=True) + inputs = {} if input is None else {"x": input} + execution = remote.execute(entity, inputs=inputs, wait=True) assert execution.outputs["o0"] == output + + +@pytest.mark.skipif( + os.environ.get("FLYTEKIT_CI", False), reason="Running workflows with sandbox cluster fails due to memory pressure" +) +def test_eager_workflow_local_entrypoint(register): + result = asyncio.run(eager_wf_local_entrypoint(x=1)) + assert result == 4 diff --git a/tests/flytekit/unit/experimental/test_eager_workflows.py b/tests/flytekit/unit/experimental/test_eager_workflows.py index abebf6f292..f588c04db0 100644 --- a/tests/flytekit/unit/experimental/test_eager_workflows.py +++ b/tests/flytekit/unit/experimental/test_eager_workflows.py @@ -1,13 +1,22 @@ import asyncio +import os import typing +from pathlib import Path import hypothesis.strategies as st +import pandas as pd import pytest -from hypothesis import given, infer, settings +from hypothesis import given, settings from flytekit import dynamic, task, workflow from flytekit.core.type_engine import TypeTransformerFailedError from flytekit.experimental import EagerException, eager +from flytekit.types.directory import FlyteDirectory +from flytekit.types.file import FlyteFile +from flytekit.types.structured import StructuredDataset + +DEADLINE = 2000 +INTEGER_ST = st.integers(max_value=10_000_000) @task @@ -38,8 +47,8 @@ def dynamic_wf(x: int) -> int: return double(x=out) -@given(x_input=infer) -@settings(deadline=1000, max_examples=5) +@given(x_input=INTEGER_ST) +@settings(deadline=DEADLINE, max_examples=5) def test_simple_eager_workflow(x_input: int): """Testing simple eager workflow with just tasks.""" @@ -52,8 +61,8 @@ async def eager_wf(x: int) -> int: assert result == (x_input + 1) * 2 -@given(x_input=infer) -@settings(deadline=1000, max_examples=5) +@given(x_input=INTEGER_ST) +@settings(deadline=DEADLINE, max_examples=5) def test_conditional_eager_workflow(x_input: int): """Test eager workfow with conditional logic.""" @@ -70,8 +79,8 @@ async def eager_wf(x: int) -> int: assert result == 1 -@given(x_input=infer) -@settings(deadline=1000, max_examples=5) +@given(x_input=INTEGER_ST) +@settings(deadline=DEADLINE, max_examples=5) def test_try_except_eager_workflow(x_input: int): """Test eager workflow with try/except logic.""" @@ -89,8 +98,8 @@ async def eager_wf(x: int) -> int: assert result == x_input -@given(x_input=infer, n_input=st.integers(min_value=1, max_value=20)) -@settings(deadline=1000, max_examples=5) +@given(x_input=INTEGER_ST, n_input=st.integers(min_value=1, max_value=20)) +@settings(deadline=DEADLINE, max_examples=5) def test_gather_eager_workflow(x_input: int, n_input: int): """Test eager workflow with asyncio gather.""" @@ -103,8 +112,8 @@ async def eager_wf(x: int, n: int) -> typing.List[int]: assert results == [x_input + 1 for _ in range(n_input)] -@given(x_input=infer) -@settings(deadline=1000, max_examples=5) +@given(x_input=INTEGER_ST) +@settings(deadline=DEADLINE, max_examples=5) def test_eager_workflow_with_dynamic_exception(x_input: int): """Test eager workflow with dynamic workflow is not supported.""" @@ -121,8 +130,8 @@ async def nested_eager_wf(x: int) -> int: return await add_one(x=x) -@given(x_input=infer) -@settings(deadline=1000, max_examples=5) +@given(x_input=INTEGER_ST) +@settings(deadline=DEADLINE, max_examples=5) def test_nested_eager_workflow(x_input: int): """Testing running nested eager workflows.""" @@ -135,8 +144,8 @@ async def eager_wf(x: int) -> int: assert result == (x_input + 1) * 2 -@given(x_input=infer) -@settings(deadline=1000, max_examples=5) +@given(x_input=INTEGER_ST) +@settings(deadline=DEADLINE, max_examples=5) def test_eager_workflow_within_workflow(x_input: int): """Testing running eager workflow within a static workflow.""" @@ -158,8 +167,8 @@ def subworkflow(x: int) -> int: return add_one(x=x) -@given(x_input=infer) -@settings(deadline=1000, max_examples=5) +@given(x_input=INTEGER_ST) +@settings(deadline=DEADLINE, max_examples=5) def test_workflow_within_eager_workflow(x_input: int): """Testing running a static workflow within an eager workflow.""" @@ -172,10 +181,10 @@ async def eager_wf(x: int) -> int: assert result == (x_input + 1) * 2 -@given(x_input=infer) -@settings(deadline=1000, max_examples=5) +@given(x_input=INTEGER_ST) +@settings(deadline=DEADLINE, max_examples=5) def test_local_task_eager_workflow_exception(x_input: int): - """Testing simple eager workflow with just tasks.""" + """Testing simple eager workflow with a local function task doesn't work.""" @task def local_task(x: int) -> int: @@ -189,8 +198,8 @@ async def eager_wf_with_local(x: int) -> int: asyncio.run(eager_wf_with_local(x=x_input)) -@given(x_input=infer) -@settings(deadline=1000, max_examples=5) +@given(x_input=INTEGER_ST) +@settings(deadline=DEADLINE, max_examples=5) @pytest.mark.filterwarnings("ignore:coroutine 'AsyncEntity.__call__' was never awaited") def test_local_workflow_within_eager_workflow_exception(x_input: int): """Cannot call a locally-defined workflow within an eager workflow""" @@ -206,3 +215,59 @@ async def eager_wf(x: int) -> int: with pytest.raises(TypeTransformerFailedError): asyncio.run(eager_wf(x=x_input)) + + +@task +def create_structured_dataset() -> StructuredDataset: + df = pd.DataFrame({"a": [1, 2, 3]}) + return StructuredDataset(dataframe=df) + + +@task +def create_file() -> FlyteFile: + fname = "/tmp/flytekit_test_file" + with open(fname, "w") as fh: + fh.write("some data\n") + return FlyteFile(path=fname) + + +@task +def create_directory() -> FlyteDirectory: + dirname = "/tmp/flytekit_test_dir" + Path(dirname).mkdir(exist_ok=True, parents=True) + with open(os.path.join(dirname, "file"), "w") as tmp: + tmp.write("some data\n") + return FlyteDirectory(path=dirname) + + +def test_eager_workflow_with_offloaded_types(): + """Test eager workflow that eager workflows work with offloaded types.""" + + @eager + async def eager_wf_structured_dataset() -> int: + dataset = await create_structured_dataset() + df = dataset.open(pd.DataFrame).all() + return df["a"].sum() + + @eager + async def eager_wf_flyte_file() -> str: + file = await create_file() + with open(file.path) as f: + data = f.read().strip() + return data + + @eager + async def eager_wf_flyte_directory() -> str: + directory = await create_directory() + with open(os.path.join(directory.path, "file")) as f: + data = f.read().strip() + return data + + result = asyncio.run(eager_wf_structured_dataset()) + assert result == 6 + + result = asyncio.run(eager_wf_flyte_file()) + assert result == "some data" + + result = asyncio.run(eager_wf_flyte_directory()) + assert result == "some data"