diff --git a/python/mxnet/_deferred_compute.py b/python/mxnet/_deferred_compute.py index a280769060c6..c756d4966f8b 100644 --- a/python/mxnet/_deferred_compute.py +++ b/python/mxnet/_deferred_compute.py @@ -54,8 +54,10 @@ def context(state=True): # other code unexpectedly, when used in concurrent code." # https://github.com/apache/incubator-mxnet/issues/17495#issuecomment-585461965 val = set_deferred_compute(state) - yield - set_deferred_compute(val) + try: + yield + finally: + set_deferred_compute(val) def get_symbol(input_arrays, output_arrays, input_names=None, *, sym_cls=Symbol):