diff --git a/scripts/hf_eval.py b/scripts/hf_eval.py index 48fb7fdf0d..ce2a106a37 100644 --- a/scripts/hf_eval.py +++ b/scripts/hf_eval.py @@ -45,7 +45,7 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compi model = AutoModelForCausalLM.from_pretrained(repo_id).to(device="cpu", dtype=precision) if compile: - model = torch.compile(model, mode="max-autotune", fullgraph=True) + model = torch.compile(model, fullgraph=True) if quantization == "int8dq": change_linear_weights_to_int8_dqtensors(model) @@ -56,6 +56,13 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compi change_linear_weights_to_int4_woqtensors(model.to(device=device)) elif quantization == "autoquant": model = autoquant(model.to(device=device)) + elif quantization == "fp8": + from float8_experimental.inference import quantize_to_float8, ActivationCasting, QuantConfig, ScalingGranularity + model.to(device) + quantize_to_float8(model, QuantConfig(ActivationCasting.DYNAMIC), scaling_granularity=ScalingGranularity.TensorWise) + + pass # no quantization applied, model is already on device and precision dtype. + with torch.no_grad(): result = evaluate( HFLM( @@ -78,7 +85,7 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compi parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate') parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use') parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation') - parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo","autoquant", "None"], help='Which quantization technique to apply') + parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo","autoquant", "fp8", "None"], help='Which quantization technique to apply') parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') parser.add_argument('--batch_size', type=int, default=1, help='Batch size to use for evaluation, note int8wo and int4wo work best with small batchsizes, int8dq works better with large batchsizes') parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time')