From 5dd180e2850b77e9f69a557201641a236043b92d Mon Sep 17 00:00:00 2001 From: Yuval Shavit <110620369+yuval-klotho@users.noreply.github.com> Date: Mon, 23 Jan 2023 10:03:35 -0500 Subject: [PATCH] python: support for `*args`/`**kwargs` in dispatcher Also, fix some fs/proxy stuff --- pkg/lang/python/aws_runtime/dispatcher_fargate.py | 10 +++++++++- .../python/aws_runtime/dispatcher_fargate.py.tmpl | 10 +++++++++- pkg/lang/python/aws_runtime/dispatcher_lambda.py | 14 +++++++++++--- .../python/aws_runtime/dispatcher_lambda.py.tmpl | 14 +++++++++++--- pkg/lang/python/aws_runtime/proxy_lambda.py | 2 +- 5 files changed, 41 insertions(+), 9 deletions(-) diff --git a/pkg/lang/python/aws_runtime/dispatcher_fargate.py b/pkg/lang/python/aws_runtime/dispatcher_fargate.py index 25f83598c..9828ed45f 100644 --- a/pkg/lang/python/aws_runtime/dispatcher_fargate.py +++ b/pkg/lang/python/aws_runtime/dispatcher_fargate.py @@ -2,6 +2,7 @@ import os import multiprocessing import types +import inspect app_port = os.getenv("KLOTHO_APP_PORT", 3000) log_level = os.getenv("KLOTHO_LOG_LEVEL", "DEBUG").upper() @@ -59,7 +60,14 @@ async def proxy_root_post(obj: dict): function = getattr(module_obj, function_name, None) if not function: raise Exception(f"couldn't find function: {module_name}.{function_name}") - result = function(*params) + param_args = () + param_kwargs = {} + args_spec = inspect.getfullargspec(function) + if args_spec.varkw: + param_kwargs, params = params[-1], params[:-1] + if args_spec.varargs: + param_args, params = params[-1], params[:-1] + result = function(*params, *param_args, **param_kwargs) if isinstance(result, types.CoroutineType): result = await result return result diff --git a/pkg/lang/python/aws_runtime/dispatcher_fargate.py.tmpl b/pkg/lang/python/aws_runtime/dispatcher_fargate.py.tmpl index 25f83598c..9828ed45f 100644 --- a/pkg/lang/python/aws_runtime/dispatcher_fargate.py.tmpl +++ b/pkg/lang/python/aws_runtime/dispatcher_fargate.py.tmpl @@ -2,6 +2,7 @@ import logging import os import multiprocessing import types +import inspect app_port = os.getenv("KLOTHO_APP_PORT", 3000) log_level = os.getenv("KLOTHO_LOG_LEVEL", "DEBUG").upper() @@ -59,7 +60,14 @@ def start_proxy_server(): function = getattr(module_obj, function_name, None) if not function: raise Exception(f"couldn't find function: {module_name}.{function_name}") - result = function(*params) + param_args = () + param_kwargs = {} + args_spec = inspect.getfullargspec(function) + if args_spec.varkw: + param_kwargs, params = params[-1], params[:-1] + if args_spec.varargs: + param_args, params = params[-1], params[:-1] + result = function(*params, *param_args, **param_kwargs) if isinstance(result, types.CoroutineType): result = await result return result diff --git a/pkg/lang/python/aws_runtime/dispatcher_lambda.py b/pkg/lang/python/aws_runtime/dispatcher_lambda.py index d4de05420..ff05715a2 100644 --- a/pkg/lang/python/aws_runtime/dispatcher_lambda.py +++ b/pkg/lang/python/aws_runtime/dispatcher_lambda.py @@ -5,6 +5,7 @@ import os import types import uuid +import inspect log_level = os.getenv("KLOTHO_LOG_LEVEL", "DEBUG").upper() @@ -47,17 +48,24 @@ def init_asgi_handler(): async def rpc_handler(event, _context): payload_key = event.get('params') async with s3fs.open(payload_key) as f: - params = json.load(await f.read()) + params = json.loads(await f.read()) module_obj = try_import(event.get('module_name')) if not module_obj: raise Exception("couldn't find module for path: {module_path}") function = getattr(module_obj, event.get('function_to_call')) - result = function(*params) + param_args = () + param_kwargs = {} + args_spec = inspect.getfullargspec(function) + if args_spec.varkw: + param_kwargs, params = params[-1], params[:-1] + if args_spec.varargs: + param_args, params = params[-1], params[:-1] + result = function(*params, *param_args, **param_kwargs) if isinstance(result, types.CoroutineType): result = await result result_payload_key = str(uuid.uuid4()) - async with s3fs.open(result_payload_key) as f: + async with s3fs.open(result_payload_key, mode='w') as f: await f.write(json.dumps(result)) return result_payload_key diff --git a/pkg/lang/python/aws_runtime/dispatcher_lambda.py.tmpl b/pkg/lang/python/aws_runtime/dispatcher_lambda.py.tmpl index d4de05420..ff05715a2 100644 --- a/pkg/lang/python/aws_runtime/dispatcher_lambda.py.tmpl +++ b/pkg/lang/python/aws_runtime/dispatcher_lambda.py.tmpl @@ -5,6 +5,7 @@ import logging import os import types import uuid +import inspect log_level = os.getenv("KLOTHO_LOG_LEVEL", "DEBUG").upper() @@ -47,17 +48,24 @@ def init_asgi_handler(): async def rpc_handler(event, _context): payload_key = event.get('params') async with s3fs.open(payload_key) as f: - params = json.load(await f.read()) + params = json.loads(await f.read()) module_obj = try_import(event.get('module_name')) if not module_obj: raise Exception("couldn't find module for path: {module_path}") function = getattr(module_obj, event.get('function_to_call')) - result = function(*params) + param_args = () + param_kwargs = {} + args_spec = inspect.getfullargspec(function) + if args_spec.varkw: + param_kwargs, params = params[-1], params[:-1] + if args_spec.varargs: + param_args, params = params[-1], params[:-1] + result = function(*params, *param_args, **param_kwargs) if isinstance(result, types.CoroutineType): result = await result result_payload_key = str(uuid.uuid4()) - async with s3fs.open(result_payload_key) as f: + async with s3fs.open(result_payload_key, mode='w') as f: await f.write(json.dumps(result)) return result_payload_key diff --git a/pkg/lang/python/aws_runtime/proxy_lambda.py b/pkg/lang/python/aws_runtime/proxy_lambda.py index f1d638c93..e10516a43 100644 --- a/pkg/lang/python/aws_runtime/proxy_lambda.py +++ b/pkg/lang/python/aws_runtime/proxy_lambda.py @@ -24,7 +24,7 @@ async def proxy_call(exec_group_name, module_name, function_name, params): Payload=json.dumps(payload_to_send)) dispatcher_param_key_result = json.load(result["Payload"]) async with s3fs.open(dispatcher_param_key_result) as f: - return json.load(await f.read()) + return json.loads(await f.read()) def get_exec_unit_lambda_function_name(logical_name):