-
Notifications
You must be signed in to change notification settings - Fork 120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
concurrency without model cloning #573
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
kwargs["infer_context"] = infer_context | ||
return super().generate(*args, **kwargs) | ||
|
||
def __call__(self, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dtrawins, please explain why __call__
should be different from forward
behavior. __call__
will eventually call forward
without any added semantics on top of that. So why cannot we move this code to forward?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@slyalin that is indeed not required. potentially it could be used to pass the infer_request context. Assuming we create new request in the forward method if generate method didn't pass the context, that is not required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
THanks a lot for your work @dtrawins
def compile(self): | ||
if self.request is None: | ||
if self.compiled_model is None: | ||
super().compile() | ||
self.request = self.request.create_infer_request() | ||
self.compiled_model = self.request |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we don't need to call self.request.create_infer_request()
then there is not need to override this method, I this we should we remove it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also if we want to rename request
to compiled_model
I think we should do it for all OVModels + add a warning stating that the request attribute will be deprecated in the future, it could make sense to do it in an other PR instead
if self.stateful: | ||
# Need a marker to differentiate the first generate iteration from the others in | ||
# the first condition at the function beginning above. | ||
# It should be something that is not None and it should be True when converted to Boolean. | ||
past_key_values = ((),) | ||
past_key_values = ((inputs["beam_idx"]), infer_request) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not related to past_key_values
so I don't think we should update past_key_values
here, the resulting output will not be what it's expected for example :
output = model(**tokens)
pkv = output.past_key_values
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is a special case for using stateful models. Such models are not using past_key_values because they preserve those information in the inference state instead. That field is used here to pass the beam_idx used for beam search algorithm and pass the inference execution context between generation cycles.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My point is that it's not related to past_key_values
so we shouldn't update this variable with beam_idx / inference execution context
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@slyalin can you add your comments here? The idea was to reused this variable for stateful models because they don't use it at all. That was the only method we found that could be used to pass the beam_idx and execution context (which includes the state data) without changing the model API. The other alternative was with using model.clone() method for each thread which would also using a separate execution context without duplicating memory consumption #564. Would cloning be better method to support concurrency in the execution? Is there some other option we are not aware of? I guess it is a bit unique situation with the stateful models in openvino so probably it is not handled in transformers lib.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My point is that it's not related to past_key_values
Definitely it is related to past_key_values
even more than the old ((),)
value. beam_idx
together with infer_request
are used to track past_key_values
for a particular sequence. Literally, infer_request
has a model state that consists of past_key_values
tensors, and beam_idx
allows indirect rows reordering in that state in case of beam search. This PR just makes it more explicit than it was before and moves these attributes from the model class instance to each sequence, which allows having multiple sequences for a single model class instance.
@echarlaix, do you have a better alternative to pass these values?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@echarlaix if we create new modelOutput data class and it is returned by the Forward method, how it could be passed back to the Forward method in the next cycle?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you try something like :
from dataclasses import dataclass
from transformers.modeling_outputs import ModelOutput
@dataclass
class CausalLMOutputWithPast(ModelOutput):
logits: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
beam_idx: Optional[int] = None
inference_session = None
and then overwritte _update_model_kwargs_for_generation
https://github.com/huggingface/transformers/blob/45c065109074d60c587d3e562f16531d02a422f6/src/transformers/generation/utils.py#L630 by adding somethign like :
def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
model_kwargs: dict[str],
is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
) -> dict[str]:
model_kwargs = super()._update_model_kwargs_for_generation(
outputs=outputs,
model_kwargs=model_kwargs,
is_encoder_decoder=is_encoder_decoder,
standardize_cache_format=standardize_cache_format,
)
if "beam_idx" in outputs:
model_kwargs["beam_idx"] = outputs["beam_idx"]
return model_kwargs
(same for inference_session)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me know if you need help on this @dtrawins
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@echarlaix @eaidova @slyalin Could you have a look if the latest version is passing the context fine now?
I'm not reusing past_key_values for stateful models with th generation context. There are additional fields in the forward
output beam_idx and infer_request. Now only 9 tests is left to fix but seams unrelated to concurrency. Probably rebase from main is needed.
Anyway can one comment if beam_idx would be populated correctly. It is not defined now in reorder_caches for stateful models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What i confirmed is beam_idx was not passed correctly. The same, initial beam_idx was circulating for the whole pipeline resulting in incorrect accuracy with beam search. Somehow it was not detected by functional tests.
Anyway my proposal is to pass the beam_idx content from reorder_caches method inside past_key_value. I tested it gives correct results and the code is in my opinion clean. The forward method returns empty past_key_values as expected for stateful models. In case someone would like to manage the pipeline for stateless models outside of transformers using just forward method, it would be still possible - beam_idx should be passed inside past_key_value and inference_request context via model_args. Anyway that is probably unlikely use case scenario. Would it be acceptable?
@@ -86,6 +86,7 @@ def __init__( | |||
|
|||
self.model = model | |||
self.request = None | |||
self.compiled_model = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure why we need a new attribute here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is needed to create new infer_request in the context of generate method for each concurrent thread. So far we had in the model class request attribute which was pointing to a static infer_request and can not be used to allocate new request. Generally there is a bit confusing setup when the request attribute is set to the compiled_model object in the based class but latest it is overwritten to become the infer_request. Eventually the recommendation would be to switch to using compiled_model attribute instead and create infer_requests dynamically. It was proposed to make this switch in a separate PR.
@@ -343,8 +343,10 @@ def normalized_config(self): | |||
def compile(self): | |||
if self.request is None: | |||
super().compile() | |||
self.compiled_model = self.request |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it could make sense to also set self.compiled_model
to None (along with self.request
) when the model is statically reshaped or moved to an other device https://github.com/huggingface/optimum-intel/blob/2a397e37dd606cdeafce6b356f5e7f869630ea1b/optimum/intel/openvino/modeling_base.py#L442C9-L442C21
an option could be to add a clear_requests
method as done for seq2seq models
Currently it should work anyway as self.compiled_model
will be correctly updated after calling .compile()
(as self.request is set to None after each of these steps)
if self.stateful: | ||
# Need a marker to differentiate the first generate iteration from the others in | ||
# the first condition at the function beginning above. | ||
# It should be something that is not None and it should be True when converted to Boolean. | ||
past_key_values = ((),) | ||
past_key_values = ((inputs["beam_idx"]), infer_request) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@echarlaix, do you have a better alternative to pass these values?
why not introducing a class inheriting from ModelOutput
like CausalLMOutputWithPast
https://github.com/huggingface/transformers/blob/536ea2aca234fb48c5c69769431d643b0d93b233/src/transformers/modeling_outputs.py#L678 with dedicated beam_idx
/ inference_request
arguments ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks for iterating on this @dtrawins !
@@ -661,8 +704,7 @@ def _reorder_cache( | |||
batch_size = beam_idx.shape[0] | |||
indices = np.array(range(batch_size * self.config.num_attention_heads)) | |||
indices = indices.reshape([batch_size, self.config.num_attention_heads]) | |||
self.next_beam_idx = np.take(indices, beam_idx, 0).flatten() | |||
return past_key_values | |||
return ((np.take(indices, beam_idx, 0).flatten()), past_key_values[1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't it be :
return ((np.take(indices, beam_idx, 0).flatten()), past_key_values[1]) | |
return past_key_values |
@@ -322,6 +340,7 @@ def normalized_config(self): | |||
def compile(self): | |||
if self.request is None: | |||
super().compile() | |||
self.compiled_model = self.request | |||
self.request = self.request.create_infer_request() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not remove this
self.request = self.request.create_infer_request() |
and use self.request instead of self.compiled_model
? (self.request
doesn't seem to be used anywhere)
tests/openvino/test_modeling.py
Outdated
@parameterized.expand(SUPPORTED_ARCHITECTURES) | ||
def test_compare_to_transformers_multithreading(self, model_arch): | ||
model_id = MODEL_NAMES[model_arch] | ||
not_stateful = ["gpt_bigcode"] | ||
if is_openvino_version("<", "2024.0"): | ||
not_stateful.append("mixtral") | ||
|
||
if is_openvino_version("<", "2024.1"): | ||
not_stateful.extend(["llama", "gemma"]) | ||
|
||
if "gptq" in model_arch: | ||
self.skipTest("GPTQ model loading unsupported with AutoModelForCausalLM") | ||
if model_arch in ["chatglm", "baichuan2"]: | ||
self.skipTest("Models " + model_id + "doesn't support concurrent execution in AutoModelForCausalLM") | ||
|
||
set_seed(SEED) | ||
model_kwargs = {} | ||
if model_arch in self.REMOTE_CODE_MODELS: | ||
model_kwargs = {"trust_remote_code": True} | ||
|
||
ov_model = OVModelForCausalLM.from_pretrained(model_id, export=True, ov_config=F32_CONFIG, **model_kwargs) | ||
self.assertIsInstance(ov_model.config, PretrainedConfig) | ||
self.assertTrue(ov_model.use_cache) | ||
self.assertEqual( | ||
ov_model.stateful, self.IS_SUPPORT_STATEFUL and ov_model.config.model_type not in not_stateful | ||
) | ||
|
||
transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) | ||
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS) | ||
if model_arch == "qwen": | ||
transformers_model.to(torch.float32) | ||
inputs_list = ["This is a cat", "This is a dog", "Yet another test"] | ||
tokens_list = [ | ||
tokenizer(inputs, return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None) | ||
for inputs in inputs_list | ||
] | ||
|
||
def run_ov_model(tokens, transformers_model, ov_model): | ||
# global ov_model, transformers_model | ||
# position_ids = None | ||
# if model_arch.replace("_", "-") in MODEL_TYPES_REQUIRING_POSITION_IDS: | ||
# input_shape = tokens["input_ids"].shape | ||
# position_ids = ( | ||
# torch.arange(0, input_shape[-1], dtype=torch.long).unsqueeze(0).view(-1, input_shape[-1]) | ||
# ) | ||
set_seed(SEED) | ||
ov_outputs = ov_model(**tokens) | ||
|
||
self.assertTrue("logits" in ov_outputs) | ||
self.assertIsInstance(ov_outputs.logits, torch.Tensor) | ||
# self.assertTrue("past_key_values" in ov_outputs) | ||
# self.assertIsInstance(ov_outputs.past_key_values, tuple) | ||
# if self.IS_SUPPORT_STATEFUL and model_arch != "gpt_bigcode": | ||
# self.assertTrue(len(ov_outputs.past_key_values) == 1 and len(ov_outputs.past_key_values[0]) == 0) | ||
with torch.no_grad(): | ||
transformers_outputs = transformers_model(**tokens) | ||
# Compare tensor outputs | ||
self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, atol=1e-4)) | ||
# self.assertTrue(False) | ||
|
||
run_on_multiple_threads(run_ov_model, tokens_list, (transformers_model, ov_model)) | ||
|
||
del transformers_model | ||
del ov_model | ||
gc.collect() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The time taken to run all tests is already non negligible so I think we should merge it with test_compare_to_transformers
(to not duplicate steps like export)
@@ -608,6 +674,42 @@ def test_pipeline(self, model_arch): | |||
del model | |||
gc.collect() | |||
|
|||
@parameterized.expand(SUPPORTED_ARCHITECTURES) | |||
def test_pipeline_multithreading(self, model_arch): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment, can be merged with test_pipeline
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | ||
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None | ||
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None | ||
infer_request: Optional[openvino.runtime.InferRequest] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we rename it to something like request
or inference_request
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
current name is aligned with openvino api name, so for me infer_request sounds better
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that'd be clearer for users who are not familiar with the openvino ecosystem, also we don't use infer_request
anywhere in optimum-intel so was thinking about something a bit more explicit
I think we're close to merge, just waiting for couple of points above to be addressed, let me know if you need any help from my side @dtrawins (fixing conflicts / applying suggested changes) |
Support for multi threading in execution?