Skip to content

Commit

Permalink
Improve Python exception handling in user provided functions
Browse files Browse the repository at this point in the history
Adds null checks around the returned object of user provided functions
and fails if the returned object is NULL. This stops us from
segfaulting and provides an immediate error as opposed to the error
propagating up further down the pipeline when we attempt to use the
NULL object.

Add null check for encoded string in Machida's TCPSinkEncoder

Verifies that the string we're outputting is infact a string so
we do not segfault if the user provides an encoder function that
does not return a string.
  • Loading branch information
JONBRWN committed May 14, 2018
1 parent 4e8294c commit c90a9bb
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 15 deletions.
9 changes: 4 additions & 5 deletions machida/cpp/python-wallaroo.c
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,7 @@ extern PyObject *computation_compute(PyObject *computation, PyObject *data,
pValue = PyObject_CallFunctionObjArgs(pFunc, data, NULL);
Py_DECREF(pFunc);

if (pValue != Py_None)
return pValue;
else
return NULL;
return pValue;
}

extern PyObject *sink_encoder_encode(PyObject *sink_encoder, PyObject *data)
Expand Down Expand Up @@ -238,7 +235,9 @@ extern long partition_function_partition_u64(PyObject *partition_function, PyObj
Py_DECREF(pFunc);

long rtn = PyInt_AsLong(pValue);
Py_DECREF(pValue);
if (pValue != NULL) {
Py_DECREF(pValue);
}
return rtn;
}

Expand Down
40 changes: 30 additions & 10 deletions machida/machida.pony
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,10 @@ class PyStateBuilder
_state_builder = state_builder

fun apply(): PyState =>
PyState(@state_builder_build_state(_state_builder))
let py_state = @state_builder_build_state(_state_builder)
Machida.print_errors()
if py_state.is_null() then Fail() end
PyState(py_state)

fun _serialise_space(): USize =>
Machida.user_serialization_get_size(_state_builder)
Expand Down Expand Up @@ -326,7 +329,7 @@ class PyComputation is Computation[PyData val, PyData val]
let r: Pointer[U8] val =
Machida.computation_compute(_computation, input.obj(), _is_multi)

if not r.is_null() then
if not Machida.is_py_none(r) then
Machida.process_computation_results(r, _is_multi)
else
None
Expand Down Expand Up @@ -406,14 +409,21 @@ class PyTCPEncoder is TCPSinkEncoder[PyData val]

fun apply(data: PyData val, wb: Writer): Array[ByteSeq] val =>
let byte_buffer = Machida.sink_encoder_encode(_sink_encoder, data.obj())
if not byte_buffer.is_null() and not Machida.is_py_none(byte_buffer) then
let arr = recover val
// create a temporary Array[U8] wrapper for the C array, then clone it
Array[U8].from_cpointer(@PyString_AsString(byte_buffer),
@PyString_Size(byte_buffer)).clone()
if not Machida.is_py_none(byte_buffer) then
let byte_string = @PyString_AsString(byte_buffer)

if not byte_string.is_null() then
let arr = recover val
// create a temporary Array[U8] wrapper for the C array, then clone it
Array[U8].from_cpointer(@PyString_AsString(byte_buffer),
@PyString_Size(byte_buffer)).clone()
end
Machida.dec_ref(byte_buffer)
wb.write(arr)
else
Machida.print_errors()
Fail()
end
Machida.dec_ref(byte_buffer)
wb.write(arr)
end
wb.done()

Expand Down Expand Up @@ -714,13 +724,15 @@ primitive Machida
=>
let r = @source_decoder_decode(source_decoder, data, size)
print_errors()
if r.is_null() then Fail() end
r

fun sink_encoder_encode(sink_encoder: Pointer[U8] val, data: Pointer[U8] val):
Pointer[U8] val
=>
let r = @sink_encoder_encode(sink_encoder, data)
print_errors()
if r.is_null() then Fail() end
r

fun computation_compute(computation: Pointer[U8] val, data: Pointer[U8] val,
Expand All @@ -729,6 +741,7 @@ primitive Machida
let method = if multi then "compute_multi" else "compute" end
let r = @computation_compute(computation, data, method.cstring())
print_errors()
if r.is_null() then Fail() end
r

fun stateful_computation_compute(computation: Pointer[U8] val,
Expand All @@ -738,7 +751,10 @@ primitive Machida
let method = if multi then "compute_multi" else "compute" end
let r =
@stateful_computation_compute(computation, data, state, method.cstring())

print_errors()
if r.is_null() then Fail() end

let msg = @PyTuple_GetItem(r, 0)
let persist = @PyTuple_GetItem(r, 1)

Expand All @@ -755,7 +771,10 @@ primitive Machida
data: Pointer[U8] val): U64
=>
let r = @partition_function_partition_u64(partition_function, data)
print_errors()
if err_occurred() and (r == -1) then
print_errors()
Fail()
end
r

fun py_list_int_to_pony_array_u64(py_array: Pointer[U8] val):
Expand Down Expand Up @@ -787,6 +806,7 @@ primitive Machida
=>
let r = @partition_function_partition(partition_function, data)
print_errors()
if r.is_null() then Fail() end
r

fun py_list_int_to_pony_array_pykey(py_array: Pointer[U8] val):
Expand Down

0 comments on commit c90a9bb

Please sign in to comment.