[PT FE] Optimize memory usage of patch_model #27428
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Details:
no_jit_trace
was using extra memory to get and store trace state, which contained all graph that was produced before, that increase memory consumption of tracing. For example for FLUX model it used about 20Gb of extra memoryargs
on meta device didn't work withoutno_jit_trace
. To workaround this issue we now pass args directly to forward without saving them inTrampoline
. This allows better flow for arguments and reduce memory used to save those args. However this changes the behavior ofevaluate
ofModuleExtension
, because now it uses the args that were passed toconvert
and not the original args.optimum-cli for FLUX with


torch_dtype=torch.bfloat16
before change:optimum-cli for FLUX with
torch_dtype=torch.bfloat16
after change:Note: optimum doesn't yet support
torch_dtype=torch.bfloat16
for FLUX.Tickets: