Skip to content

Commit

Permalink
Respect predefined modes in get_default_mode
Browse files Browse the repository at this point in the history
Also make linker and optimizer non-mutable config as the mode is cached after using them for the first time.
  • Loading branch information
ricardoV94 committed Jan 23, 2025
1 parent a0fe30d commit 837f98e
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 30 deletions.
43 changes: 21 additions & 22 deletions pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,46 +503,45 @@ def get_mode(orig_string):
if not isinstance(string, str):
return string # it is hopefully already a mode...

if string in predefined_modes:
return predefined_modes[string]

if string not in ("Mode", "DebugMode", "DEBUG_MODE", "NanGuardMode"):
raise ValueError(f"No predefined mode exist for string: {string}")

Check warning on line 510 in pytensor/compile/mode.py

View check run for this annotation

Codecov / codecov/patch

pytensor/compile/mode.py#L510

Added line #L510 was not covered by tests

global instantiated_default_mode
# The default mode is cached. However, config.mode can change
# If instantiated_default_mode has the right class, use it.

if orig_string is None and instantiated_default_mode:
if string in predefined_modes:
default_mode_class = predefined_modes[string].__class__.__name__
else:
default_mode_class = string
default_mode_class = string
# FIXME: This is flawed, we should use proper object comparison.
if instantiated_default_mode.__class__.__name__ == default_mode_class:
return instantiated_default_mode

if string in ("Mode", "DebugMode", "NanGuardMode"):
if string == "DebugMode":
# need to import later to break circular dependency.
from .debugmode import DebugMode
if string in ("DebugMode", "DEBUG_MODE"):
# need to import later to break circular dependency.
from .debugmode import DebugMode

# DebugMode use its own linker.
ret = DebugMode(optimizer=config.optimizer)
elif string == "NanGuardMode":
# need to import later to break circular dependency.
from .nanguardmode import NanGuardMode
# DebugMode use its own linker.
ret = DebugMode(optimizer=config.optimizer)
elif string == "NanGuardMode":
# need to import later to break circular dependency.
from .nanguardmode import NanGuardMode

Check warning on line 530 in pytensor/compile/mode.py

View check run for this annotation

Codecov / codecov/patch

pytensor/compile/mode.py#L530

Added line #L530 was not covered by tests

# NanGuardMode use its own linker.
ret = NanGuardMode(True, True, True, optimizer=config.optimizer)
else:
# TODO: Can't we look up the name and invoke it rather than using eval here?
ret = eval(string + "(linker=config.linker, optimizer=config.optimizer)")
elif string in predefined_modes:
ret = predefined_modes[string]
# NanGuardMode use its own linker.
ret = NanGuardMode(True, True, True, optimizer=config.optimizer)

Check warning on line 533 in pytensor/compile/mode.py

View check run for this annotation

Codecov / codecov/patch

pytensor/compile/mode.py#L533

Added line #L533 was not covered by tests
else:
raise Exception(f"No predefined mode exist for string: {string}")
ret = Mode(linker=config.linker, optimizer=config.optimizer)

if orig_string is None:
# Build and cache the default mode
if config.optimizer_excluding:
ret = ret.excluding(*config.optimizer_excluding.split(":"))
if config.optimizer_including:
ret = ret.including(*config.optimizer_including.split(":"))
if config.optimizer_requiring:
ret = ret.requiring(*config.optimizer_requiring.split(":"))
# Override the cache with the new class mode
instantiated_default_mode = ret

return ret
Expand Down
4 changes: 3 additions & 1 deletion pytensor/configdefaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,8 @@ def add_compile_configvars():
config.add(
"linker",
"Default linker used if the pytensor flags mode is Mode",
EnumStr("cvm", linker_options),
# Not mutable because the default mode is cached after the first use.
EnumStr("cvm", linker_options, mutable=False),
in_c_key=False,
)

Expand All @@ -411,6 +412,7 @@ def add_compile_configvars():
EnumStr(
"o4",
["o3", "o2", "o1", "unsafe", "fast_run", "fast_compile", "merge", "None"],
mutable=False, # Not mutable because the default mode is cached after the first use.
),
in_c_key=False,
)
Expand Down
8 changes: 1 addition & 7 deletions tests/compile/function/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,14 +1105,10 @@ def test_optimizations_preserved(self):
((a.T.T) * (dot(xm, (sm.T.T.T)) + x).T * (x / x) + s),
)
old_default_mode = config.mode
old_default_opt = config.optimizer
old_default_link = config.linker
try:
try:
str_f = pickle.dumps(f, protocol=-1)
config.mode = "Mode"
config.linker = "py"
config.optimizer = "None"
config.mode = "NUMBA"
g = pickle.loads(str_f)
# print g.maker.mode
# print compile.mode.default_mode
Expand All @@ -1121,8 +1117,6 @@ def test_optimizations_preserved(self):
g = "ok"
finally:
config.mode = old_default_mode
config.optimizer = old_default_opt
config.linker = old_default_link

if g == "ok":
return
Expand Down
13 changes: 13 additions & 0 deletions tests/compile/test_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pytensor.graph.features import NoOutputFromInplace
from pytensor.graph.rewriting.db import RewriteDatabaseQuery, SequenceDB
from pytensor.link.basic import LocalLinker
from pytensor.link.jax import JAXLinker
from pytensor.tensor.math import dot, tanh
from pytensor.tensor.type import matrix, vector

Expand Down Expand Up @@ -142,3 +143,15 @@ class MyLinker(LocalLinker):
test_mode = Mode(linker=MyLinker())
with pytest.raises(Exception):
get_target_language(test_mode)


def test_predefined_modes_respected():
default_mode = get_default_mode()
assert not isinstance(default_mode.linker, JAXLinker)

with config.change_flags(mode="JAX"):
jax_mode = get_default_mode()
assert isinstance(jax_mode.linker, JAXLinker)

default_mode_again = get_default_mode()
assert not isinstance(default_mode_again.linker, JAXLinker)

0 comments on commit 837f98e

Please sign in to comment.