Skip to content

Commit

Permalink
Improved type-hints for stage and source decorators (nv-morpheus#1831)
Browse files Browse the repository at this point in the history
* Improved type hints for both the input parameters and the return types

Closes  nv-morpheus#1812

## By Submitting this PR I confirm:
- I am familiar with the [Contributing Guidelines](https://github.com/nv-morpheus/Morpheus/blob/main/docs/source/developer_guide/contributing.md).
- When the PR is ready for review, new or existing tests cover these changes.
- When the PR is ready for review, the documentation is up to date with these changes.

Authors:
  - David Gardner (https://github.com/dagardner-nv)

Approvers:
  - Michael Demoret (https://github.com/mdemoret-nv)

URL: nv-morpheus#1831
  • Loading branch information
dagardner-nv authored Aug 7, 2024
1 parent ad9249c commit 929531f
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions morpheus/pipeline/stage_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@
from morpheus.messages import MultiMessage

logger = logging.getLogger(__name__)
GeneratorType = typing.Callable[..., collections.abc.Iterator[typing.Any]]

_InputT = typing.TypeVar('_InputT')
_OutputT = typing.TypeVar('_OutputT')
_P = typing.ParamSpec('_P')

GeneratorType = typing.Callable[_P, collections.abc.Iterator[_OutputT]]
ComputeSchemaType = typing.Callable[[_pipeline.StageSchema], None]


Expand Down Expand Up @@ -134,7 +139,12 @@ class PreAllocatedWrappedFunctionStage(_pipeline.PreallocatorMixin, WrappedFunct
"""


def source(gen_fn: GeneratorType = None, *, name: str = None, compute_schema_fn: ComputeSchemaType = None):
def source(
gen_fn: GeneratorType = None,
*,
name: str = None,
compute_schema_fn: ComputeSchemaType = None
) -> typing.Callable[typing.Concatenate[Config, _P], WrappedFunctionSourceStage]:
"""
Decorator for wrapping a function as a source stage. The function must be a generator method, and provide a
provide a return type annotation.
Expand Down Expand Up @@ -162,7 +172,7 @@ def source(gen_fn: GeneratorType = None, *, name: str = None, compute_schema_fn:
# Use wraps to ensure user's don't lose their function name and docstrinsgs, however we do want to override the
# annotations to reflect that the returned function requires a config and returns a stage
@functools.wraps(gen_fn, assigned=('__module__', '__name__', '__qualname__', '__doc__'))
def wrapper(config: Config, **kwargs) -> WrappedFunctionSourceStage:
def wrapper(config: Config, **kwargs: _P.kwargs) -> WrappedFunctionSourceStage:
nonlocal name
nonlocal compute_schema_fn

Expand Down Expand Up @@ -271,12 +281,15 @@ def _build_single(self, builder: mrc.Builder, input_node: mrc.SegmentObject) ->
return node


def stage(on_data_fn: typing.Callable = None,
DecoratedStageType = typing.Callable[typing.Concatenate[Config, _P], WrappedFunctionStage]


def stage(on_data_fn: typing.Callable[typing.Concatenate[_InputT, _P], _OutputT] = None,
*,
name: str = None,
accept_type: type = None,
compute_schema_fn: ComputeSchemaType = None,
needed_columns: dict[str, TypeId] = None):
needed_columns: dict[str, TypeId] = None) -> DecoratedStageType:
"""
Decorator for wrapping a function as a stage. The function must receive at least one argument, the first argument
must be the incoming message, and must return a value.
Expand Down Expand Up @@ -317,7 +330,7 @@ def stage(on_data_fn: typing.Callable = None,
# Use wraps to ensure user's don't lose their function name and docstrinsgs, however we do want to override the
# annotations to reflect that the returned function requires a config and returns a stage
@functools.wraps(on_data_fn, assigned=('__module__', '__name__', '__qualname__', '__doc__'))
def wrapper(config: Config, **kwargs) -> WrappedFunctionStage:
def wrapper(config: Config, **kwargs: _P.kwargs) -> WrappedFunctionStage:
nonlocal name
nonlocal accept_type
nonlocal compute_schema_fn
Expand Down

0 comments on commit 929531f

Please sign in to comment.