Skip to content

Commit

Permalink
Remove-warns (#26483)
Browse files Browse the repository at this point in the history
* fix stripping

* remove some warnings and update some warnings

* revert changes for other PR
  • Loading branch information
ArthurZucker authored Oct 2, 2023
1 parent 1b8decb commit e4dad4f
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 28 deletions.
8 changes: 4 additions & 4 deletions src/transformers/models/llama/tokenization_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __init__(

if legacy is None:
logger.warning_once(
f"You are using the default legacy behaviour of the {self.__class__}. If you see this, DO NOT PANIC! This is"
f"You are using the default legacy behaviour of the {self.__class__}. This is"
" expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you."
" If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it"
" means, and thouroughly read the reason why this was added as explained in"
Expand All @@ -138,7 +138,7 @@ def __init__(
self.add_bos_token = add_bos_token
self.add_eos_token = add_eos_token
self.use_default_system_prompt = use_default_system_prompt
self.sp_model = self.get_spm_processor()
self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False))

super().__init__(
bos_token=bos_token,
Expand All @@ -160,9 +160,9 @@ def unk_token_length(self):
return len(self.sp_model.encode(str(self.unk_token)))

# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor
def get_spm_processor(self):
def get_spm_processor(self, from_slow=False):
tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
if self.legacy: # no dependency on protobuf
if self.legacy or from_slow: # no dependency on protobuf
tokenizer.Load(self.vocab_file)
return tokenizer

Expand Down
9 changes: 5 additions & 4 deletions src/transformers/models/t5/tokenization_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def __init__(

if legacy is None:
logger.warning_once(
f"You are using the default legacy behaviour of the {self.__class__}. If you see this, DO NOT PANIC! This is"
f"You are using the default legacy behaviour of the {self.__class__}. This is"
" expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you."
" If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it"
" means, and thouroughly read the reason why this was added as explained in"
Expand All @@ -195,7 +195,7 @@ def __init__(
legacy = True

self.legacy = legacy
self.sp_model = self.get_spm_processor()
self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False))
self.vocab_file = vocab_file
self._extra_ids = extra_ids

Expand All @@ -210,9 +210,10 @@ def __init__(
**kwargs,
)

def get_spm_processor(self):
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor
def get_spm_processor(self, from_slow=False):
tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
if self.legacy: # no dependency on protobuf
if self.legacy or from_slow: # no dependency on protobuf
tokenizer.Load(self.vocab_file)
return tokenizer

Expand Down
5 changes: 0 additions & 5 deletions src/transformers/tokenization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,11 +979,6 @@ def _decode(
) -> str:
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)

if spaces_between_special_tokens:
logger.warning_once(
"spaces_between_special_tokens is deprecated and will be removed in transformers v5. It was adding spaces between `added_tokens`, not special tokens, "
"and does not exist in our fast implementation. Future tokenizers will handle the decoding process on a per-model rule."
)
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
legacy_added_tokens = set(self._added_tokens_encoder.keys()) - set(self.all_special_tokens) | {
token for token in self.additional_special_tokens if self.convert_tokens_to_ids(token) >= self.vocab_size
Expand Down
15 changes: 0 additions & 15 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2204,11 +2204,6 @@ def _from_pretrained(
f"Found a {token.__class__} in the saved `added_tokens_decoder`, should be a dictionary."
)
else:
logger.warning_once(
"Loading the tokenizer from the `special_tokens_map.json` and the `added_tokens.json` will be removed in `transformers 5`, "
" it is kept for forward compatibility, but it is recommended to update your `tokenizer_config.json` by uploading it again."
" You will see the new `added_tokens_decoder` attribute that will store the relevant information."
)
# begin legacy: read the added_tokens_file and update kwargs with special_tokens_map if modified
if special_tokens_map_file is not None:
with open(special_tokens_map_file, encoding="utf-8") as special_tokens_map_handle:
Expand Down Expand Up @@ -2277,16 +2272,6 @@ def _from_pretrained(
# uses the information stored in `added_tokens_decoder`. Checks after addition that we have the same ids
if init_kwargs.get("slow_to_fast", False):
tokenizer.add_tokens([token for _, token in sorted(added_tokens_decoder.items(), key=lambda x: x[0])])
warnings = ""
for index, token in sorted(added_tokens_decoder.items(), key=lambda x: x[0]):
if tokenizer.convert_tokens_to_ids(str(token)) != index:
warnings += f"\texpected id: {tokenizer.convert_tokens_to_ids(str(token))}, found: {index}, token: `{token}`,\n"
if len(warnings) > 1:
logger.warn(
f"You are converting a {slow_tokenizer.__class__.__name__} to a {cls.__name__}, but"
f" wrong indexes were founds when adding the `added_tokens` from the `slow` tokenizer to the `fast`. "
f" The following tokens had unexpected id :\n{warnings}. You should try using `from_slow`."
)
# finally we add all the special_tokens to make sure eveything is initialized
tokenizer.add_tokens(tokenizer.all_special_tokens_extended, special_tokens=True)

Expand Down

0 comments on commit e4dad4f

Please sign in to comment.