Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix resolve_weight_file_from_hf_hub #10

Merged
merged 1 commit into from
Dec 27, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 103 additions & 111 deletions paddlenlp/transformers/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from paddle.utils.download import is_url as is_remote_url
from tqdm.auto import tqdm

from paddlenlp.utils.downloader import get_path_from_url_with_filelock, hf_file_exists
from paddlenlp.utils.downloader import get_path_from_url_with_filelock
from paddlenlp.utils.env import (
CONFIG_NAME,
LEGACY_CONFIG_NAME,
Expand Down Expand Up @@ -366,50 +366,28 @@ def resolve_weight_file_from_hf_hub(
subfolder (str, optional) An optional value corresponding to a folder inside the repo.
"""
is_sharded = False

if use_safetensors:
# SAFE WEIGHTS
if hf_file_exists(repo_id, SAFE_WEIGHTS_INDEX_NAME, subfolder=subfolder):
file_name = SAFE_WEIGHTS_INDEX_NAME
is_sharded = True
elif hf_file_exists(repo_id, SAFE_WEIGHTS_NAME, subfolder=subfolder):
file_name = SAFE_WEIGHTS_NAME
else:
raise EntryNotFoundError(
message=f"can not find the safetensors weight file from: https://huggingface.co/{repo_id}",
response=None,
)
file_name_list = [
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
]
else:
if convert_from_torch:
# TORCH WEIGHTS
if hf_file_exists(repo_id, PYTORCH_WEIGHTS_INDEX_NAME, subfolder=subfolder):
file_name = PYTORCH_WEIGHTS_INDEX_NAME
is_sharded = True
elif hf_file_exists(repo_id, PYTORCH_WEIGHTS_NAME, subfolder=subfolder):
file_name = PYTORCH_WEIGHTS_NAME
else:
raise EntryNotFoundError(
message=f"can not find the pytorch weight file from: https://huggingface.co/{repo_id}",
response=None,
)
else:
if hf_file_exists(repo_id, PADDLE_WEIGHTS_INDEX_NAME, subfolder=subfolder):
file_name = PADDLE_WEIGHTS_INDEX_NAME
is_sharded = True
elif hf_file_exists(repo_id, PADDLE_WEIGHTS_NAME, subfolder=subfolder):
file_name = PADDLE_WEIGHTS_NAME
else:
raise EntryNotFoundError(
message=f"can not find the paddle weight file from: https://huggingface.co/{repo_id}",
response=None,
)

file_name_list = [file_name]
file_name_list = [
PYTORCH_WEIGHTS_INDEX_NAME,
PADDLE_WEIGHTS_INDEX_NAME,
PYTORCH_WEIGHTS_NAME,
PADDLE_WEIGHTS_NAME,
SAFE_WEIGHTS_NAME, # (NOTE,lxl): 兼容极端情况
]
resolved_file = None
for fn in file_name_list:
resolved_file = cached_file_for_hf_hub(
repo_id, fn, cache_dir, subfolder, _raise_exceptions_for_missing_entries=False
)
if resolved_file is not None:
if resolved_file.endswith(".json"):
is_sharded = True
break

if resolved_file is None:
Expand Down Expand Up @@ -1458,6 +1436,30 @@ def _resolve_model_file_path(
is_sharded = False
sharded_metadata = None

# -1. when it's from HF
if from_hf_hub or convert_from_torch:
resolved_archive_file, is_sharded = resolve_weight_file_from_hf_hub(
pretrained_model_name_or_path,
cache_dir=cache_dir,
convert_from_torch=convert_from_torch,
subfolder=subfolder,
use_safetensors=use_safetensors,
)
# We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
resolved_sharded_files = None
if is_sharded:
# resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
resolved_sharded_files, sharded_metadata = get_checkpoint_shard_files(
pretrained_model_name_or_path,
resolved_archive_file,
from_aistudio=from_aistudio,
from_hf_hub=from_hf_hub,
cache_dir=cache_dir,
subfolder=subfolder,
)

return resolved_archive_file, resolved_sharded_files, sharded_metadata, is_sharded

if pretrained_model_name_or_path is not None:
# the following code use a lot of os.path.join, hence setting subfolder to empty str if None
if subfolder is None:
Expand Down Expand Up @@ -1561,95 +1563,85 @@ def get_file_path(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_NAME, v
filename = pretrained_model_name_or_path
resolved_archive_file = get_path_from_url_with_filelock(pretrained_model_name_or_path)
else:
# -1. when it's from HF
if from_hf_hub:
resolved_archive_file, is_sharded = resolve_weight_file_from_hf_hub(
pretrained_model_name_or_path,

# set correct filename
if use_safetensors is not False:
filename = _add_variant(SAFE_WEIGHTS_NAME, variant)
else:
filename = _add_variant(PADDLE_WEIGHTS_NAME, variant)

try:
# Load from URL or cache if already cached
cached_file_kwargs = dict(
cache_dir=cache_dir,
convert_from_torch=convert_from_torch,
subfolder=subfolder,
use_safetensors=use_safetensors,
from_aistudio=from_aistudio,
_raise_exceptions_for_missing_entries=False,
)
else:
resolved_archive_file = None
if pretrained_model_name_or_path in cls.pretrained_init_configuration:
# fetch the weight url from the `pretrained_resource_files_map`
resource_file_url = cls.pretrained_resource_files_map["model_state"][
pretrained_model_name_or_path
]
resolved_archive_file = cached_file(
resource_file_url, _add_variant(PADDLE_WEIGHTS_NAME, variant), **cached_file_kwargs
)

if resolved_archive_file is None:
resolved_archive_file = cached_file(
pretrained_model_name_or_path, filename, **cached_file_kwargs
)

# set correct filename
if use_safetensors is not False:
filename = _add_variant(SAFE_WEIGHTS_NAME, variant)
else:
# xxx.pdparams in pretrained_resource_files_map renamed model_state.pdparams
filename = _add_variant(PADDLE_WEIGHTS_NAME, variant)

try:
# Load from URL or cache if already cached
cached_file_kwargs = dict(
cache_dir=cache_dir,
subfolder=subfolder,
from_aistudio=from_aistudio,
_raise_exceptions_for_missing_entries=False,
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
# result when internet is up, the repo and revision exist, but the file does not.
if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant):
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
resolved_archive_file = cached_file(
pretrained_model_name_or_path,
_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant),
**cached_file_kwargs,
)
resolved_archive_file = None
if pretrained_model_name_or_path in cls.pretrained_init_configuration:
# fetch the weight url from the `pretrained_resource_files_map`
resource_file_url = cls.pretrained_resource_files_map["model_state"][
pretrained_model_name_or_path
]
resolved_archive_file = cached_file(
resource_file_url, _add_variant(PADDLE_WEIGHTS_NAME, variant), **cached_file_kwargs
)

if resolved_archive_file is None:
resolved_archive_file = cached_file(
pretrained_model_name_or_path, filename, **cached_file_kwargs
if resolved_archive_file is not None:
is_sharded = True
elif use_safetensors:
raise EnvironmentError(
f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} and thus cannot be loaded with `safetensors`. Please make sure that the model has been saved with `safe_serialization=True` or do not set `use_safetensors=True`."
)

else:
# xxx.pdparams in pretrained_resource_files_map renamed model_state.pdparams
# This repo has no safetensors file of any kind, we switch to PyTorch.
filename = _add_variant(PADDLE_WEIGHTS_NAME, variant)

# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
# result when internet is up, the repo and revision exist, but the file does not.
if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant):
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
resolved_archive_file = cached_file(
pretrained_model_name_or_path,
_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant),
**cached_file_kwargs,
)
if resolved_archive_file is not None:
is_sharded = True
elif use_safetensors:
raise EnvironmentError(
f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} and thus cannot be loaded with `safetensors`. Please make sure that the model has been saved with `safe_serialization=True` or do not set `use_safetensors=True`."
)
else:
# This repo has no safetensors file of any kind, we switch to PyTorch.
filename = _add_variant(PADDLE_WEIGHTS_NAME, variant)
resolved_archive_file = cached_file(
pretrained_model_name_or_path, filename, **cached_file_kwargs
)
if resolved_archive_file is None and filename == _add_variant(PADDLE_WEIGHTS_NAME, variant):
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
resolved_archive_file = cached_file(
pretrained_model_name_or_path,
_add_variant(PADDLE_WEIGHTS_INDEX_NAME, variant),
**cached_file_kwargs,
)
# raise ValueError(resolved_archive_file)
if resolved_archive_file is not None:
is_sharded = True
if resolved_archive_file is None:
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
# message.
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named"
f" {_add_variant(PADDLE_WEIGHTS_NAME, variant)}."
pretrained_model_name_or_path, filename, **cached_file_kwargs
)
except Exception as e:
logger.info(e)
# For any other exception, we throw a generic error.
if resolved_archive_file is None and filename == _add_variant(PADDLE_WEIGHTS_NAME, variant):
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
resolved_archive_file = cached_file(
pretrained_model_name_or_path,
_add_variant(PADDLE_WEIGHTS_INDEX_NAME, variant),
**cached_file_kwargs,
)
# raise ValueError(resolved_archive_file)
if resolved_archive_file is not None:
is_sharded = True
if resolved_archive_file is None:
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
# message.
raise EnvironmentError(
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
" from 'https://paddlenlp.bj.bcebos.com'"
f"{pretrained_model_name_or_path} does not appear to have a file named"
f" {_add_variant(PADDLE_WEIGHTS_NAME, variant)}."
)
except Exception as e:
logger.info(e)
# For any other exception, we throw a generic error.
raise EnvironmentError(
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
" from 'https://paddlenlp.bj.bcebos.com'"
)

if is_local:
logger.info(f"Loading weights file {archive_file}")
Expand Down
Loading