Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[BUGFIX] Switch hybrid_forward to forward in test_fc_int8_fp32_outputs (
Browse files Browse the repository at this point in the history
#20398)

* Switch hybrid_forward to forward in test_fc_int8_fp32_outputs

* Add use_np decorator
  • Loading branch information
bgawrych authored Jun 29, 2021
1 parent 835e250 commit 38e1416
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions tests/python/mkl/subgraphs/test_fc_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,10 @@ def infer_shape(self, x, *args):
rtol=1e-2, atol=1e-2, etol=0.01)


@mx.util.use_np
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
def test_fc_int8_and_fp32_outputs(data_shape):
@pytest.mark.parametrize('flatten', [True, False])
def test_fc_int8_and_fp32_outputs(data_shape, flatten):

# /---> Quantizable op
# Input ---> FC -|
Expand All @@ -185,15 +187,15 @@ def test_fc_int8_and_fp32_outputs(data_shape):
class MultiOutputFC(nn.HybridBlock):
def __init__(self, **kwargs):
super(MultiOutputFC, self).__init__(**kwargs)
self.dense0 = nn.Dense(64)
self.dense1 = nn.Dense(64)
self.dense0 = nn.Dense(64, flatten=flatten)
self.dense1 = nn.Dense(64, flatten=flatten)

def hybrid_forward(self, F, x):
def forward(self, x):
x = self.dense0(x)
y = self.dense1(x) # quantizable
z = F.softmax(x) # non quantizable
y = self.dense1(x) # quantizable
z = mx.npx.softmax(x) # non quantizable
return y + z

attrs = {'fc': {}}
net = MultiOutputFC()
check_fusion(net, data_shape, attrs, check_quantization=True)
check_fusion(net, data_shape, attrs, check_quantization=flatten)

0 comments on commit 38e1416

Please sign in to comment.