From 123a39f8741a40b339244feb9e1e6ffb4cd32b6f Mon Sep 17 00:00:00 2001 From: Sebastian Lee Date: Tue, 31 Jan 2023 15:14:06 +0100 Subject: [PATCH] Reuse tokenizer instead of loading new one. --- haystack/nodes/prompt/prompt_node.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/haystack/nodes/prompt/prompt_node.py b/haystack/nodes/prompt/prompt_node.py index e559b7e31b..10b787785e 100644 --- a/haystack/nodes/prompt/prompt_node.py +++ b/haystack/nodes/prompt/prompt_node.py @@ -9,7 +9,14 @@ import requests import torch -from transformers import pipeline, AutoModelForSeq2SeqLM, StoppingCriteria, StoppingCriteriaList, AutoTokenizer +from transformers import ( + pipeline, + AutoModelForSeq2SeqLM, + StoppingCriteria, + StoppingCriteriaList, + PreTrainedTokenizerFast, + PreTrainedTokenizer, +) from haystack import MultiLabel from haystack.environment import HAYSTACK_REMOTE_API_BACKOFF_SEC, HAYSTACK_REMOTE_API_MAX_RETRIES @@ -215,9 +222,8 @@ class StopWordsCriteria(StoppingCriteria): Stops text generation if any one of the stop words is generated. """ - def __init__(self, model_name_or_path: str, stop_words: List[str]): + def __init__(self, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], stop_words: List[str]): super().__init__() - tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) self.stop_words = tokenizer.encode(stop_words, add_special_tokens=False, return_tensors="pt") def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: @@ -245,7 +251,6 @@ def __init__( """ Creates an instance of HFLocalInvocationLayer used to invoke local Hugging Face models. - :param model_name_or_path: The name or path of the underlying model. :param max_length: The maximum length of the output text. :param use_auth_token: The token to use as HTTP bearer authorization for remote files. @@ -342,7 +347,7 @@ def invoke(self, *args, **kwargs): if key in kwargs } if stop_words: - sw = StopWordsCriteria(model_name_or_path=self.model_name_or_path, stop_words=stop_words) + sw = StopWordsCriteria(tokenizer=self.pipe.tokenizer, stop_words=stop_words) model_input_kwargs["stopping_criteria"] = StoppingCriteriaList([sw]) output = self.pipe(prompt, max_length=self.max_length, **model_input_kwargs) generated_texts = [o["generated_text"] for o in output if "generated_text" in o]