Skip to content

Commit

Permalink
Fixes observer attachment to model based on config for wanda sparsifier
Browse files Browse the repository at this point in the history
  • Loading branch information
agrawal-aka committed Nov 26, 2024
1 parent 43966b6 commit 5b917ec
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions torchao/sparsity/wanda.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch
from torch import nn
from torch.ao.pruning import BaseSparsifier
from torch.ao.pruning import BaseSparsifier, get_arg_info_from_tensor_fqn
from torch.ao.quantization import QConfig, default_placeholder_observer
from torch.ao.quantization.quantize import _remove_qconfig

Expand Down Expand Up @@ -45,11 +45,23 @@ def __init__(

# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
def prepare(self, model: nn.Module, config: List[Dict]) -> None:
# activation: use PerChannelNormObserver
# use no-op placeholder weight observer
model.qconfig = QConfig(
activation=PerChannelNormObserver, weight=default_placeholder_observer
) # type: ignore[assignment]

for module_config in config:
tensor_fqn = module_config.get("tensor_fqn", None)
if tensor_fqn is None:
raise ValueError("Each config must contain a 'tensor_fqn'.")

# Extract module information from tensor_fqn
info_from_tensor_fqn = get_arg_info_from_tensor_fqn(model, tensor_fqn)
module = info_from_tensor_fqn["module"]

# Apply the qconfig directly to the module if it exists
# activation: use PerChannelNormObserver
# use no-op placeholder weight observer
if module is not None:
module.qconfig = QConfig(
activation=PerChannelNormObserver, weight=default_placeholder_observer
) # type: ignore[assignment]
torch.ao.quantization.prepare(model, inplace=True)

# call superclass prepare
Expand Down

0 comments on commit 5b917ec

Please sign in to comment.