Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PT FE] Optimize memory usage of patch_model #27428

Merged
merged 3 commits into from
Nov 11, 2024

Conversation

mvafin
Copy link
Contributor

@mvafin mvafin commented Nov 6, 2024

Details:

  • no_jit_trace was using extra memory to get and store trace state, which contained all graph that was produced before, that increase memory consumption of tracing. For example for FLUX model it used about 20Gb of extra memory
  • Saving args on meta device didn't work without no_jit_trace. To workaround this issue we now pass args directly to forward without saving them in Trampoline. This allows better flow for arguments and reduce memory used to save those args. However this changes the behavior of evaluate of ModuleExtension, because now it uses the args that were passed to convert and not the original args.

optimum-cli for FLUX with torch_dtype=torch.bfloat16 before change:
image
optimum-cli for FLUX with torch_dtype=torch.bfloat16 after change:
image

Note: optimum doesn't yet support torch_dtype=torch.bfloat16 for FLUX.

Tickets:

@mvafin mvafin requested a review from a team as a code owner November 6, 2024 10:49
@github-actions github-actions bot added category: Python API OpenVINO Python bindings category: PyTorch FE OpenVINO PyTorch Frontend labels Nov 6, 2024
@mvafin mvafin added this pull request to the merge queue Nov 11, 2024
Merged via the queue into openvinotoolkit:master with commit f9118af Nov 11, 2024
166 checks passed
@mvafin mvafin deleted the mvafin/pt_fe/optimize_patch branch November 11, 2024 10:54
NishantPrabhuFujitsu pushed a commit to NishantPrabhuFujitsu/openvino that referenced this pull request Nov 26, 2024
### Details:
- *`no_jit_trace` was using extra memory to get and store trace state,
which contained all graph that was produced before, that increase memory
consumption of tracing. For example for FLUX model it used about 20Gb of
extra memory*
- *Saving `args` on meta device didn't work without `no_jit_trace`. To
workaround this issue we now pass args directly to forward without
saving them in `Trampoline`. This allows better flow for arguments and
reduce memory used to save those args. However this changes the behavior
of `evaluate` of `ModuleExtension`, because now it uses the args that
were passed to `convert` and not the original args.*

optimum-cli for FLUX with `torch_dtype=torch.bfloat16` before change: 

![image](https://github.com/user-attachments/assets/f070068a-e52e-4558-956e-95afa64d1dbc)
optimum-cli for FLUX with `torch_dtype=torch.bfloat16` after change:

![image](https://github.com/user-attachments/assets/a76fe1df-2410-4b92-9b01-38ef40133b2b)

Note: optimum doesn't yet support `torch_dtype=torch.bfloat16` for FLUX.

### Tickets:
 - *CVS-151254*

---------

Signed-off-by: Maxim Vafin <[email protected]>
github-merge-queue bot pushed a commit that referenced this pull request Nov 27, 2024
### Details:
 - *Cherry-pick #27428 and #27413 in 24.6 branch*

### Tickets:
 - *ticket-id*

---------

Signed-off-by: Maxim Vafin <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
category: Python API OpenVINO Python bindings category: PyTorch FE OpenVINO PyTorch Frontend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants