Skip to content

Commit

Permalink
[PT FE] Optimize memory usage of patch_model (openvinotoolkit#27428)
Browse files Browse the repository at this point in the history
### Details:
- *`no_jit_trace` was using extra memory to get and store trace state,
which contained all graph that was produced before, that increase memory
consumption of tracing. For example for FLUX model it used about 20Gb of
extra memory*
- *Saving `args` on meta device didn't work without `no_jit_trace`. To
workaround this issue we now pass args directly to forward without
saving them in `Trampoline`. This allows better flow for arguments and
reduce memory used to save those args. However this changes the behavior
of `evaluate` of `ModuleExtension`, because now it uses the args that
were passed to `convert` and not the original args.*

optimum-cli for FLUX with `torch_dtype=torch.bfloat16` before change: 

![image](https://github.com/user-attachments/assets/f070068a-e52e-4558-956e-95afa64d1dbc)
optimum-cli for FLUX with `torch_dtype=torch.bfloat16` after change:

![image](https://github.com/user-attachments/assets/a76fe1df-2410-4b92-9b01-38ef40133b2b)

Note: optimum doesn't yet support `torch_dtype=torch.bfloat16` for FLUX.

### Tickets:
 - *CVS-151254*

---------

Signed-off-by: Maxim Vafin <[email protected]>
  • Loading branch information
mvafin authored and NishantPrabhuFujitsu committed Nov 26, 2024
1 parent f77795b commit f7d7f38
Showing 1 changed file with 29 additions and 51 deletions.
80 changes: 29 additions & 51 deletions src/bindings/python/src/openvino/frontend/pytorch/patch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,7 @@
log = logging.getLogger(__name__)


class no_jit_trace:
def __enter__(self):
self.state = torch._C._get_tracing_state()
torch._C._set_tracing_state(None)

def __exit__(self, *args):
torch._C._set_tracing_state(self.state)
self.state = None


def patch_model(model, module_extensions, orig_forward_name, use_meta=False):
def patch_model(model, module_extensions, orig_forward_name):
def module_patcher(m, name):
extension = None
if m in module_extensions:
Expand All @@ -34,41 +24,27 @@ def module_patcher(m, name):

if extension:
log.debug("Patching module %s", m)
# The Trampoline class is instantiated for every module replacement, so we can use class members individually for each module.
# The Trampoline class is instantiated for every module replacement, so we can use
# class members individually for each module.

class Trampoline(torch.autograd.Function):
# required to be saved in class
target_extension = extension
original_module = m
stashed_args = tuple()
stashed_kwargs = {}

@staticmethod
@torch.jit.ignore
def forward(*args, **kwargs):
with no_jit_trace():
# `module` is going to be passed to a user-defined function `evaluate`
# `module` is patched: forward function was replaced, and we are actually in this patched function right in this code
# if we pass `module` as-is to the user code below, and it happens to call forward it will lead to infinite recursion or fail
# so we need to temporary patch the module back to the original forward and then return it back again
# stash the current forward to be able to return it back
patched_forward = m.forward
# set original forward for the module
m.forward = getattr(m, orig_forward_name)
# call user code
results = extension.evaluate(m, *Trampoline.stashed_args,
**Trampoline.stashed_kwargs)
m.forward = patched_forward # return patched forward back
return results
def forward(ctx, *args, **kwargs):
# Temporarily restore the original forward function of `module` to avoid
# recursion issues in `evaluate`, then revert it back.
patched_forward = m.forward
# set original forward for the module
m.forward = getattr(m, orig_forward_name)
# call user code
results = extension.evaluate(m, *args, **kwargs)
m.forward = patched_forward # return patched forward back
return results

def new_forward(*args, **kwargs):
# use meta device to store args, to save memory
if use_meta:
d = torch.device("meta")
Trampoline.stashed_args = tuple(a.to(d) for a in args)
Trampoline.stashed_kwargs = dict((k, v.to(d)) for k, v in kwargs.items())
else:
Trampoline.stashed_args = args
Trampoline.stashed_kwargs = kwargs
return extension.convert(m, Trampoline.apply, *args, **kwargs)

# make signature of new_forward same as of forward
Expand Down Expand Up @@ -109,36 +85,38 @@ def __make_16bit_traceable(model: torch.nn.Module):
extensions = {
torch.nn.Linear: ModuleExtension(
torch.nn.Linear, "ov_ext::linear",
evaluate=lambda module, *args, **kwargs: torch.full(
list(args[0].shape[:-1]) + [module.out_features], 0.5, dtype=torch.float32),
convert=lambda module, target_op, *args, **kwargs: target_op(args[0],
module.weight,
module.bias)),
module.bias),
evaluate=lambda module, *args, **kwargs: torch.full(
list(args[0].shape[:-1]) + [module.out_features], 0.5, dtype=torch.float32)),
torch.nn.Embedding: ModuleExtension(
torch.nn.Embedding, "ov_ext::embedding",
evaluate=lambda module, *args, **kwargs: torch.full(
list(args[0].shape) + [module.embedding_dim], 0.5, dtype=torch.float32),
convert=lambda module, target_op, *args, **kwargs: target_op(module.weight,
args[0],
module.padding_idx,
module.scale_grad_by_freq,
module.sparse)),
module.sparse),
evaluate=lambda module, *args, **kwargs: torch.full(
list(args[1].shape) + [module.embedding_dim], 0.5, dtype=torch.float32)),
}
try:
from transformers.pytorch_utils import Conv1D
extensions[Conv1D] = ModuleExtension(
Conv1D, "ov_ext::conv1d",
evaluate=lambda module, *args, **kwargs: torch.full(
list(args[0].shape[:-1]) + [module.nf], 0.5, dtype=torch.float32),
convert=lambda module, target_op, *args, **kwargs: target_op(args[0],
module.weight,
module.bias))
except:
module.bias),
evaluate=lambda module, *args, **kwargs: torch.full(
list(args[0].shape[:-1]) + [module.nf], 0.5, dtype=torch.float32))
except ImportError:
pass
patch_model(model, extensions,
"_openvino_module_extension_patch_orig_forward", use_meta=True)
"_openvino_module_extension_patch_orig_forward")
dtype_to_patch = [torch.float16, torch.bfloat16]
for _, module in model.named_modules():
if module.__class__ not in extensions and (any(p.dtype in [torch.float16, torch.bfloat16] for p in module.parameters(False))
or any(b.dtype in [torch.float16, torch.bfloat16] for b in module.buffers(False))):
if (module.__class__ not in extensions and
(any(p.dtype in dtype_to_patch for p in module.parameters(False))
or any(b.dtype in dtype_to_patch for b in module.buffers(False)))):
log.debug("Casting module %s to float32", module)
module.float()

0 comments on commit f7d7f38

Please sign in to comment.