Skip to content

Commit

Permalink
Tracebacks no longer have JAX-internal frames prepended by default
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Aug 3, 2023
1 parent a22c477 commit d400527
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 17 deletions.
13 changes: 8 additions & 5 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,12 +1030,15 @@ def _update_disable_jit_thread_local(val):
help="Controls how JAX filters internal frames out of tracebacks.\n\n"
"Valid values are:\n"
" * \"off\": disables traceback filtering.\n"
" * \"auto\": use \"tracebackhide\" if running under a sufficiently "
"new IPython, or \"remove_frames\" otherwise.\n"
" * \"tracebackhide\": adds \"__tracebackhide__\" annotations to "
" * \"auto\": use \"tracebackhide\" if running under a sufficiently"
" new IPython, or \"remove_frames\" otherwise.\n"
" * \"tracebackhide\": adds \"__tracebackhide__\" annotations to"
" hidden stack frames, which some traceback printers support.\n"
" * \"remove_frames\": removes hidden frames from tracebacks, and adds "
" the unfiltered traceback as a __cause__ of the exception.\n")
" * \"remove_frames\": removes hidden frames from tracebacks, and adds"
" the unfiltered traceback as a __cause__ of the exception.\n"
" * \"quiet_remove_frames\": removes hidden frames from tracebacks, and adds"
" a brief message (to the __cause__ of the exception) describing that this has"
" happened.\n")

# This flag is for internal use.
# TODO(tianjianlu): Removes once we always enable cusparse lowering.
Expand Down
35 changes: 23 additions & 12 deletions jax/_src/traceback_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,14 @@ def format_exception_only(e: BaseException) -> str:

class UnfilteredStackTrace(Exception): pass

class SimplifiedTraceback(Exception):
def __str__(self):
return ("For simplicity, JAX has removed its internal frames from the traceback of "
"of the following exception. Set JAX_TRACEBACK_FILTERING=off to include "
"these.")

SimplifiedTraceback.__module__ = "jax.errors"

def _running_under_ipython() -> bool:
"""Returns true if we appear to be in an IPython session."""
try:
Expand All @@ -133,7 +141,7 @@ def _filtering_mode() -> str:
if (_running_under_ipython() and _ipython_supports_tracebackhide()):
mode = "tracebackhide"
else:
mode = "remove_frames"
mode = "quiet_remove_frames"
return mode

def api_boundary(fun: C) -> C:
Expand Down Expand Up @@ -171,21 +179,24 @@ def reraise_with_filtered_traceback(*args, **kwargs):
if mode == "tracebackhide":
_add_tracebackhide_to_hidden_frames(e.__traceback__)
raise
assert mode == "remove_frames", mode

filtered_tb, unfiltered, mode = None, None, None
filtered_tb, unfiltered = None, None
try:
filtered_tb = filter_traceback(e.__traceback__)
msg = format_exception_only(e)
msg = f'{msg}\n\n{_jax_message_append}'
unfiltered = UnfilteredStackTrace(msg)
unfiltered.with_traceback(_add_call_stack_frames(e.__traceback__))
unfiltered.__context__ = e.__context__
unfiltered.__cause__ = e.__cause__
unfiltered.__suppress_context__ = e.__suppress_context__
if mode == "quiet_remove_frames":
jax_error = SimplifiedTraceback()
elif mode == "remove_frames":
msg = format_exception_only(e)
msg = f'{msg}\n\n{_jax_message_append}'
jax_error = UnfilteredStackTrace(msg)
jax_error.with_traceback(_add_call_stack_frames(e.__traceback__))
else:
raise ValueError(f"JAX_TRACEBACK_FILTERING={mode} is not a valid value.")
jax_error.__cause__ = e.__cause__
jax_error.__context__ = e.__context__
jax_error.__suppress_context__ = e.__suppress_context__
e.__cause__ = jax_error
e.__context__ = None
e.__cause__ = unfiltered

e.__traceback__ = filtered_tb
# In Python < 3.11, there seems to be no way to alter the currently
# raised exception traceback, except via the C API. The interpreter
Expand Down
1 change: 1 addition & 0 deletions jax/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@
TracerIntegerConversionError as TracerIntegerConversionError,
UnexpectedTracerError as UnexpectedTracerError,
)
from jax._src.traceback_util import SimplifiedTraceback as SimplifiedTraceback

0 comments on commit d400527

Please sign in to comment.