diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index 152ad3554d..60d65bbde1 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -503,46 +503,45 @@ def get_mode(orig_string): if not isinstance(string, str): return string # it is hopefully already a mode... + if string in predefined_modes: + return predefined_modes[string] + + if string not in ("Mode", "DebugMode", "DEBUG_MODE", "NanGuardMode"): + raise ValueError(f"No predefined mode exist for string: {string}") + global instantiated_default_mode # The default mode is cached. However, config.mode can change # If instantiated_default_mode has the right class, use it. + if orig_string is None and instantiated_default_mode: - if string in predefined_modes: - default_mode_class = predefined_modes[string].__class__.__name__ - else: - default_mode_class = string + default_mode_class = string + # FIXME: This is flawed, we should use proper object comparison. if instantiated_default_mode.__class__.__name__ == default_mode_class: return instantiated_default_mode - if string in ("Mode", "DebugMode", "NanGuardMode"): - if string == "DebugMode": - # need to import later to break circular dependency. - from .debugmode import DebugMode + if string in ("DebugMode", "DEBUG_MODE"): + # need to import later to break circular dependency. + from .debugmode import DebugMode - # DebugMode use its own linker. - ret = DebugMode(optimizer=config.optimizer) - elif string == "NanGuardMode": - # need to import later to break circular dependency. - from .nanguardmode import NanGuardMode + # DebugMode use its own linker. + ret = DebugMode(optimizer=config.optimizer) + elif string == "NanGuardMode": + # need to import later to break circular dependency. + from .nanguardmode import NanGuardMode - # NanGuardMode use its own linker. - ret = NanGuardMode(True, True, True, optimizer=config.optimizer) - else: - # TODO: Can't we look up the name and invoke it rather than using eval here? - ret = eval(string + "(linker=config.linker, optimizer=config.optimizer)") - elif string in predefined_modes: - ret = predefined_modes[string] + # NanGuardMode use its own linker. + ret = NanGuardMode(True, True, True, optimizer=config.optimizer) else: - raise Exception(f"No predefined mode exist for string: {string}") + ret = Mode(linker=config.linker, optimizer=config.optimizer) if orig_string is None: - # Build and cache the default mode if config.optimizer_excluding: ret = ret.excluding(*config.optimizer_excluding.split(":")) if config.optimizer_including: ret = ret.including(*config.optimizer_including.split(":")) if config.optimizer_requiring: ret = ret.requiring(*config.optimizer_requiring.split(":")) + # Override the cache with the new class mode instantiated_default_mode = ret return ret diff --git a/pytensor/configdefaults.py b/pytensor/configdefaults.py index a81fd63905..144b5daef2 100644 --- a/pytensor/configdefaults.py +++ b/pytensor/configdefaults.py @@ -388,7 +388,8 @@ def add_compile_configvars(): config.add( "linker", "Default linker used if the pytensor flags mode is Mode", - EnumStr("cvm", linker_options), + # Not mutable because the default mode is cached after the first use. + EnumStr("cvm", linker_options, mutable=False), in_c_key=False, ) @@ -411,6 +412,7 @@ def add_compile_configvars(): EnumStr( "o4", ["o3", "o2", "o1", "unsafe", "fast_run", "fast_compile", "merge", "None"], + mutable=False, # Not mutable because the default mode is cached after the first use. ), in_c_key=False, ) diff --git a/tests/compile/function/test_types.py b/tests/compile/function/test_types.py index bef3ae25bf..0990dbeca0 100644 --- a/tests/compile/function/test_types.py +++ b/tests/compile/function/test_types.py @@ -1105,14 +1105,10 @@ def test_optimizations_preserved(self): ((a.T.T) * (dot(xm, (sm.T.T.T)) + x).T * (x / x) + s), ) old_default_mode = config.mode - old_default_opt = config.optimizer - old_default_link = config.linker try: try: str_f = pickle.dumps(f, protocol=-1) - config.mode = "Mode" - config.linker = "py" - config.optimizer = "None" + config.mode = "NUMBA" g = pickle.loads(str_f) # print g.maker.mode # print compile.mode.default_mode @@ -1121,8 +1117,6 @@ def test_optimizations_preserved(self): g = "ok" finally: config.mode = old_default_mode - config.optimizer = old_default_opt - config.linker = old_default_link if g == "ok": return diff --git a/tests/compile/test_mode.py b/tests/compile/test_mode.py index c965087ea2..291eac0782 100644 --- a/tests/compile/test_mode.py +++ b/tests/compile/test_mode.py @@ -13,6 +13,7 @@ from pytensor.graph.features import NoOutputFromInplace from pytensor.graph.rewriting.db import RewriteDatabaseQuery, SequenceDB from pytensor.link.basic import LocalLinker +from pytensor.link.jax import JAXLinker from pytensor.tensor.math import dot, tanh from pytensor.tensor.type import matrix, vector @@ -142,3 +143,15 @@ class MyLinker(LocalLinker): test_mode = Mode(linker=MyLinker()) with pytest.raises(Exception): get_target_language(test_mode) + + +def test_predefined_modes_respected(): + default_mode = get_default_mode() + assert not isinstance(default_mode.linker, JAXLinker) + + with config.change_flags(mode="JAX"): + jax_mode = get_default_mode() + assert isinstance(jax_mode.linker, JAXLinker) + + default_mode_again = get_default_mode() + assert not isinstance(default_mode_again.linker, JAXLinker)