Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

When using the task and workflow decorator, correctly wrap the fu… #780

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion flytekit/core/task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime as _datetime
from functools import update_wrapper
from typing import Any, Callable, Dict, List, Optional, Type, Union

from flytekit.core.base_task import TaskMetadata, TaskResolverMixin
Expand Down Expand Up @@ -195,7 +196,7 @@ def wrapper(fn) -> PythonFunctionTask:
execution_mode=execution_mode,
task_resolver=task_resolver,
)

update_wrapper(task_instance, fn)
return task_instance

if _task_function:
Expand Down
2 changes: 2 additions & 0 deletions flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from dataclasses import dataclass
from enum import Enum
from functools import update_wrapper
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

from flytekit.common import constants as _common_constants
Expand Down Expand Up @@ -730,6 +731,7 @@ def wrapper(fn):
docstring=Docstring(callable_=fn),
)
workflow_instance.compile()
update_wrapper(workflow_instance, fn)
return workflow_instance

if _workflow_function:
Expand Down
4 changes: 2 additions & 2 deletions plugins/flytekit-greatexpectations/tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ def valid_wf(dataset: str = "yellow_tripdata_sample_2019-01.csv") -> int:
task_object(dataset=dataset)
return my_task(csv_file=dataset)

@pytest.mark.xfail(strict=True)
@workflow
def invalid_wf(dataset: str = "yellow_tripdata_sample_2019-02.csv") -> int:
task_object(dataset=dataset)
Expand All @@ -158,7 +157,8 @@ def invalid_wf(dataset: str = "yellow_tripdata_sample_2019-02.csv") -> int:
valid_result = valid_wf()
assert valid_result == 10000

invalid_wf()
with pytest.raises(ValidationError, match=r".*passenger_count -> expect_column_min_to_be_between.*"):
invalid_wf()


def test_ge_workflow():
Expand Down
58 changes: 58 additions & 0 deletions tests/flytekit/unit/core/test_wrapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from functools import wraps

from flytekit import task, workflow


def test_task_correctly_wrapped():
eapolinario marked this conversation as resolved.
Show resolved Hide resolved
@task
def my_task(a: int) -> int:
return a

assert my_task.__wrapped__ == my_task._task_function


def test_stacked_decorators():
def task_decorator_1(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
print("running task_decorator_1")
return fn(*args, **kwargs)

return wrapper

def task_decorator_2(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
print("running task_decorator_2")
return fn(*args, **kwargs)

return wrapper

def task_decorator_3(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
print("running task_decorator_3")
return fn(*args, **kwargs)

return wrapper

@task
@task_decorator_1
@task_decorator_2
@task_decorator_3
def my_task(x: int) -> int:
"""Some function doc"""
print("running my_task")
return x + 1

assert my_task.__wrapped__.__doc__ == "Some function doc"
assert my_task.__wrapped__ == my_task._task_function
assert my_task(x=10) == 11


def test_wf_correctly_wrapped():
@workflow
def my_workflow(a: int) -> int:
return a

assert my_workflow.__wrapped__ == my_workflow._workflow_function