Skip to content

Commit

Permalink
handles case when no config is specified
Browse files Browse the repository at this point in the history
  • Loading branch information
agrawal-aka committed Nov 26, 2024
1 parent 5b917ec commit 366261e
Showing 1 changed file with 22 additions and 17 deletions.
39 changes: 22 additions & 17 deletions torchao/sparsity/wanda.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,23 +45,28 @@ def __init__(

# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting errors.
def prepare(self, model: nn.Module, config: List[Dict]) -> None:

for module_config in config:
tensor_fqn = module_config.get("tensor_fqn", None)
if tensor_fqn is None:
raise ValueError("Each config must contain a 'tensor_fqn'.")

# Extract module information from tensor_fqn
info_from_tensor_fqn = get_arg_info_from_tensor_fqn(model, tensor_fqn)
module = info_from_tensor_fqn["module"]

# Apply the qconfig directly to the module if it exists
# activation: use PerChannelNormObserver
# use no-op placeholder weight observer
if module is not None:
module.qconfig = QConfig(
activation=PerChannelNormObserver, weight=default_placeholder_observer
) # type: ignore[assignment]
# activation: use PerChannelNormObserver
# use no-op placeholder weight observer
if config is None:
# If no config is provided, apply the qconfig to the entire model
model.qconfig = QConfig(
activation=PerChannelNormObserver, weight=default_placeholder_observer
) # type: ignore[assignment]
else:
for module_config in config:
tensor_fqn = module_config.get("tensor_fqn", None)
if tensor_fqn is None:
raise ValueError("Each config must contain a 'tensor_fqn'.")

# Extract module information from tensor_fqn
info_from_tensor_fqn = get_arg_info_from_tensor_fqn(model, tensor_fqn)
module = info_from_tensor_fqn["module"]

# Apply the qconfig directly to the module if it exists
if module is not None:
module.qconfig = QConfig(
activation=PerChannelNormObserver, weight=default_placeholder_observer
) # type: ignore[assignment]
torch.ao.quantization.prepare(model, inplace=True)

# call superclass prepare
Expand Down

0 comments on commit 366261e

Please sign in to comment.