Bugfix: Offload of GGML-quantized model in torch.inference_mode()
cm
#7525
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
This PR contains a bugfix for an edge case with model unloading (from VRAM to RAM). Thanks to @JPPhoto for finding it.
The bug was triggered under the following conditions:
torch.inference_mode()
context manager.torch.inference_mode()
cm results in the following error:Explanation
From the
torch.inference_mode()
docs:Disabling version counter bumps results in the aforementioned error when saving
GGMLTensor
s to a state_dict.This incompatibility between
GGMLTensors
andtorch.inference_mode()
is likely caused by the custom tensor type implementation. There may very well be a way to get these to cooperate, but for now it is much simpler to remove thetorch.inference_mode()
contexts.Note that there are several other uses of
torch.inference_mode()
in the Invoke codebase, but they are all tight wrappers around the inference forward pass and do not contain the model load/unload process.Related Issues / Discussions
Original discussion: https://discord.com/channels/1020123559063990373/1149506274971631688/1326180753159094303
QA Instructions
Find a sequence of operations that triggers the condition. For me, this was:
Tests:
torch.inference_mode()
totorch.no_grad()
.50.354s
, After:51.536s
Checklist
What's New
copy (if doing a release after this PR)