Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Data subsystem use fsspec #1531

Merged
merged 6 commits into from
Mar 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions flytekit/core/context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,10 @@ def raw_output_prefix(self) -> str:
return self._raw_output_prefix

@property
def working_directory(self) -> utils.AutoDeletingTempDir:
def working_directory(self) -> str:
"""
A handle to a special working directory for easily producing temporary files.

TODO: Usage examples
TODO: This does not always return a AutoDeletingTempDir
"""
return self._working_directory

Expand Down
15 changes: 7 additions & 8 deletions flytekit/core/data_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@
import pathlib
import tempfile
import typing
from typing import Union
from typing import Union, cast
from uuid import UUID

import fsspec
from fsspec.core import strip_protocol
from fsspec.utils import get_protocol

from flytekit import configuration
Expand Down Expand Up @@ -95,7 +94,7 @@ def __init__(

self._data_config = data_config if data_config else DataConfig.auto()
self._default_protocol = get_protocol(raw_output_prefix)
self._default_remote = self.get_filesystem(self._default_protocol)
self._default_remote = cast(fsspec.AbstractFileSystem, self.get_filesystem(self._default_protocol))
if os.name == "nt" and raw_output_prefix.startswith("file://"):
raise FlyteAssertion("Cannot use the file:// prefix on Windows.")
self._raw_output_prefix = (
Expand All @@ -113,11 +112,11 @@ def data_config(self) -> DataConfig:
return self._data_config

def get_filesystem(
self, protocol: str = None, anonymous: bool = False
self, protocol: typing.Optional[str] = None, anonymous: bool = False
) -> typing.Optional[fsspec.AbstractFileSystem]:
if not protocol:
return self._default_remote
kwargs = {}
kwargs = {} # type: typing.Dict[str, typing.Any]
if protocol == "file":
kwargs = {"auto_mkdir": True}
elif protocol == "s3":
Expand All @@ -134,9 +133,9 @@ def get_filesystem(

return fsspec.filesystem(protocol, **kwargs) # type: ignore

def get_filesystem_for_path(self, path: str = "") -> fsspec.AbstractFileSystem:
def get_filesystem_for_path(self, path: str = "", anonymous: bool = False) -> fsspec.AbstractFileSystem:
protocol = get_protocol(path)
return self.get_filesystem(protocol)
return self.get_filesystem(protocol, anonymous=anonymous)

@staticmethod
def is_remote(path: Union[str, os.PathLike]) -> bool:
Expand Down Expand Up @@ -322,7 +321,7 @@ def put_data(self, local_path: Union[str, os.PathLike], remote_path: str, is_mul
"""
try:
with PerformanceTimer(f"Writing ({local_path} -> {remote_path})"):
self.put(local_path, remote_path, recursive=is_multipart)
self.put(cast(str, local_path), remote_path, recursive=is_multipart)
except Exception as ex:
raise FlyteAssertion(
f"Failed to put data from {local_path} to {remote_path} (recursive={is_multipart}).\n\n"
Expand Down
52 changes: 37 additions & 15 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,20 @@
from __future__ import annotations

import base64
import functools
import hashlib
import importlib
import os
import pathlib
import tempfile
import time
import typing
import uuid
from base64 import b64encode
from collections import OrderedDict
from dataclasses import asdict, dataclass
from datetime import datetime, timedelta

import requests
from flyteidl.admin.signal_pb2 import Signal, SignalListRequest, SignalSetRequest
from flyteidl.core import literals_pb2 as literals_pb2

Expand All @@ -31,10 +34,15 @@
from flytekit.core.launch_plan import LaunchPlan
from flytekit.core.python_auto_container import PythonAutoContainerTask
from flytekit.core.reference_entity import ReferenceSpec
from flytekit.core.tracker import get_full_module_path
from flytekit.core.type_engine import LiteralsResolver, TypeEngine
from flytekit.core.workflow import WorkflowBase
from flytekit.exceptions import user as user_exceptions
from flytekit.exceptions.user import FlyteEntityAlreadyExistsException, FlyteEntityNotExistException
from flytekit.exceptions.user import (
FlyteEntityAlreadyExistsException,
FlyteEntityNotExistException,
FlyteValueException,
)
from flytekit.loggers import remote_logger
from flytekit.models import common as common_models
from flytekit.models import filters as filter_models
Expand Down Expand Up @@ -62,7 +70,7 @@
from flytekit.remote.lazy_entity import LazyEntity
from flytekit.remote.remote_callable import RemoteEntity
from flytekit.tools.fast_registration import fast_package
from flytekit.tools.script_mode import fast_register_single_script, hash_file
from flytekit.tools.script_mode import compress_single_script, hash_file
from flytekit.tools.translator import (
FlyteControlPlaneEntity,
FlyteLocalEntity,
Expand Down Expand Up @@ -728,7 +736,23 @@ def _upload_file(
content_md5=md5_bytes,
filename=to_upload.name,
)
self._ctx.file_access.put_data(str(to_upload), upload_location.signed_url)

encoded_md5 = b64encode(md5_bytes)
with open(str(to_upload), "+rb") as local_file:
content = local_file.read()
content_length = len(content)
rsp = requests.put(
upload_location.signed_url,
data=content,
headers={"Content-Length": str(content_length), "Content-MD5": encoded_md5},
)

if rsp.status_code != requests.codes["OK"]:
raise FlyteValueException(
rsp.status_code,
f"Request to send data {upload_location.signed_url} failed.",
)

remote_logger.debug(
f"Uploading {to_upload} to {upload_location.signed_url} native url {upload_location.native_url}"
)
Expand Down Expand Up @@ -795,16 +819,14 @@ def register_script(
if image_config is None:
image_config = ImageConfig.auto_default_image()

upload_location, md5_bytes = fast_register_single_script(
source_path,
module_name,
functools.partial(
self.client.get_upload_signed_url,
project=project or self.default_project,
domain=domain or self.default_domain,
filename="scriptmode.tar.gz",
),
)
with tempfile.TemporaryDirectory() as tmp_dir:
archive_fname = pathlib.Path(os.path.join(tmp_dir, "script_mode.tar.gz"))
mod = importlib.import_module(module_name)
compress_single_script(source_path, str(archive_fname), get_full_module_path(mod, mod.__name__))
md5_bytes, upload_native_url = self._upload_file(
archive_fname, project or self.default_project, domain or self.default_domain
)

serialization_settings = SerializationSettings(
project=project,
domain=domain,
Expand All @@ -813,7 +835,7 @@ def register_script(
fast_serialization_settings=FastSerializationSettings(
enabled=True,
destination_dir=destination_dir,
distribution_location=upload_location.native_url,
distribution_location=upload_native_url,
),
)

Expand Down
24 changes: 0 additions & 24 deletions flytekit/tools/script_mode.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,12 @@
import gzip
import hashlib
import importlib
import os
import shutil
import tarfile
import tempfile
import typing
from pathlib import Path

from flyteidl.service import dataproxy_pb2 as _data_proxy_pb2

from flytekit.core import context_manager
from flytekit.core.tracker import get_full_module_path


def compress_single_script(source_path: str, destination: str, full_module_name: str):
"""
Expand Down Expand Up @@ -96,24 +90,6 @@ def tar_strip_file_attributes(tar_info: tarfile.TarInfo) -> tarfile.TarInfo:
return tar_info


def fast_register_single_script(
source_path: str, module_name: str, create_upload_location_fn: typing.Callable
) -> (_data_proxy_pb2.CreateUploadLocationResponse, bytes):

# Open a temp directory and dump the contents of the digest.
with tempfile.TemporaryDirectory() as tmp_dir:
archive_fname = os.path.join(tmp_dir, "script_mode.tar.gz")
mod = importlib.import_module(module_name)
compress_single_script(source_path, archive_fname, get_full_module_path(mod, mod.__name__))

flyte_ctx = context_manager.FlyteContextManager.current_context()
md5, _ = hash_file(archive_fname)
upload_location = create_upload_location_fn(content_md5=md5)
flyte_ctx.file_access.put_data(archive_fname, upload_location.signed_url)

return upload_location, md5


def hash_file(file_path: typing.Union[os.PathLike, str]) -> (bytes, str):
"""
Hash a file and produce a digest to be used as a version
Expand Down
89 changes: 76 additions & 13 deletions flytekit/types/structured/basic_dfs.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import os
import typing
from pathlib import Path
from typing import TypeVar

import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from botocore.exceptions import NoCredentialsError
from fsspec.core import split_protocol, strip_protocol
from fsspec.utils import get_protocol

from flytekit import FlyteContext
from flytekit import FlyteContext, logger
from flytekit.configuration import DataConfig
from flytekit.core.data_persistence import s3_setup_args
from flytekit.deck import TopFrameRenderer
from flytekit.deck.renderer import ArrowRenderer
from flytekit.models import literals
Expand All @@ -23,6 +29,15 @@
T = TypeVar("T")


def get_storage_options(cfg: DataConfig, uri: str, anon: bool = False) -> typing.Optional[typing.Dict]:
protocol = get_protocol(uri)
if protocol == "s3":
kwargs = s3_setup_args(cfg.s3, anon)
if kwargs:
return kwargs
return None


class PandasToParquetEncodingHandler(StructuredDatasetEncoder):
def __init__(self):
super().__init__(pd.DataFrame, None, PARQUET)
Expand All @@ -33,6 +48,26 @@ def encode(
structured_dataset: StructuredDataset,
structured_dataset_type: StructuredDatasetType,
) -> literals.StructuredDataset:
uri = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory()
if not ctx.file_access.is_remote(uri):
Path(uri).mkdir(parents=True, exist_ok=True)
path = os.path.join(uri, f"{0:05}")
df = typing.cast(pd.DataFrame, structured_dataset.dataframe)
df.to_parquet(
path,
coerce_timestamps="us",
allow_truncated_timestamps=False,
storage_options=get_storage_options(ctx.file_access.data_config, path),
)
structured_dataset_type.format = PARQUET
return literals.StructuredDataset(uri=uri, metadata=StructuredDatasetMetadata(structured_dataset_type))

def ddencode(
self,
ctx: FlyteContext,
structured_dataset: StructuredDataset,
structured_dataset_type: StructuredDatasetType,
) -> literals.StructuredDataset:

path = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory()
df = typing.cast(pd.DataFrame, structured_dataset.dataframe)
Expand All @@ -53,6 +88,24 @@ def decode(
ctx: FlyteContext,
flyte_value: literals.StructuredDataset,
current_task_metadata: StructuredDatasetMetadata,
) -> pd.DataFrame:
uri = flyte_value.uri
columns = None
kwargs = get_storage_options(ctx.file_access.data_config, uri)
if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns:
columns = [c.name for c in current_task_metadata.structured_dataset_type.columns]
try:
return pd.read_parquet(uri, columns=columns, storage_options=kwargs)
except NoCredentialsError:
logger.debug("S3 source detected, attempting anonymous S3 access")
kwargs = get_storage_options(ctx.file_access.data_config, uri, anon=True)
return pd.read_parquet(uri, columns=columns, storage_options=kwargs)

def dcccecode(
self,
ctx: FlyteContext,
flyte_value: literals.StructuredDataset,
current_task_metadata: StructuredDatasetMetadata,
) -> pd.DataFrame:
path = flyte_value.uri
local_dir = ctx.file_access.get_random_local_directory()
Expand All @@ -73,13 +126,13 @@ def encode(
structured_dataset: StructuredDataset,
structured_dataset_type: StructuredDatasetType,
) -> literals.StructuredDataset:
path = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_path()
df = structured_dataset.dataframe
local_dir = ctx.file_access.get_random_local_directory()
local_path = os.path.join(local_dir, f"{0:05}")
pq.write_table(df, local_path)
ctx.file_access.upload_directory(local_dir, path)
return literals.StructuredDataset(uri=path, metadata=StructuredDatasetMetadata(structured_dataset_type))
uri = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory()
if not ctx.file_access.is_remote(uri):
Path(uri).mkdir(parents=True, exist_ok=True)
path = os.path.join(uri, f"{0:05}")
filesystem = ctx.file_access.get_filesystem_for_path(path)
pq.write_table(structured_dataset.dataframe, strip_protocol(path), filesystem=filesystem)
return literals.StructuredDataset(uri=uri, metadata=StructuredDatasetMetadata(structured_dataset_type))


class ParquetToArrowDecodingHandler(StructuredDatasetDecoder):
Expand All @@ -92,13 +145,23 @@ def decode(
flyte_value: literals.StructuredDataset,
current_task_metadata: StructuredDatasetMetadata,
) -> pa.Table:
path = flyte_value.uri
local_dir = ctx.file_access.get_random_local_directory()
ctx.file_access.get_data(path, local_dir, is_multipart=True)
uri = flyte_value.uri
if not ctx.file_access.is_remote(uri):
Path(uri).parent.mkdir(parents=True, exist_ok=True)
_, path = split_protocol(uri)

columns = None
if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns:
columns = [c.name for c in current_task_metadata.structured_dataset_type.columns]
return pq.read_table(local_dir, columns=columns)
return pq.read_table(local_dir)
try:
fs = ctx.file_access.get_filesystem_for_path(uri)
return pq.read_table(path, filesystem=fs, columns=columns)
except NoCredentialsError as e:
logger.debug("S3 source detected, attempting anonymous S3 access")
fs = ctx.file_access.get_filesystem_for_path(uri, anonymous=True)
if fs is not None:
return pq.read_table(path, filesystem=fs, columns=columns)
raise e


StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler(), default_format_for_type=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,13 @@
:toctree: generated/

ArrowToParquetEncodingHandler
FSSpecPersistence
PandasToParquetEncodingHandler
ParquetToArrowDecodingHandler
ParquetToPandasDecodingHandler
"""

__all__ = [
"ArrowToParquetEncodingHandler",
"FSSpecPersistence",
"PandasToParquetEncodingHandler",
"ParquetToArrowDecodingHandler",
"ParquetToPandasDecodingHandler",
Expand All @@ -28,7 +26,6 @@

from .arrow import ArrowToParquetEncodingHandler, ParquetToArrowDecodingHandler
from .pandas import PandasToParquetEncodingHandler, ParquetToPandasDecodingHandler
from .persist import FSSpecPersistence

S3 = "s3"
ABFS = "abfs"
Expand Down
Loading