From 45d48242343d26afe2bb93b6bf621d7f157c74ee Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Mon, 17 Apr 2023 14:33:40 -0700 Subject: [PATCH 1/3] set default ComSpec when running on Windows Signed-off-by: Kevin Su --- flytekit/extras/tasks/shell.py | 49 ++++++++++++++++++---------------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/flytekit/extras/tasks/shell.py b/flytekit/extras/tasks/shell.py index 12ef36af3e..64160086f0 100644 --- a/flytekit/extras/tasks/shell.py +++ b/flytekit/extras/tasks/shell.py @@ -1,5 +1,6 @@ import datetime import os +import platform import string import subprocess import typing @@ -70,10 +71,10 @@ def format_field(self, value, format_spec): return super().format_field(value, format_spec) def interpolate( - self, - tmpl: str, - inputs: typing.Optional[typing.Dict[str, str]] = None, - outputs: typing.Optional[typing.Dict[str, str]] = None, + self, + tmpl: str, + inputs: typing.Optional[typing.Dict[str, str]] = None, + outputs: typing.Optional[typing.Dict[str, str]] = None, ) -> str: """ Interpolate python formatted string templates with variables from the input and output @@ -101,15 +102,15 @@ class ShellTask(PythonInstanceTask[T]): """ """ def __init__( - self, - name: str, - debug: bool = False, - script: typing.Optional[str] = None, - script_file: typing.Optional[str] = None, - task_config: T = None, - inputs: typing.Optional[typing.Dict[str, typing.Type]] = None, - output_locs: typing.Optional[typing.List[OutputLocation]] = None, - **kwargs, + self, + name: str, + debug: bool = False, + script: typing.Optional[str] = None, + script_file: typing.Optional[str] = None, + task_config: T = None, + inputs: typing.Optional[typing.Dict[str, typing.Type]] = None, + output_locs: typing.Optional[typing.List[OutputLocation]] = None, + **kwargs, ): """ Args: @@ -213,6 +214,9 @@ def execute(self, **kwargs) -> typing.Any: print("\n==============================================\n") try: + if platform.system() == "Windows" and os.environ["ComSpec"] is None: + # https://github.com/python/cpython/issues/101283 + os.environ["ComSpec"] = "C:\\Windows\\System32\\cmd.exe" subprocess.check_call(gen_script, shell=True) except subprocess.CalledProcessError as e: files = os.listdir(".") @@ -245,15 +249,15 @@ class RawShellTask(ShellTask): """ """ def __init__( - self, - name: str, - debug: bool = False, - script: typing.Optional[str] = None, - script_file: typing.Optional[str] = None, - task_config: T = None, - inputs: typing.Optional[typing.Dict[str, typing.Type]] = None, - output_locs: typing.Optional[typing.List[OutputLocation]] = None, - **kwargs, + self, + name: str, + debug: bool = False, + script: typing.Optional[str] = None, + script_file: typing.Optional[str] = None, + task_config: T = None, + inputs: typing.Optional[typing.Dict[str, typing.Type]] = None, + output_locs: typing.Optional[typing.List[OutputLocation]] = None, + **kwargs, ): """ The `RawShellTask` is a minimal extension of the existing `ShellTask`. It's purpose is to support wrapping a @@ -356,7 +360,6 @@ def execute(self, **kwargs) -> typing.Any: # This utility function allows for the specification of env variables, arguments, and the actual script within the # workflow definition rather than at `RawShellTask` instantiation def get_raw_shell_task(name: str) -> RawShellTask: - return RawShellTask( name=name, debug=True, From a4f8fa6c8fd3831ddcd8c56e3a4fc99955b318f6 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Mon, 17 Apr 2023 14:38:24 -0700 Subject: [PATCH 2/3] lint Signed-off-by: Kevin Su --- flytekit/extras/tasks/shell.py | 44 +++++++++++++++++----------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/flytekit/extras/tasks/shell.py b/flytekit/extras/tasks/shell.py index 64160086f0..2664a3f076 100644 --- a/flytekit/extras/tasks/shell.py +++ b/flytekit/extras/tasks/shell.py @@ -71,10 +71,10 @@ def format_field(self, value, format_spec): return super().format_field(value, format_spec) def interpolate( - self, - tmpl: str, - inputs: typing.Optional[typing.Dict[str, str]] = None, - outputs: typing.Optional[typing.Dict[str, str]] = None, + self, + tmpl: str, + inputs: typing.Optional[typing.Dict[str, str]] = None, + outputs: typing.Optional[typing.Dict[str, str]] = None, ) -> str: """ Interpolate python formatted string templates with variables from the input and output @@ -102,15 +102,15 @@ class ShellTask(PythonInstanceTask[T]): """ """ def __init__( - self, - name: str, - debug: bool = False, - script: typing.Optional[str] = None, - script_file: typing.Optional[str] = None, - task_config: T = None, - inputs: typing.Optional[typing.Dict[str, typing.Type]] = None, - output_locs: typing.Optional[typing.List[OutputLocation]] = None, - **kwargs, + self, + name: str, + debug: bool = False, + script: typing.Optional[str] = None, + script_file: typing.Optional[str] = None, + task_config: T = None, + inputs: typing.Optional[typing.Dict[str, typing.Type]] = None, + output_locs: typing.Optional[typing.List[OutputLocation]] = None, + **kwargs, ): """ Args: @@ -249,15 +249,15 @@ class RawShellTask(ShellTask): """ """ def __init__( - self, - name: str, - debug: bool = False, - script: typing.Optional[str] = None, - script_file: typing.Optional[str] = None, - task_config: T = None, - inputs: typing.Optional[typing.Dict[str, typing.Type]] = None, - output_locs: typing.Optional[typing.List[OutputLocation]] = None, - **kwargs, + self, + name: str, + debug: bool = False, + script: typing.Optional[str] = None, + script_file: typing.Optional[str] = None, + task_config: T = None, + inputs: typing.Optional[typing.Dict[str, typing.Type]] = None, + output_locs: typing.Optional[typing.List[OutputLocation]] = None, + **kwargs, ): """ The `RawShellTask` is a minimal extension of the existing `ShellTask`. It's purpose is to support wrapping a From d138eb62649a115530d96550594c55fcb7738681 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Mon, 17 Apr 2023 15:12:41 -0700 Subject: [PATCH 3/3] nit Signed-off-by: Kevin Su --- flytekit/extras/tasks/shell.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytekit/extras/tasks/shell.py b/flytekit/extras/tasks/shell.py index 2664a3f076..87b60126d6 100644 --- a/flytekit/extras/tasks/shell.py +++ b/flytekit/extras/tasks/shell.py @@ -214,7 +214,7 @@ def execute(self, **kwargs) -> typing.Any: print("\n==============================================\n") try: - if platform.system() == "Windows" and os.environ["ComSpec"] is None: + if platform.system() == "Windows" and os.environ.get("ComSpec") is None: # https://github.com/python/cpython/issues/101283 os.environ["ComSpec"] = "C:\\Windows\\System32\\cmd.exe" subprocess.check_call(gen_script, shell=True)