Skip to content

Commit

Permalink
Added test case in test_wanda.py for custom config
Browse files Browse the repository at this point in the history
  • Loading branch information
agrawal-aka committed Nov 27, 2024
1 parent 366261e commit 63eeb7e
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions test/sparsity/test_wanda.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,5 +110,39 @@ def test_two_layer_mlp_unstructured(self):

sparsifier.squash_mask()

def test_two_layer_mlp_unstructured_custom_config(self):
model = nn.Sequential(
nn.Linear(128, 200), nn.ReLU(), nn.Linear(200, 10)
) # C_in by C_out
X1 = torch.randn(100, 128) # B1 by C_in
X2 = torch.randn(50, 128) # B2 by C_in

# Define custom config to sparsify only the first Linear layer for testing
config = [{"tensor_fqn": "0.weight"}]

sparsifier = WandaSparsifier(sparsity_level=0.5)
sparsifier.prepare(model, config=config)

model(X1)
model(X2)
sparsifier.step()

cnt = 0
for m in model.modules():
if isinstance(m, nn.Linear):
cnt += 1
sparsity_level = (m.weight == 0).float().mean()
if cnt == 1: # First Linear layer should have 50% sparsity
assert (
sparsity_level == 0.5
), f"sparsity for linear layer {cnt} should be 0.5"
else: # Other layers should not be sparsified
assert (
sparsity_level != 0.5
), f"sparsity for linear layer {cnt} should not be 0.5"

sparsifier.squash_mask()


if __name__ == "__main__":
unittest.main()

0 comments on commit 63eeb7e

Please sign in to comment.