Skip to content

Commit

Permalink
fix bug from pytorch#101651
Browse files Browse the repository at this point in the history
  • Loading branch information
yusheng.wei committed May 9, 2024
1 parent b958810 commit 3e4f56c
Showing 1 changed file with 41 additions and 7 deletions.
48 changes: 41 additions & 7 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,14 @@ class OptimizedModule(torch.nn.Module):
forward method to optimized self.forward method.
"""

_torchdynamo_orig_callable: Callable[..., Any]
get_compiler_config: Callable[[], Any]
__torchdynamo_orig_callable: Callable[..., Any]
_get_compiler_config: Callable[[], Any]

_opt_mod_attributes = {
"_orig_mod",
"dynamo_ctx",
"_torchdynamo_orig_callable",
"get_compiler_config",
"__torchdynamo_orig_callable",
"_get_compiler_config",
"forward",
"_forward",
"__dict__",
Expand Down Expand Up @@ -170,6 +170,8 @@ def _initialize(self):

def __getstate__(self):
state = dict(self.__dict__)
state.pop("__torchdynamo_orig_callable", None)
state.pop("_get_compiler_config", None)
state.pop("forward", None)
state.pop("__call__", None)
return state
Expand Down Expand Up @@ -207,6 +209,27 @@ def __dir__(self):
attr for attr in super().__dir__() if attr not in orig_mod_attrs
]

@property
def _torchdynamo_orig_callable(self):
if ( not hasattr(self, '__torchdynamo_orig_callable') ) or self.__torchdynamo_orig_callable is None:
return lambda : innermost_fn(self)
return self.__torchdynamo_orig_callable

@_torchdynamo_orig_callable.setter
def _torchdynamo_orig_callable(self, fn):
assert callable(fn)
self.__torchdynamo_orig_callable = fn

@property
def get_compiler_config(self):
if ( not hasattr(self, '_get_compiler_config') ) or self._get_compiler_config is None:
return lambda : self.dynamo_ctx.compiler_config
return self._get_compiler_config

@get_compiler_config.setter
def get_compiler_config(self, fn):
assert callable(fn)
self._get_compiler_config = fn

def remove_from_cache(f):
"""
Expand Down Expand Up @@ -323,6 +346,13 @@ def __exit__(self, exc_type, exc_val, exc_tb):
cleanup()
self.cleanup_fns.clear()

def __getstate__(self):
state = dict(self.__dict__)
state.pop("callback", None)
state.pop("on_enter", None)
state.pop("__call__", None)
return state

def __call__(self, fn):
# public api for compiler config/options
def get_compiler_config():
Expand All @@ -347,7 +377,7 @@ def get_compiler_config():

# when compiling torch.nn.Module,
# provide public api OptimizedModule.get_compiler_config()
assert not hasattr(new_mod, "get_compiler_config")
assert not hasattr(new_mod, "_get_compiler_config")
new_mod.get_compiler_config = get_compiler_config

return new_mod
Expand Down Expand Up @@ -379,7 +409,11 @@ def get_compiler_config():
# call to a builtin without a frame for us to capture
fn = external_utils.wrap_inline(fn)

callback = self.callback
def do_nothing(*arg, **kwargs):
pass
callback = do_nothing
if hasattr(self, 'callback'):
callback = self.callback

is_jit_tracing = torch._C._is_tracing
is_fx_tracing = torch.fx._symbolic_trace.is_fx_tracing
Expand Down Expand Up @@ -426,7 +460,7 @@ def _fn(*args, **kwargs):

# when compiling user function instead of nn.Module
# provide public api _fn.get_compiler_config()
assert not hasattr(_fn, "get_compiler_config")
assert not hasattr(_fn, "_get_compiler_config")
_fn.get_compiler_config = get_compiler_config # type: ignore[attr-defined]

# If the function is called using torch._dynamo.optimize decorator, we
Expand Down

0 comments on commit 3e4f56c

Please sign in to comment.