Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.compile() on quantized model: No attribute "meta" #148072

Closed
Whadup opened this issue Feb 27, 2025 · 3 comments
Closed

torch.compile() on quantized model: No attribute "meta" #148072

Whadup opened this issue Feb 27, 2025 · 3 comments
Assignees
Labels
needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user oncall: cpu inductor CPU Inductor issues for Intel team to triage oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@Whadup
Copy link
Contributor

Whadup commented Feb 27, 2025

During evaluation of a compiled and quanzited model, I obtain the error "no attribute "meta"" in the following line:

x = match.kwargs["x"].meta["val"]

I propose the following change:

-       x = match.kwargs["x"].meta["val"] if hasattr(match.kwargs["x"], 'meta') else match.kwargs["x"]
-       weight = match.kwargs["weight"].meta["val"] if hasattr(match.kwargs["weight"], 'meta') else match.kwargs["weight"]
-       scales = match.kwargs["scales"].meta["val"] if hasattr(match.kwargs["scales"], 'meta') else match.kwargs["scales"]
+       x = match.kwargs["x"]
+       if hasattr(x, 'meta'):
+           x = x.meta["val"]
+       weight = match.kwargs["weight"]
+       if hasattr(weight, 'meta'):
+           weight = weight.meta["val"]
+       scales = match.kwargs["scales"]
+       if hasattr(scales, 'meta'):
+           scales = scales.meta["val"]

cc @chauhang @penguinwu

Here is an example to reproduce the behavior on a machine with an A100 GPU.
Requirements: torch, transformers, peft

from transformers import AutoModelForCausalLM
import peft
import torch

model = AutoModelForCausalLM.from_pretrained(
    "casperhansen/llama-3-8b-instruct-awq",
    device_map="auto",
)
model = peft.get_peft_model(
    model,
    peft.LoraConfig(
        task_type="CAUSAL_LM"
    )
)

torch._dynamo.config.cache_size_limit = 1024
for i, layer in enumerate(model.base_model.model.model.layers):
    model.base_model.model.model.layers[i] = torch.compile(layer)

with torch.amp.autocast("cuda"):
    model(
        input_ids = torch.tensor([[0, 1, 2]]).cuda(),
        attention_mask = torch.tensor([[1, 1, 1]]).cuda()
    )

Output:

torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AttributeError: 'float' object has no attribute 'meta'

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
@desertfire desertfire added oncall: cpu inductor CPU Inductor issues for Intel team to triage triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user labels Feb 28, 2025
@desertfire
Copy link
Contributor

Can you provide instruction for a reproduction?

@Whadup
Copy link
Contributor Author

Whadup commented Mar 4, 2025

I provided an example in the edit.
I tried to extract the essence of my original training script, but it is still rather specific: Without LORA (peft) it works. Without mixed precision (amp.autocast) it works. Without the quantized based model (AWQ), it works. The combination triggers the error.

Using the fix I proposed in the original message, it works.

@leslie-fang-intel
Copy link
Collaborator

Hi @Whadup, seems you already have the PR. Assign it to you. If not, please let me know. Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user oncall: cpu inductor CPU Inductor issues for Intel team to triage oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants