Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw committed Aug 24, 2023
1 parent 36a6a6c commit c1fe0dd
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 9 deletions.
6 changes: 3 additions & 3 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ def t1(directory: Annotated[FlyteDirectory, BatchSize(10)]) -> Annotated[FlyteDi
...
return FlyteDirectory(...)
In the above example flytekit will download all files from the input `directory` in chunks of 10, i.e. first it
downloads 10 files, loads them to memory, then writes those 10 to local disk, then it loads the next 10, so on
and so forth. Similarly, for outputs, in this case flytekit is going to upload the resulting directory in chunks of
In the above example flytekit will download all files from the input `directory` in chunks of 10, i.e. first it
downloads 10 files, loads them to memory, then writes those 10 to local disk, then it loads the next 10, so on
and so forth. Similarly, for outputs, in this case flytekit is going to upload the resulting directory in chunks of
100.
"""

Expand Down
7 changes: 1 addition & 6 deletions flytekit/types/directory/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@
from dataclasses_json import config, dataclass_json
from fsspec.utils import get_protocol
from marshmallow import fields
from typing_extensions import get_args

from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.type_engine import TypeEngine, TypeTransformer, get_batch_size, is_annotated
from flytekit.core.type_engine import TypeEngine, TypeTransformer, get_batch_size
from flytekit.models import types as _type_models
from flytekit.models.core import types as _core_types
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
Expand Down Expand Up @@ -325,8 +324,6 @@ def to_literal(
should_upload = True
batch_size = get_batch_size(python_type)

if is_annotated(python_type):
python_type = get_args(python_type)[0]
meta = BlobMetadata(type=self._blob_type(format=self.get_format(python_type)))

# There are two kinds of literals we handle, either an actual FlyteDirectory, or a string path to a directory.
Expand Down Expand Up @@ -385,8 +382,6 @@ def to_python_value(
local_folder = ctx.file_access.get_random_local_directory()

batch_size = get_batch_size(expected_python_type)
if is_annotated(expected_python_type):
expected_python_type = get_args(expected_python_type)[0]

def _downloader():
return ctx.file_access.get_data(uri, local_folder, is_multipart=True, batch_size=batch_size)
Expand Down
10 changes: 10 additions & 0 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,16 @@ def test_dir_no_downloader_default():
assert pv.download() == local_dir


def test_dir_with_batch_size():
flyte_dir = Annotated[FlyteDirectory, BatchSize(100)]
val = flyte_dir("s3://bucket/key")
transformer = TypeEngine.get_transformer(flyte_dir)
ctx = FlyteContext.current_context()
lt = transformer.get_literal_type(flyte_dir)
lv = transformer.to_literal(ctx, val, flyte_dir, lt)
assert val.path == transformer.to_python_value(ctx, lv, flyte_dir).remote_source


def test_dict_transformer():
d = DictTransformer()

Expand Down

0 comments on commit c1fe0dd

Please sign in to comment.