From c1fe0dd7cff8f707b4d3f7306d11ff7acd1d2208 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 24 Aug 2023 12:53:50 -0700 Subject: [PATCH] Add test Signed-off-by: Kevin Su --- flytekit/core/type_engine.py | 6 +++--- flytekit/types/directory/types.py | 7 +------ tests/flytekit/unit/core/test_type_engine.py | 10 ++++++++++ 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 9efdd92fa2..9b1fa4fa70 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -64,9 +64,9 @@ def t1(directory: Annotated[FlyteDirectory, BatchSize(10)]) -> Annotated[FlyteDi ... return FlyteDirectory(...) - In the above example flytekit will download all files from the input `directory` in chunks of 10, i.e. first it - downloads 10 files, loads them to memory, then writes those 10 to local disk, then it loads the next 10, so on - and so forth. Similarly, for outputs, in this case flytekit is going to upload the resulting directory in chunks of + In the above example flytekit will download all files from the input `directory` in chunks of 10, i.e. first it + downloads 10 files, loads them to memory, then writes those 10 to local disk, then it loads the next 10, so on + and so forth. Similarly, for outputs, in this case flytekit is going to upload the resulting directory in chunks of 100. """ diff --git a/flytekit/types/directory/types.py b/flytekit/types/directory/types.py index 0f82178abe..fb6e728736 100644 --- a/flytekit/types/directory/types.py +++ b/flytekit/types/directory/types.py @@ -13,10 +13,9 @@ from dataclasses_json import config, dataclass_json from fsspec.utils import get_protocol from marshmallow import fields -from typing_extensions import get_args from flytekit.core.context_manager import FlyteContext, FlyteContextManager -from flytekit.core.type_engine import TypeEngine, TypeTransformer, get_batch_size, is_annotated +from flytekit.core.type_engine import TypeEngine, TypeTransformer, get_batch_size from flytekit.models import types as _type_models from flytekit.models.core import types as _core_types from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar @@ -325,8 +324,6 @@ def to_literal( should_upload = True batch_size = get_batch_size(python_type) - if is_annotated(python_type): - python_type = get_args(python_type)[0] meta = BlobMetadata(type=self._blob_type(format=self.get_format(python_type))) # There are two kinds of literals we handle, either an actual FlyteDirectory, or a string path to a directory. @@ -385,8 +382,6 @@ def to_python_value( local_folder = ctx.file_access.get_random_local_directory() batch_size = get_batch_size(expected_python_type) - if is_annotated(expected_python_type): - expected_python_type = get_args(expected_python_type)[0] def _downloader(): return ctx.file_access.get_data(uri, local_folder, is_multipart=True, batch_size=batch_size) diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 411ffa85dc..ff69f7bcd9 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -302,6 +302,16 @@ def test_dir_no_downloader_default(): assert pv.download() == local_dir +def test_dir_with_batch_size(): + flyte_dir = Annotated[FlyteDirectory, BatchSize(100)] + val = flyte_dir("s3://bucket/key") + transformer = TypeEngine.get_transformer(flyte_dir) + ctx = FlyteContext.current_context() + lt = transformer.get_literal_type(flyte_dir) + lv = transformer.to_literal(ctx, val, flyte_dir, lt) + assert val.path == transformer.to_python_value(ctx, lv, flyte_dir).remote_source + + def test_dict_transformer(): d = DictTransformer()