Skip to content

Commit

Permalink
fix: merge conflict
Browse files Browse the repository at this point in the history
Signed-off-by: Kenny Workman <[email protected]>
  • Loading branch information
kennyworkman committed Dec 9, 2021
2 parents 5c6a9b2 + 3b7fb4d commit 9f316f9
Show file tree
Hide file tree
Showing 32 changed files with 1,371 additions and 841 deletions.
3 changes: 1 addition & 2 deletions flytekit/clients/friendly.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,12 +676,11 @@ def get_node_execution(self, node_execution_identifier):
)
)

def get_node_execution_data(self, node_execution_identifier):
def get_node_execution_data(self, node_execution_identifier) -> _execution.NodeExecutionGetDataResponse:
"""
Returns signed URLs to LiteralMap blobs for a node execution's inputs and outputs (when available).
:param flytekit.models.core.identifier.NodeExecutionIdentifier node_execution_identifier:
:rtype: flytekit.models.execution.NodeExecutionGetDataResponse
"""
return _execution.NodeExecutionGetDataResponse.from_flyte_idl(
super(SynchronousFlyteClient, self).get_node_execution_data(
Expand Down
42 changes: 40 additions & 2 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from google.protobuf.json_format import MessageToDict as _MessageToDict
from google.protobuf.json_format import ParseDict as _ParseDict
from google.protobuf.struct_pb2 import Struct
from marshmallow_enum import EnumField, LoadDumpOptions
from marshmallow_jsonschema import JSONSchema
from typing_extensions import get_origin

Expand All @@ -30,7 +31,7 @@
from flytekit.models import types as _type_models
from flytekit.models.annotation import TypeAnnotation as _annotation_model
from flytekit.models.core import types as _core_types
from flytekit.models.literals import Literal, LiteralCollection, LiteralMap, Primitive, Scalar
from flytekit.models.literals import Literal, LiteralCollection, LiteralMap, Primitive, Scalar, Schema
from flytekit.models.types import LiteralType, SimpleType

T = typing.TypeVar("T")
Expand Down Expand Up @@ -229,7 +230,13 @@ def get_literal_type(self, t: Type[T]) -> LiteralType:
)
schema = None
try:
schema = JSONSchema().dump(cast(DataClassJsonMixin, t).schema())
s = cast(DataClassJsonMixin, t).schema()
for _, v in s.fields.items():
# marshmallow-jsonschema only supports enums loaded by name.
# https://github.com/fuhrysteve/marshmallow-jsonschema/blob/81eada1a0c42ff67de216923968af0a6b54e5dcb/marshmallow_jsonschema/base.py#L228
if isinstance(v, EnumField):
v.load_by = LoadDumpOptions.name
schema = JSONSchema().dump(s)
except Exception as e:
logger.warn("failed to extract schema for object %s, (will run schemaless) error: %s", str(t), e)

Expand All @@ -245,10 +252,39 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp
raise AssertionError(
f"Dataclass {python_type} should be decorated with @dataclass_json to be " f"serialized correctly"
)
self._serialize_flyte_type(python_val, python_type)
return Literal(
scalar=Scalar(generic=_json_format.Parse(cast(DataClassJsonMixin, python_val).to_json(), _struct.Struct()))
)

def _serialize_flyte_type(self, python_val: T, python_type: Type[T]):
"""
If any field inside the dataclass is flyte type, we should use flyte type transformer for that field.
"""
from flytekit.types.schema.types import FlyteSchema, FlyteSchemaTransformer

for f in dataclasses.fields(python_type):
v = python_val.__getattribute__(f.name)
if inspect.isclass(f.type) and issubclass(f.type, FlyteSchema):
FlyteSchemaTransformer().to_literal(FlyteContext.current_context(), v, f.type, None)
elif dataclasses.is_dataclass(f.type):
self._serialize_flyte_type(v, f.type)

def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type["FlyteSchema"]):
from flytekit.types.schema.types import FlyteSchema, FlyteSchemaTransformer

for f in dataclasses.fields(expected_python_type):
v = python_val.__getattribute__(f.name)
if inspect.isclass(f.type) and issubclass(f.type, FlyteSchema):
t = FlyteSchemaTransformer()
t.to_python_value(
FlyteContext.current_context(),
Literal(scalar=Scalar(schema=Schema(v.remote_path, t._get_schema_type(f.type)))),
f.type,
)
elif dataclasses.is_dataclass(f.type):
self._deserialize_flyte_type(v, f.type)

def _fix_val_int(self, t: typing.Type, val: typing.Any) -> typing.Any:
if t == int:
return int(val)
Expand Down Expand Up @@ -291,7 +327,9 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
f"Dataclass {expected_python_type} should be decorated with @dataclass_json to be "
f"serialized correctly"
)

dc = cast(DataClassJsonMixin, expected_python_type).from_json(_json_format.MessageToJson(lv.scalar.generic))
self._deserialize_flyte_type(dc, expected_python_type)
return self._fix_dataclass_int(expected_python_type, dc)

def guess_python_type(self, literal_type: LiteralType) -> Type[T]:
Expand Down
File renamed without changes.
224 changes: 224 additions & 0 deletions flytekit/extras/tasks/shell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
import datetime
import logging
import os
import re
import subprocess
import typing
from dataclasses import dataclass

from flytekit.core.context_manager import ExecutionParameters
from flytekit.core.interface import Interface
from flytekit.core.python_function_task import PythonInstanceTask
from flytekit.core.task import TaskPlugins
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile


@dataclass
class OutputLocation:
"""
Args:
var: str The name of the output variable
var_type: typing.Type The type of output variable
location: os.PathLike The location where this output variable will be written to or a regex that accepts input
vars and generates the path. Of the form ``"{{ .inputs.v }}.tmp.md"``.
This example for a given input v, at path `/tmp/abc.csv` will resolve to `/tmp/abc.csv.tmp.md`
"""

var: str
var_type: typing.Type
location: typing.Union[os.PathLike, str]


def _stringify(v: typing.Any) -> str:
"""
Special cased return for the given value. Given the type returns the string version for the type.
Handles FlyteFile and FlyteDirectory specially. Downloads and returns the downloaded filepath
"""
if isinstance(v, FlyteFile):
v.download()
return v.path
if isinstance(v, FlyteDirectory):
v.download()
return v.path
if isinstance(v, datetime.datetime):
return v.isoformat()
return str(v)


def _interpolate(tmpl: str, regex: re.Pattern, validate_all_match: bool = True, **kwargs) -> str:
"""
Substitutes all templates that match the supplied regex
with the given inputs and returns the substituted string. The result is non destructive towards the given string.
"""
modified = tmpl
matched = set()
for match in regex.finditer(tmpl):
expr = match.groups()[0]
var = match.groups()[1]
if var not in kwargs:
raise ValueError(f"Variable {var} in Query (part of {expr}) not found in inputs {kwargs.keys()}")
matched.add(var)
val = kwargs[var]
# str conversion should be deliberate, with right conversion for each type
modified = modified.replace(expr, _stringify(val))

if validate_all_match:
if len(matched) < len(kwargs.keys()):
diff = set(kwargs.keys()).difference(matched)
raise ValueError(f"Extra Inputs have no matches in script template - missing {diff}")
return modified


def _dummy_task_func():
"""
A Fake function to satisfy the inner PythonTask requirements
"""
return None


T = typing.TypeVar("T")


class ShellTask(PythonInstanceTask[T]):
""" """

_INPUT_REGEX = re.compile(r"({{\s*.inputs.(\w+)\s*}})", re.IGNORECASE)
_OUTPUT_REGEX = re.compile(r"({{\s*.outputs.(\w+)\s*}})", re.IGNORECASE)

def __init__(
self,
name: str,
debug: bool = False,
script: typing.Optional[str] = None,
script_file: typing.Optional[str] = None,
task_config: T = None,
inputs: typing.Optional[typing.Dict[str, typing.Type]] = None,
output_locs: typing.Optional[typing.List[OutputLocation]] = None,
**kwargs,
):
"""
Args:
name: str Name of the Task. Should be unique in the project
debug: bool Print the generated script and other debugging information
script: The actual script specified as a string
script_file: A path to the file that contains the script (Only script or script_file) can be provided
task_config: T Configuration for the task, can be either a Pod (or coming soon, BatchJob) config
inputs: A Dictionary of input names to types
output_locs: A list of :py:class:`OutputLocations`
**kwargs: Other arguments that can be passed to :ref:class:`PythonInstanceTask`
"""
if script and script_file:
raise ValueError("Only either of script or script_file can be provided")
if not script and not script_file:
raise ValueError("Either a script or script_file is needed")
if script_file:
if not os.path.exists(script_file):
raise ValueError(f"FileNotFound: the specified Script file at path {script_file} cannot be loaded")
script_file = os.path.abspath(script_file)

if task_config is not None:
if str(type(task_config)) != "flytekitplugins.pod.task.Pod":
raise ValueError("TaskConfig can either be empty - indicating simple container task or a PodConfig.")

# Each instance of NotebookTask instantiates an underlying task with a dummy function that will only be used
# to run pre- and post- execute functions using the corresponding task plugin.
# We rename the function name here to ensure the generated task has a unique name and avoid duplicate task name
# errors.
# This seem like a hack. We should use a plugin_class that doesn't require a fake-function to make work.
plugin_class = TaskPlugins.find_pythontask_plugin(type(task_config))
self._config_task_instance = plugin_class(task_config=task_config, task_function=_dummy_task_func)
# Rename the internal task so that there are no conflicts at serialization time. Technically these internal
# tasks should not be serialized at all, but we don't currently have a mechanism for skipping Flyte entities
# at serialization time.
self._config_task_instance._name = f"_bash.{name}"
self._script = script
self._script_file = script_file
self._debug = debug
self._output_locs = output_locs if output_locs else []
outputs = self._validate_output_locs()
super().__init__(
name,
task_config,
task_type=self._config_task_instance.task_type,
interface=Interface(inputs=inputs, outputs=outputs),
**kwargs,
)

def _validate_output_locs(self) -> typing.Dict[str, typing.Type]:
outputs = {}
for v in self._output_locs:
if v is None:
raise ValueError("OutputLocation cannot be none")
if not isinstance(v, OutputLocation):
raise ValueError("Every output type should be an output location on the file-system")
if v.location is None:
raise ValueError(f"Output Location not provided for output var {v.var}")
if not issubclass(v.var_type, FlyteFile) and not issubclass(v.var_type, FlyteDirectory):
raise ValueError(
"Currently only outputs of type FlyteFile/FlyteDirectory and their derived types are supported"
)
outputs[v.var] = v.var_type
return outputs

@property
def script(self) -> typing.Optional[str]:
return self._script

@property
def script_file(self) -> typing.Optional[os.PathLike]:
return self._script_file

def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:
return self._config_task_instance.pre_execute(user_params)

def execute(self, **kwargs) -> typing.Any:
"""
Executes the given script by substituting the inputs and outputs and extracts the outputs from the filesystem
"""
logging.info(f"Running shell script as type {self.task_type}")
if self.script_file:
with open(self.script_file) as f:
self._script = f.read()

outputs: typing.Dict[str, str] = {}
if self._output_locs:
for v in self._output_locs:
outputs[v.var] = _interpolate(v.location, self._INPUT_REGEX, validate_all_match=False, **kwargs)

gen_script = _interpolate(self._script, self._INPUT_REGEX, **kwargs)
# For outputs it is not necessary that all outputs are used in the script, some are implicit outputs
# for example gcc main.c will generate a.out automatically
gen_script = _interpolate(gen_script, self._OUTPUT_REGEX, validate_all_match=False, **outputs)
if self._debug:
print("\n==============================================\n")
print(gen_script)
print("\n==============================================\n")

try:
subprocess.check_call(gen_script, shell=True)
except subprocess.CalledProcessError as e:
files = os.listdir("./")
fstr = "\n-".join(files)
logging.error(
f"Failed to Execute Script, return-code {e.returncode} \n"
f"StdErr: {e.stderr}\n"
f"StdOut: {e.stdout}\n"
f" Current directory contents: .\n-{fstr}"
)
raise

final_outputs = []
for v in self._output_locs:
if issubclass(v.var_type, FlyteFile):
final_outputs.append(FlyteFile(outputs[v.var]))
if issubclass(v.var_type, FlyteDirectory):
final_outputs.append(FlyteDirectory(outputs[v.var]))
if len(final_outputs) == 1:
return final_outputs[0]
if len(final_outputs) > 1:
return tuple(final_outputs)
return None

def post_execute(self, user_params: ExecutionParameters, rval: typing.Any) -> typing.Any:
return self._config_task_instance.post_execute(user_params, rval)
75 changes: 75 additions & 0 deletions flytekit/models/core/catalog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from flyteidl.core import catalog_pb2

from flytekit.models import common as _common_models
from flytekit.models.core import identifier as _identifier


class CatalogArtifactTag(_common_models.FlyteIdlEntity):
def __init__(self, artifact_id: str, name: str):
self._artifact_id = artifact_id
self._name = name

@property
def artifact_id(self) -> str:
return self._artifact_id

@property
def name(self) -> str:
return self._name

def to_flyte_idl(self) -> catalog_pb2.CatalogArtifactTag:
return catalog_pb2.CatalogArtifactTag(artifact_id=self.artifact_id, name=self.name)

@classmethod
def from_flyte_idl(cls, p: catalog_pb2.CatalogArtifactTag) -> "CatalogArtifactTag":
return cls(
artifact_id=p.artifact_id,
name=p.name,
)


class CatalogMetadata(_common_models.FlyteIdlEntity):
def __init__(
self,
dataset_id: _identifier.Identifier,
artifact_tag: CatalogArtifactTag,
source_task_execution: _identifier.TaskExecutionIdentifier,
):
self._dataset_id = dataset_id
self._artifact_tag = artifact_tag
self._source_task_execution = source_task_execution

@property
def dataset_id(self) -> _identifier.Identifier:
return self._dataset_id

@property
def artifact_tag(self) -> CatalogArtifactTag:
return self._artifact_tag

@property
def source_task_execution(self) -> _identifier.TaskExecutionIdentifier:
return self._source_task_execution

@property
def source_execution(self) -> _identifier.TaskExecutionIdentifier:
"""
This is a one of but for now there's only one thing in the one of
"""
return self._source_task_execution

def to_flyte_idl(self) -> catalog_pb2.CatalogMetadata:
return catalog_pb2.CatalogMetadata(
dataset_id=self.dataset_id.to_flyte_idl(),
artifact_tag=self.artifact_tag.to_flyte_idl(),
source_task_execution=self.source_task_execution.to_flyte_idl(),
)

@classmethod
def from_flyte_idl(cls, pb: catalog_pb2.CatalogMetadata) -> "CatalogMetadata":
return cls(
dataset_id=_identifier.Identifier.from_flyte_idl(pb.dataset_id),
artifact_tag=CatalogArtifactTag.from_flyte_idl(pb.artifact_tag),
# Add HasField check if more things are ever added to the one of
source_task_execution=_identifier.TaskExecutionIdentifier.from_flyte_idl(pb.source_task_execution),
)
Loading

0 comments on commit 9f316f9

Please sign in to comment.