diff --git a/neuralmagic/benchmarks/common.py b/neuralmagic/benchmarks/common.py index 459ee47eddadc..fc78bcc10a357 100644 --- a/neuralmagic/benchmarks/common.py +++ b/neuralmagic/benchmarks/common.py @@ -1,5 +1,6 @@ import itertools import json +import os from argparse import Namespace from pathlib import Path from typing import Iterable, NamedTuple @@ -14,9 +15,13 @@ def download_model(model: str) -> None: """ - Downloads a hugging face model to cache - """ - download_weights_from_hf(model) + Downloads a hugging face model to cache + """ + cache_dir = os.getenv("HF_HOME") + allow_patterns = ["*.safetensors", "*.bin"] + download_weights_from_hf(model, + cache_dir=cache_dir, + allow_patterns=allow_patterns) get_tokenizer(model)