Skip to content

Commit

Permalink
python: support for *args/**kwargs in dispatcher
Browse files Browse the repository at this point in the history
Also, fix some fs/proxy stuff
  • Loading branch information
Yuval Shavit authored Jan 23, 2023
1 parent b836758 commit 5dd180e
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 9 deletions.
10 changes: 9 additions & 1 deletion pkg/lang/python/aws_runtime/dispatcher_fargate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import multiprocessing
import types
import inspect

app_port = os.getenv("KLOTHO_APP_PORT", 3000)
log_level = os.getenv("KLOTHO_LOG_LEVEL", "DEBUG").upper()
Expand Down Expand Up @@ -59,7 +60,14 @@ async def proxy_root_post(obj: dict):
function = getattr(module_obj, function_name, None)
if not function:
raise Exception(f"couldn't find function: {module_name}.{function_name}")
result = function(*params)
param_args = ()
param_kwargs = {}
args_spec = inspect.getfullargspec(function)
if args_spec.varkw:
param_kwargs, params = params[-1], params[:-1]
if args_spec.varargs:
param_args, params = params[-1], params[:-1]
result = function(*params, *param_args, **param_kwargs)
if isinstance(result, types.CoroutineType):
result = await result
return result
Expand Down
10 changes: 9 additions & 1 deletion pkg/lang/python/aws_runtime/dispatcher_fargate.py.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import logging
import os
import multiprocessing
import types
import inspect

app_port = os.getenv("KLOTHO_APP_PORT", 3000)
log_level = os.getenv("KLOTHO_LOG_LEVEL", "DEBUG").upper()
Expand Down Expand Up @@ -59,7 +60,14 @@ def start_proxy_server():
function = getattr(module_obj, function_name, None)
if not function:
raise Exception(f"couldn't find function: {module_name}.{function_name}")
result = function(*params)
param_args = ()
param_kwargs = {}
args_spec = inspect.getfullargspec(function)
if args_spec.varkw:
param_kwargs, params = params[-1], params[:-1]
if args_spec.varargs:
param_args, params = params[-1], params[:-1]
result = function(*params, *param_args, **param_kwargs)
if isinstance(result, types.CoroutineType):
result = await result
return result
Expand Down
14 changes: 11 additions & 3 deletions pkg/lang/python/aws_runtime/dispatcher_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import types
import uuid
import inspect

log_level = os.getenv("KLOTHO_LOG_LEVEL", "DEBUG").upper()

Expand Down Expand Up @@ -47,17 +48,24 @@ def init_asgi_handler():
async def rpc_handler(event, _context):
payload_key = event.get('params')
async with s3fs.open(payload_key) as f:
params = json.load(await f.read())
params = json.loads(await f.read())
module_obj = try_import(event.get('module_name'))
if not module_obj:
raise Exception("couldn't find module for path: {module_path}")
function = getattr(module_obj, event.get('function_to_call'))
result = function(*params)
param_args = ()
param_kwargs = {}
args_spec = inspect.getfullargspec(function)
if args_spec.varkw:
param_kwargs, params = params[-1], params[:-1]
if args_spec.varargs:
param_args, params = params[-1], params[:-1]
result = function(*params, *param_args, **param_kwargs)
if isinstance(result, types.CoroutineType):
result = await result

result_payload_key = str(uuid.uuid4())
async with s3fs.open(result_payload_key) as f:
async with s3fs.open(result_payload_key, mode='w') as f:
await f.write(json.dumps(result))
return result_payload_key

Expand Down
14 changes: 11 additions & 3 deletions pkg/lang/python/aws_runtime/dispatcher_lambda.py.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import logging
import os
import types
import uuid
import inspect

log_level = os.getenv("KLOTHO_LOG_LEVEL", "DEBUG").upper()

Expand Down Expand Up @@ -47,17 +48,24 @@ def init_asgi_handler():
async def rpc_handler(event, _context):
payload_key = event.get('params')
async with s3fs.open(payload_key) as f:
params = json.load(await f.read())
params = json.loads(await f.read())
module_obj = try_import(event.get('module_name'))
if not module_obj:
raise Exception("couldn't find module for path: {module_path}")
function = getattr(module_obj, event.get('function_to_call'))
result = function(*params)
param_args = ()
param_kwargs = {}
args_spec = inspect.getfullargspec(function)
if args_spec.varkw:
param_kwargs, params = params[-1], params[:-1]
if args_spec.varargs:
param_args, params = params[-1], params[:-1]
result = function(*params, *param_args, **param_kwargs)
if isinstance(result, types.CoroutineType):
result = await result

result_payload_key = str(uuid.uuid4())
async with s3fs.open(result_payload_key) as f:
async with s3fs.open(result_payload_key, mode='w') as f:
await f.write(json.dumps(result))
return result_payload_key

Expand Down
2 changes: 1 addition & 1 deletion pkg/lang/python/aws_runtime/proxy_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ async def proxy_call(exec_group_name, module_name, function_name, params):
Payload=json.dumps(payload_to_send))
dispatcher_param_key_result = json.load(result["Payload"])
async with s3fs.open(dispatcher_param_key_result) as f:
return json.load(await f.read())
return json.loads(await f.read())


def get_exec_unit_lambda_function_name(logical_name):
Expand Down

0 comments on commit 5dd180e

Please sign in to comment.