Skip to content

Commit

Permalink
Partial local tokenizer load (#9807)
Browse files Browse the repository at this point in the history
* Allow partial loading of a cached tokenizer

* Warning > Info

* Update src/transformers/tokenization_utils_base.py

Co-authored-by: Sylvain Gugger <[email protected]>

* Raise error if not local_files_only

Co-authored-by: Sylvain Gugger <[email protected]>
  • Loading branch information
LysandreJik and sgugger authored Jan 28, 2021
1 parent 25fcb5c commit 6cb0a6f
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 10 deletions.
2 changes: 1 addition & 1 deletion src/transformers/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1239,7 +1239,7 @@ def get_from_cache(
# the models might've been found if local_files_only=False
# Notify the user about that
if local_files_only:
raise ValueError(
raise FileNotFoundError(
"Cannot find the requested files in the cached path and outgoing traffic has been"
" disabled. To enable model look-ups and downloads online, set 'local_files_only'"
" to False."
Expand Down
35 changes: 26 additions & 9 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1730,27 +1730,41 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike],

# Get files from url, cache, or disk depending on the case
resolved_vocab_files = {}
unresolved_files = []
for file_id, file_path in vocab_files.items():
if file_path is None:
resolved_vocab_files[file_id] = None
else:
try:
resolved_vocab_files[file_id] = cached_path(
file_path,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
)
try:
resolved_vocab_files[file_id] = cached_path(
file_path,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
)
except FileNotFoundError as error:
if local_files_only:
unresolved_files.append(file_id)
else:
raise error

except requests.exceptions.HTTPError as err:
if "404 Client Error" in str(err):
logger.debug(err)
resolved_vocab_files[file_id] = None
else:
raise err

if len(unresolved_files) > 0:
logger.info(
f"Can't load following files from cache: {unresolved_files} and cannot check if these "
"files are necessary for the tokenizer to operate."
)

if all(full_file_name is None for full_file_name in resolved_vocab_files.values()):
msg = (
f"Can't load tokenizer for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
Expand All @@ -1760,6 +1774,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike],
raise EnvironmentError(msg)

for file_id, file_path in vocab_files.items():
if file_id not in resolved_vocab_files:
continue

if file_path == resolved_vocab_files[file_id]:
logger.info("loading file {}".format(file_path))
else:
Expand Down

0 comments on commit 6cb0a6f

Please sign in to comment.