diff --git a/tutorials/quantize_vit/run_vit_b_quant.py b/tutorials/quantize_vit/run_vit_b_quant.py index 5c30762099..06113bcd68 100644 --- a/tutorials/quantize_vit/run_vit_b_quant.py +++ b/tutorials/quantize_vit/run_vit_b_quant.py @@ -36,6 +36,9 @@ if not TORCH_VERSION_AT_LEAST_2_5: unwrap_tensor_subclass(model) +# temporary workaround to recover the perf with quantized model under torch.compile +torch.backends.mha.set_fastpath_enabled(False) + model = torch.compile(model, mode='max-autotune') # Must run with no_grad when optimizing for inference