Skip to content

Commit

Permalink
Caching of offloaded objects (#762)
Browse files Browse the repository at this point in the history
* Remove flyteidl from install_requires

Signed-off-by: Eduardo Apolinario <[email protected]>

* Expose hash in Literal

Signed-off-by: Eduardo Apolinario <[email protected]>

* Set hash in TypeEngine

Signed-off-by: Eduardo Apolinario <[email protected]>

* Modify cache key calculation to take hash into account

Signed-off-by: Eduardo Apolinario <[email protected]>

* Opt-in PandasDataFrameTransformer

Signed-off-by: Eduardo Apolinario <[email protected]>

* Add unit tests

Signed-off-by: Eduardo Apolinario <[email protected]>

* Iterate using a flyteidl branch

Signed-off-by: Eduardo Apolinario <[email protected]>

* Regenerate requirements files

Signed-off-by: Eduardo Apolinario <[email protected]>

* Regenerate requirements files

Signed-off-by: Eduardo Apolinario <[email protected]>

* Move _hash_overridable to StructureDatasetTransformerEngine

Signed-off-by: Eduardo Apolinario <[email protected]>

* Move HashMethod to flytekit.core.hash

Signed-off-by: Eduardo Apolinario <[email protected]>

* Fix `unit_test` make target

Signed-off-by: Eduardo Apolinario <[email protected]>

* Split `unit_test` make target in two lines

Signed-off-by: Eduardo Apolinario <[email protected]>

* Add assert to structured dataset compatibility test

Signed-off-by: Eduardo Apolinario <[email protected]>

* Remove TODO

Signed-off-by: Eduardo Apolinario <[email protected]>

* Regenerate plugins requirements files pointing to the right version of flyteidl.

Signed-off-by: Eduardo Apolinario <[email protected]>

* Set hash as a property of the literal

Signed-off-by: Eduardo Apolinario <[email protected]>

* Install plugins requirements in CI.

Signed-off-by: Eduardo Apolinario <[email protected]>

* Add hash.setter

Signed-off-by: Eduardo Apolinario <[email protected]>

* Install flyteidl directly

Signed-off-by: Eduardo Apolinario <[email protected]>

* Revert "Regenerate plugins requirements files pointing to the right version of flyteidl."

This reverts commit c2dbb54.

Signed-off-by: Eduardo Apolinario <[email protected]>

* wip - Add support for univariate lists

Signed-off-by: Eduardo Apolinario <[email protected]>

* Add support for lists of annotated objects

Signed-off-by: Eduardo Apolinario <[email protected]>

* Revamp generation of cache key (to cover case of literals collections and maps)

Signed-off-by: Eduardo Apolinario <[email protected]>

* Leave TODO for warning

Signed-off-by: Eduardo Apolinario <[email protected]>

* Revert "Add support for lists of annotated objects"

This reverts commit 4b5f608.

Signed-off-by: Eduardo Apolinario <[email protected]>

* Revert "wip - Add support for univariate lists"

This reverts commit adaa448.

Signed-off-by: Eduardo Apolinario <[email protected]>

* Remove docstring

Signed-off-by: Eduardo Apolinario <[email protected]>

* Add flyteidl>=0.23.0

Signed-off-by: Eduardo Apolinario <[email protected]>

* Remove mentions to branch flyteidl@add-hash-to-literal

Signed-off-by: Eduardo Apolinario <[email protected]>

* Bump flyteidl in plugins requirements

Signed-off-by: Eduardo Apolinario <[email protected]>

* Regenerate plugins requirements again

Signed-off-by: Eduardo Apolinario <[email protected]>

* Restore papermill/requirements.txt

Signed-off-by: Eduardo Apolinario <[email protected]>

* Point flytekitplugins-spark to the offloaded-objects-caching branch in papermill tests

Signed-off-by: Eduardo Apolinario <[email protected]>

* Set flyteidl>=0.23.0 in papermill dev-requirements

Co-authored-by: Eduardo Apolinario <[email protected]>
  • Loading branch information
eapolinario and eapolinario authored Mar 2, 2022
1 parent 10ab48e commit 0da523c
Show file tree
Hide file tree
Showing 36 changed files with 1,020 additions and 336 deletions.
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ test: lint unit_test

.PHONY: unit_test
unit_test:
FLYTE_SDK_USE_STRUCTURED_DATASET=TRUE pytest tests/flytekit/unit tests/flytekit_compatibility
FLYTE_SDK_USE_STRUCTURED_DATASET=FALSE pytest tests/flytekit_compatibility && \
FLYTE_SDK_USE_STRUCTURED_DATASET=TRUE pytest tests/flytekit/unit

requirements-spark2.txt: export CUSTOM_COMPILE_COMMAND := make requirements-spark2.txt
requirements-spark2.txt: requirements-spark2.in install-piptools
Expand Down
33 changes: 24 additions & 9 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# This file is autogenerated by pip-compile with python 3.8
# This file is autogenerated by pip-compile with python 3.9
# To update, run:
#
# make dev-requirements.txt
Expand Down Expand Up @@ -32,6 +32,7 @@ certifi==2021.10.8
# requests
cffi==1.15.0
# via
# -c requirements.txt
# bcrypt
# cryptography
# pynacl
Expand Down Expand Up @@ -71,7 +72,10 @@ croniter==1.3.4
# -c requirements.txt
# flytekit
cryptography==36.0.1
# via paramiko
# via
# -c requirements.txt
# paramiko
# secretstorage
dataclasses-json==0.5.6
# via
# -c requirements.txt
Expand Down Expand Up @@ -112,7 +116,7 @@ docstring-parser==0.13
# flytekit
filelock==3.6.0
# via virtualenv
flyteidl==0.22.3
flyteidl==0.23.0
# via
# -c requirements.txt
# flytekit
Expand All @@ -133,9 +137,9 @@ google-cloud-core==2.2.2
# via google-cloud-bigquery
google-crc32c==1.3.0
# via google-resumable-media
google-resumable-media==2.2.1
google-resumable-media==2.3.0
# via google-cloud-bigquery
googleapis-common-protos==1.54.0
googleapis-common-protos==1.55.0
# via
# -c requirements.txt
# flyteidl
Expand All @@ -156,12 +160,17 @@ idna==3.3
# via
# -c requirements.txt
# requests
importlib-metadata==4.11.1
importlib-metadata==4.11.2
# via
# -c requirements.txt
# keyring
iniconfig==1.1.1
# via pytest
jeepney==0.7.1
# via
# -c requirements.txt
# keyring
# secretstorage
jinja2==3.0.3
# via
# -c requirements.txt
Expand Down Expand Up @@ -275,7 +284,9 @@ pyasn1==0.4.8
pyasn1-modules==0.2.8
# via google-auth
pycparser==2.21
# via cffi
# via
# -c requirements.txt
# cffi
pynacl==1.5.0
# via paramiko
pyparsing==3.0.7
Expand Down Expand Up @@ -307,7 +318,7 @@ python-json-logger==2.0.2
# via
# -c requirements.txt
# flytekit
python-slugify==6.1.0
python-slugify==6.1.1
# via
# -c requirements.txt
# cookiecutter
Expand Down Expand Up @@ -349,6 +360,10 @@ retry==0.9.2
# flytekit
rsa==4.8
# via google-auth
secretstorage==3.3.1
# via
# -c requirements.txt
# keyring
six==1.16.0
# via
# -c requirements.txt
Expand Down Expand Up @@ -399,7 +414,7 @@ urllib3==1.26.8
# flytekit
# requests
# responses
virtualenv==20.13.1
virtualenv==20.13.2
# via pre-commit
websocket-client==0.59.0
# via
Expand Down
25 changes: 15 additions & 10 deletions doc-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# This file is autogenerated by pip-compile with python 3.8
# This file is autogenerated by pip-compile with python 3.9
# To update, run:
#
# make doc-requirements.txt
Expand All @@ -10,7 +10,7 @@ alabaster==0.7.12
# via sphinx
arrow==1.2.2
# via jinja2-time
astroid==2.9.3
astroid==2.10.0
# via sphinx-autoapi
babel==2.9.1
# via sphinx
Expand Down Expand Up @@ -42,7 +42,9 @@ cookiecutter==1.7.3
croniter==1.3.4
# via flytekit
cryptography==36.0.1
# via -r doc-requirements.in
# via
# -r doc-requirements.in
# secretstorage
css-html-js-minify==2.5.5
# via sphinx-material
dataclasses-json==0.5.6
Expand All @@ -61,7 +63,7 @@ docutils==0.17.1
# via
# sphinx
# sphinx-panels
flyteidl==0.22.3
flyteidl==0.23.0
# via flytekit
furo @ git+https://github.com/flyteorg/furo@main
# via -r doc-requirements.in
Expand All @@ -75,10 +77,14 @@ idna==3.3
# via requests
imagesize==1.3.0
# via sphinx
importlib-metadata==4.11.1
importlib-metadata==4.11.2
# via
# keyring
# sphinx
jeepney==0.7.1
# via
# keyring
# secretstorage
jinja2==3.0.3
# via
# cookiecutter
Expand Down Expand Up @@ -125,10 +131,7 @@ protobuf==3.19.4
# googleapis-common-protos
# protoc-gen-swagger
protoc-gen-swagger==0.1.0
# via
# flyteidl
# flytekit

# via flyteidl
py==1.11.0
# via retry
pyarrow==6.0.1
Expand All @@ -149,7 +152,7 @@ python-dateutil==2.8.2
# pandas
python-json-logger==2.0.2
# via flytekit
python-slugify[unidecode]==6.1.0
python-slugify[unidecode]==6.1.1
# via
# cookiecutter
# sphinx-material
Expand All @@ -174,6 +177,8 @@ responses==0.18.0
# via flytekit
retry==0.9.2
# via flytekit
secretstorage==3.3.1
# via keyring
six==1.16.0
# via
# cookiecutter
Expand Down
1 change: 1 addition & 0 deletions flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@
from flytekit.core.context_manager import ExecutionParameters, FlyteContext, FlyteContextManager
from flytekit.core.data_persistence import DataPersistence, DataPersistencePlugins
from flytekit.core.dynamic_workflow_task import dynamic
from flytekit.core.hash import HashMethod
from flytekit.core.launch_plan import LaunchPlan
from flytekit.core.map_task import map_task
from flytekit.core.notification import Email, PagerDuty, Slack
Expand Down
20 changes: 20 additions & 0 deletions flytekit/core/hash.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,23 @@
from typing import Callable, Generic, TypeVar

T = TypeVar("T")


class HashOnReferenceMixin(object):
def __hash__(self):
return hash(id(self))


class HashMethod(Generic[T]):
"""
Flyte-specific object used to wrap the hash function for a specific type
"""

def __init__(self, function: Callable[[T], str]):
self._function = function

def calculate(self, obj: T) -> str:
"""
Calculate hash for `obj`.
"""
return self._function(obj)
30 changes: 28 additions & 2 deletions flytekit/core/local_cache.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,42 @@
import base64
from typing import Optional

import cloudpickle
from diskcache import Cache

from flytekit.models.literals import LiteralMap
from flytekit.models.literals import Literal, LiteralCollection, LiteralMap

# Location on the filesystem where serialized objects will be stored
# TODO: read from config
CACHE_LOCATION = "~/.flyte/local-cache"


def _recursive_hash_placement(literal: Literal) -> Literal:
if literal.collection is not None:
literals = [_recursive_hash_placement(literal) for literal in literal.collection.literals]
return Literal(collection=LiteralCollection(literals=literals))
elif literal.map is not None:
literal_map = {}
for key, literal in literal.map.literals.items():
literal_map[key] = _recursive_hash_placement(literal)
return Literal(map=LiteralMap(literal_map))

# Base case
if literal.hash is not None:
return Literal(hash=literal.hash)
else:
return literal


def _calculate_cache_key(task_name: str, cache_version: str, input_literal_map: LiteralMap) -> str:
return f"{task_name}-{cache_version}-{input_literal_map}"
# Traverse the literals and replace the literal with a new literal that only contains the hash
literal_map_overridden = {}
for key, literal in input_literal_map.literals.items():
literal_map_overridden[key] = _recursive_hash_placement(literal)

# Pickle the literal map and use base64 encoding to generate a representation of it
b64_encoded = base64.b64encode(cloudpickle.dumps(LiteralMap(literal_map_overridden)))
return f"{task_name}-{cache_version}-{b64_encoded}"


class LocalTaskCache(object):
Expand Down
35 changes: 29 additions & 6 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,6 @@
from abc import ABC, abstractmethod
from typing import NamedTuple, Optional, Type, cast

try:
from typing import Annotated, get_args, get_origin
except ImportError:
from typing_extensions import Annotated, get_origin, get_args

from dataclasses_json import DataClassJsonMixin, dataclass_json
from google.protobuf import json_format as _json_format
from google.protobuf import reflection as _proto_reflection
Expand All @@ -24,9 +19,11 @@
from google.protobuf.struct_pb2 import Struct
from marshmallow_enum import EnumField, LoadDumpOptions
from marshmallow_jsonschema import JSONSchema
from typing_extensions import Annotated, get_args, get_origin

from flytekit.core.annotation import FlyteAnnotation
from flytekit.core.context_manager import FlyteContext
from flytekit.core.hash import HashMethod
from flytekit.core.type_helpers import load_type_from_tag
from flytekit.exceptions import user as user_exceptions
from flytekit.loggers import logger
Expand Down Expand Up @@ -56,10 +53,13 @@ class TypeTransformer(typing.Generic[T]):
Base transformer type that should be implemented for every python native type that can be handled by flytekit
"""

def __init__(self, name: str, t: Type[T], enable_type_assertions: bool = True):
def __init__(self, name: str, t: Type[T], enable_type_assertions: bool = True, hash_overridable: bool = False):
self._t = t
self._name = name
self._type_assertions_enabled = enable_type_assertions
# `hash_overridable` indicates that the literals produced by this type transformer can set their hashes if needed.
# See (link to documentation where this feature is explained).
self._hash_overridable = hash_overridable

@property
def name(self):
Expand All @@ -79,6 +79,10 @@ def type_assertions_enabled(self) -> bool:
"""
return self._type_assertions_enabled

@property
def hash_overridable(self) -> bool:
return self._hash_overridable

def assert_type(self, t: Type[T], v: T):
if not hasattr(t, "__origin__") and not isinstance(v, t):
raise TypeError(f"Type of Val '{v}' is not an instance of {t}")
Expand Down Expand Up @@ -640,7 +644,25 @@ def to_literal(cls, ctx: FlyteContext, python_val: typing.Any, python_type: Type
transformer = cls.get_transformer(python_type)
if transformer.type_assertions_enabled:
transformer.assert_type(python_type, python_val)

# In case the value is an annotated type we inspect the annotations and look for hash-related annotations.
hash = None
if transformer.hash_overridable and get_origin(python_type) is Annotated:
# We are now dealing with one of two cases:
# 1. The annotated type is a `HashMethod`, which indicates that we should we should produce the hash using
# the method indicated in the annotation.
# 2. The annotated type is being used for a different purpose other than calculating hash values, in which case
# we should just continue.
for annotation in get_args(python_type)[1:]:
if not isinstance(annotation, HashMethod):
continue
hash = annotation.calculate(python_val)
break

lv = transformer.to_literal(ctx, python_val, python_type, expected)

if hash is not None:
lv.hash = hash
return lv

@classmethod
Expand Down Expand Up @@ -852,6 +874,7 @@ def to_literal(
for k, v in python_val.items():
if type(k) != str:
raise ValueError("Flyte MapType expects all keys to be strings")
# TODO: log a warning for Annotated objects that contain HashMethod
k_type, v_type = self.get_dict_types(python_type)
lit_map[k] = TypeEngine.to_literal(ctx, v, v_type, expected.map_value_type)
return Literal(map=LiteralMap(literals=lit_map))
Expand Down
Loading

0 comments on commit 0da523c

Please sign in to comment.