-
Notifications
You must be signed in to change notification settings - Fork 120
/
Copy pathmodeling_seq2seq.py
1171 lines (1020 loc) · 53.3 KB
/
modeling_seq2seq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
import os
from pathlib import Path
from tempfile import gettempdir
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
import numpy as np
import openvino
import torch
import transformers
from openvino.runtime import Core
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
AutoModelForSpeechSeq2Seq,
Pix2StructForConditionalGeneration,
WhisperForConditionalGeneration,
)
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from transformers.generation import GenerationMixin
from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
from transformers.models.whisper.tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE
from .modeling_base_seq2seq import OVBaseModelForSeq2SeqLM
from .utils import _print_compiled_model_properties
if TYPE_CHECKING:
from transformers import PretrainedConfig
core = Core()
logger = logging.getLogger(__name__)
_TOKENIZER_FOR_DOC = "AutoTokenizer"
INPUTS_DOCSTRING = r"""
Arguments:
encoder (`openvino.runtime.Model`):
The OpenVINO Runtime model associated to the encoder.
decoder (`openvino.runtime.Model`):
The OpenVINO Runtime model associated to the decoder.
decoder_with_past (`openvino.runtime.Model`):
The OpenVINO Runtime model associated to the decoder with past key values.
config (`transformers.PretrainedConfig`):
[PretrainedConfig](https://huggingface.co/docs/transformers/main_classes/configuration#transformers.PretrainedConfig)
is an instance of the configuration associated to the model. Initializing with a config file does
not load the weights associated with the model, only the configuration.
"""
ENCODER_INPUTS_DOCSTRING = r"""
Arguments:
input_ids (`torch.LongTensor`):
Indices of input sequence tokens in the vocabulary of shape `(batch_size, encoder_sequence_length)`.
attention_mask (`torch.LongTensor`):
Mask to avoid performing attention on padding token indices, of shape
`(batch_size, encoder_sequence_length)`. Mask values selected in `[0, 1]`.
"""
DECODER_INPUTS_DOCSTRING = r"""
Arguments:
input_ids (`torch.LongTensor`):
Indices of decoder input sequence tokens in the vocabulary of shape `(batch_size, decoder_sequence_length)`.
encoder_hidden_states (`torch.FloatTensor`):
The encoder `last_hidden_state` of shape `(batch_size, encoder_sequence_length, hidden_size)`.
encoder_attention_mask (`torch.LongTensor`, *optional*):
Mask to avoid performing cross-attention on padding tokens indices of encoder `input_ids`.
past_key_values (`tuple(tuple(torch.FloatTensor), *optional*)`
Contains the precomputed key and value hidden states of the attention blocks used to speed up decoding.
The tuple is of length `config.n_layers` with each tuple having 2 tensors of shape
`(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)` and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
"""
SEQ2SEQ_MODEL_DOCSTRING = r"""
Arguments:
input_ids (`torch.LongTensor`):
Indices of input sequence tokens in the vocabulary of shape `(batch_size, encoder_sequence_length)`.
attention_mask (`torch.LongTensor`):
Mask to avoid performing attention on padding token indices, of shape
`(batch_size, encoder_sequence_length)`. Mask values selected in `[0, 1]`.
decoder_input_ids (`torch.LongTensor`):
Indices of decoder input sequence tokens in the vocabulary of shape `(batch_size, decoder_sequence_length)`.
encoder_outputs (`torch.FloatTensor`):
The encoder `last_hidden_state` of shape `(batch_size, encoder_sequence_length, hidden_size)`.
past_key_values (`tuple(tuple(torch.FloatTensor), *optional*)`
Contains the precomputed key and value hidden states of the attention blocks used to speed up decoding.
The tuple is of length `config.n_layers` with each tuple having 2 tensors of shape
`(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)` and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
"""
TRANSLATION_EXAMPLE = r"""
Example of text generation:
```python
>>> from transformers import {processor_class}
>>> from optimum.intel import {model_class}
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> text = "He never went out without a book under his arm, and he often came back with two."
>>> inputs = tokenizer(text, return_tensors="pt")
>>> gen_tokens = model.generate(**inputs)
>>> outputs = tokenizer.batch_decode(gen_tokens)
```
Example using `transformers.pipeline`:
```python
>>> from transformers import {processor_class}, pipeline
>>> from optimum.intel import {model_class}
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> pipe = pipeline("translation_en_to_fr", model=model, tokenizer=tokenizer)
>>> text = "He never went out without a book under his arm, and he often came back with two."
>>> outputs = pipe(text)
```
"""
PIX2STRUCT_MODEL_DOCSTRING = r"""
Args:
flattened_patches (`torch.FloatTensor` of shape `(batch_size, seq_length, hidden_size)`):
Flattened pixel patches. the `hidden_size` is obtained by the following formula: `hidden_size` =
`num_channels` * `patch_size` * `patch_size`
The process of flattening the pixel patches is done by `Pix2StructProcessor`.
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices.
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary.
Pix2StructText uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
`past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
be used by default.
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*)
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at
the output of the last layer of the encoder. Used in the cross-attention of the decoder.
past_key_values (`tuple(tuple(torch.FloatTensor), *optional*, defaults to `None`)`
Contains the precomputed key and value hidden states of the attention blocks used to speed up decoding.
The tuple is of length `config.n_layers` with each tuple having 2 tensors of shape
`(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)` and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
"""
_PROCESSOR_FOR_DOC = "AutoProcessor"
PIX2STRUCT_EXAMPLE = r"""
Example of pix2struct:
```python
>>> from transformers import {processor_class}
>>> from optimum.intel import {model_class}
>>> from PIL import Image
>>> import requests
>>> processor = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}", export=True)
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> question = "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud"
>>> inputs = processor(images=image, text=question, return_tensors="pt")
>>> gen_tokens = model.generate(**inputs)
>>> outputs = processor.batch_decode(gen_tokens, skip_special_tokens=True)
```
"""
SPEECH_SEQ2SEQ_MODEL_DOCSTRING = r"""
Args:
input_features (`torch.FloatTensor`):
Mel features extracted from the raw speech waveform.
`(batch_size, feature_size, encoder_sequence_length)`.
decoder_input_ids (`torch.LongTensor`):
Indices of decoder input sequence tokens in the vocabulary of shape `(batch_size, decoder_sequence_length)`.
encoder_outputs (`torch.FloatTensor`):
The encoder `last_hidden_state` of shape `(batch_size, encoder_sequence_length, hidden_size)`.
past_key_values (`tuple(tuple(torch.FloatTensor), *optional*, defaults to `None`)`
Contains the precomputed key and value hidden states of the attention blocks used to speed up decoding.
The tuple is of length `config.n_layers` with each tuple having 2 tensors of shape
`(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)` and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
"""
AUTOMATIC_SPEECH_RECOGNITION_EXAMPLE = r"""
Example of text generation:
```python
>>> from transformers import {processor_class}
>>> from optimum.intel.openvino import {model_class}
>>> from datasets import load_dataset
>>> processor = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = processor.feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
>>> gen_tokens = model.generate(inputs=inputs.input_features)
>>> outputs = processor.tokenizer.batch_decode(gen_tokens)
```
Example using `transformers.pipeline`:
```python
>>> from transformers import {processor_class}, pipeline
>>> from optimum.intel.openvino import {model_class}
>>> from datasets import load_dataset
>>> processor = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> speech_recognition = pipeline("automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor)
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> pred = speech_recognition(ds[0]["audio"]["array"])
```
"""
@add_start_docstrings(
"""
Sequence-to-sequence model with a language modeling head for OpenVINO inference.
""",
INPUTS_DOCSTRING,
)
class OVModelForSeq2SeqLM(OVBaseModelForSeq2SeqLM, GenerationMixin):
auto_model_class = AutoModelForSeq2SeqLM
main_input_name = "input_ids"
export_feature = "text2text-generation"
def __init__(
self,
encoder: openvino.runtime.Model,
decoder: openvino.runtime.Model,
decoder_with_past: openvino.runtime.Model = None,
config: transformers.PretrainedConfig = None,
**kwargs,
):
super().__init__(
encoder=encoder, decoder=decoder, decoder_with_past=decoder_with_past, config=config, **kwargs
)
self.device = torch.device("cpu")
self.decoder_with_past = None
enable_compilation = kwargs.get("compile", True)
self.encoder = OVEncoder(self.encoder_model, parent_model=self)
self.decoder = OVDecoder(self.decoder_model, parent_model=self)
if self.use_cache:
self.decoder_with_past = OVDecoder(self.decoder_with_past_model, parent_model=self)
if enable_compilation:
self.compile()
# Avoid warnings when creating a transformers pipeline
AutoConfig.register(self.base_model_prefix, AutoConfig)
try:
self.auto_model_class.register(AutoConfig, self.__class__)
except AttributeError:
pass
def to(self, device: str):
if isinstance(device, str):
self._device = device.upper()
self.encoder._device = self._device
self.decoder._device = self._device
if self.use_cache:
self.decoder_with_past._device = self._device
self.clear_requests()
else:
logger.debug(f"device must be of type {str} but got {type(device)} instead")
return self
@add_start_docstrings_to_model_forward(
SEQ2SEQ_MODEL_DOCSTRING.format("batch_size, sequence_length")
+ TRANSLATION_EXAMPLE.format(
processor_class=_TOKENIZER_FOR_DOC,
model_class="OVModelForSeq2SeqLM",
checkpoint="echarlaix/t5-small-openvino",
)
)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
**kwargs,
) -> Seq2SeqLMOutput:
# Encode if needed : first prediction pass
if encoder_outputs is None:
encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
# Decode
if past_key_values is None or self.decoder_with_past is None:
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
encoder_hidden_states=encoder_outputs.last_hidden_state,
encoder_attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
else:
decoder_outputs = self.decoder_with_past(
input_ids=decoder_input_ids[:, -1:], # Cut decoder_input_ids if past is used
past_key_values=past_key_values,
encoder_hidden_states=encoder_outputs.last_hidden_state,
encoder_attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
return Seq2SeqLMOutput(logits=decoder_outputs.logits, past_key_values=decoder_outputs.past_key_values)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs,
) -> Dict:
return {
"decoder_input_ids": input_ids,
"past_key_values": past_key_values or kwargs.get("past", None),
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache,
}
def get_encoder(self):
return self.encoder
# Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration._reorder_cache
@staticmethod
def _reorder_cache(past, beam_idx) -> Tuple[Tuple[torch.FloatTensor]]:
reordered_past = ()
for layer_past in past:
# Cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(np.take(past_state, beam_idx, 0) for past_state in layer_past[:2]) + layer_past[2:],
)
return reordered_past
def reshape(self, batch_size: int, sequence_length: int):
"""
Propagates the given input shapes on the model's layers, fixing the inputs shapes of the model.
Arguments:
batch_size (`int`):
The batch size.
sequence_length (`int`):
The sequence length.
"""
super().reshape(batch_size, sequence_length)
self.clear_requests()
return self
def half(self):
"""
Converts all the model weights to FP16 for more efficient inference on GPU.
"""
super().half()
self.clear_requests()
return self
def clear_requests(self):
self.encoder.request = None
self.decoder.request = None
if self.use_cache:
self.decoder_with_past.request = None
def compile(self):
self.encoder._compile()
self.decoder._compile()
if self.use_cache:
self.decoder_with_past._compile()
class OVEncoder:
"""
Encoder model for OpenVINO inference.
Arguments:
request (`openvino.runtime.ie_api.InferRequest`):
The OpenVINO inference request associated to the encoder.
"""
def __init__(self, model: openvino.runtime.Model, parent_model: OVModelForSeq2SeqLM):
self.model = model
self.parent_model = parent_model
self._device = self.parent_model._device
self.device = torch.device("cpu")
self.input_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.inputs)}
self.main_input_name = self.parent_model.main_input_name or "input_ids"
self.request = None
@add_start_docstrings_to_model_forward(ENCODER_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: torch.LongTensor = None,
**kwargs,
) -> BaseModelOutput:
self._compile()
# Model inputs
inputs = {self.main_input_name: input_ids if input_ids is not None else kwargs.get(self.main_input_name)}
# Add the attention_mask inputs when needed
if "attention_mask" in self.input_names:
inputs["attention_mask"] = attention_mask
# Run inference
last_hidden_state = torch.from_numpy(
self.request(inputs, share_inputs=True, share_outputs=True)["last_hidden_state"]
).to(self.device)
return BaseModelOutput(last_hidden_state=last_hidden_state)
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def _compile(self):
ov_config = {**self.parent_model.ov_config}
if (
"CACHE_DIR" not in ov_config.keys()
and not str(self.parent_model.model_save_dir).startswith(gettempdir())
and "gpu" in self._device.lower()
):
cache_dir = Path(self.parent_model.model_save_dir).joinpath("model_cache")
ov_config["CACHE_DIR"] = str(cache_dir)
if self.request is None:
logger.info(f"Compiling the encoder to {self._device} ...")
self.request = core.compile_model(self.model, self._device, ov_config)
# OPENVINO_LOG_LEVEL can be found in https://docs.openvino.ai/2023.2/openvino_docs_OV_UG_supported_plugins_AUTO_debugging.html
if "OPENVINO_LOG_LEVEL" in os.environ and int(os.environ["OPENVINO_LOG_LEVEL"]) > 2:
logger.info(f"{self._device} SUPPORTED_PROPERTIES:")
_print_compiled_model_properties(self.request)
class OVDecoder:
"""
Decoder model for OpenVINO inference.
Arguments:
request (`openvino.runtime.ie_api.InferRequest`):
The OpenVINO inference request associated to the decoder.
device (`torch.device`):
The device type used by this process.
"""
def __init__(self, model: openvino.runtime.Model, parent_model: OVModelForSeq2SeqLM):
self.model = model
self.parent_model = parent_model
self._device = self.parent_model._device
self.device = torch.device("cpu")
self.input_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.inputs)}
self.key_value_input_names = [key for key in self.input_names if "key_values" in key]
self.output_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.outputs)}
self.key_value_output_names = [key for key in self.output_names if "key_values" in key or "present" in key]
is_legacy = any("past_key_values" in key.get_any_name() for key in self.model.outputs)
if len(self.key_value_input_names) > 0 and not is_legacy:
self.use_past = True
self.num_pkv = 2
else:
self.use_past = False
self.num_pkv = 4
self.request = None
@add_start_docstrings_to_model_forward(DECODER_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor,
encoder_hidden_states: torch.FloatTensor,
encoder_attention_mask: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
) -> Seq2SeqLMOutput:
self._compile()
# Model inputs
inputs = {}
if past_key_values is not None:
# Flatten the past_key_values
past_key_values = tuple(
past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer
)
# Add the past_key_values to the decoder inputs
inputs = dict(zip(self.key_value_input_names, past_key_values))
inputs["input_ids"] = input_ids
# Add the encoder_attention_mask inputs when needed
if "encoder_attention_mask" in self.input_names and encoder_attention_mask is not None:
inputs["encoder_attention_mask"] = encoder_attention_mask
# Add the encoder_hidden_states inputs when needed
if "encoder_hidden_states" in self.input_names and encoder_hidden_states is not None:
inputs["encoder_hidden_states"] = encoder_hidden_states
if "decoder_attention_mask" in self.input_names and decoder_attention_mask is not None:
inputs["decoder_attention_mask"] = decoder_attention_mask
# Run inference
self.request.start_async(inputs, share_inputs=True)
self.request.wait()
logits = torch.from_numpy(self.request.get_tensor("logits").data).to(self.device)
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the
# self-attention layer and 2 to the cross-attention layer)
out_past_key_values = tuple(self.request.get_tensor(key).data for key in self.key_value_output_names)
# Tuple of tuple of length `n_layers`, with each tuple of length equal to:
# * 4 for the decoder without cache (k/v of self-attention + k/v of cross-attention)
# * 2 for the decoder with cache (k/v of self-attention as cross-attention cache is constant)
if self.use_past is False:
out_past_key_values = tuple(
out_past_key_values[i : i + self.num_pkv] for i in range(0, len(out_past_key_values), self.num_pkv)
)
else:
# grab the cross attention key/values from the inputs
out_past_key_values = tuple(
out_past_key_values[i : i + self.num_pkv] + past_key_values[2 * i + 2 : 2 * i + 2 + self.num_pkv]
for i in range(0, len(out_past_key_values), self.num_pkv)
)
return Seq2SeqLMOutput(logits=logits, past_key_values=out_past_key_values)
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def _compile(self):
ov_config = {**self.parent_model.ov_config}
if (
"CACHE_DIR" not in ov_config.keys()
and not str(self.parent_model.model_save_dir).startswith(gettempdir())
and "gpu" in self._device.lower()
):
cache_dir = Path(self.parent_model.model_save_dir).joinpath("model_cache")
ov_config["CACHE_DIR"] = str(cache_dir)
if self.request is None:
logger.info(f"Compiling the decoder to {self._device} ...")
compiled_model = core.compile_model(self.model, self._device, ov_config)
self.request = compiled_model.create_infer_request()
# OPENVINO_LOG_LEVEL can be found in https://docs.openvino.ai/2023.2/openvino_docs_OV_UG_supported_plugins_AUTO_debugging.html
if "OPENVINO_LOG_LEVEL" in os.environ and int(os.environ["OPENVINO_LOG_LEVEL"]) > 2:
logger.info(f"{self._device} SUPPORTED_PROPERTIES:")
_print_compiled_model_properties(compiled_model)
@add_start_docstrings(
"""
Pix2Struct model with a language modeling head for OpenVINO inference.
""",
INPUTS_DOCSTRING,
)
class OVModelForPix2Struct(OVModelForSeq2SeqLM):
auto_model_class = Pix2StructForConditionalGeneration
main_input_name = "flattened_patches"
export_feature = "image-to-text"
def prepare_inputs_for_generation(
self,
input_ids,
flattened_patches: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
past_key_values=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs,
) -> Dict:
if decoder_attention_mask is None:
decoder_attention_mask = torch.ones_like(input_ids).to(input_ids.device)
return {
"flattened_patches": flattened_patches,
"decoder_input_ids": input_ids,
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache,
}
@add_start_docstrings_to_model_forward(
PIX2STRUCT_MODEL_DOCSTRING.format("batch_size, sequence_length")
+ PIX2STRUCT_EXAMPLE.format(
processor_class=_PROCESSOR_FOR_DOC,
model_class="OVModelForPix2Struct",
checkpoint="google/pix2struct-ai2d-base",
)
)
def forward(
self,
flattened_patches: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
**kwargs,
) -> Seq2SeqLMOutput:
return super().forward(
input_ids=flattened_patches,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
**kwargs,
)
def _reshape(self, model: openvino.runtime.Model, batch_size: int, sequence_length: int, is_decoder=True):
shapes = {}
for inputs in model.inputs:
shapes[inputs] = inputs.get_partial_shape()
shapes[inputs][0] = batch_size if not is_decoder else -1
if is_decoder:
if inputs.get_any_name().startswith("past_key_values"):
shapes[inputs][2] = -1
elif not inputs.get_any_name().startswith("encoder"):
shapes[inputs][1] = -1
model.reshape(shapes)
return model
@add_start_docstrings(
"""
Speech Sequence-to-sequence model with a language modeling head for OpenVINO inference. This class officially supports whisper, speech_to_text.
""",
INPUTS_DOCSTRING,
)
class OVModelForSpeechSeq2Seq(OVModelForSeq2SeqLM):
auto_model_class = AutoModelForSpeechSeq2Seq
main_input_name = "input_features"
export_feature = "automatic-speech-recognition"
def prepare_inputs_for_generation(
self,
input_ids,
attention_mask: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
past_key_values=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs,
) -> Dict:
if decoder_attention_mask is None:
decoder_attention_mask = torch.ones_like(input_ids).to(input_ids.device)
return {
"decoder_input_ids": input_ids,
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache,
}
@add_start_docstrings_to_model_forward(
SPEECH_SEQ2SEQ_MODEL_DOCSTRING
+ AUTOMATIC_SPEECH_RECOGNITION_EXAMPLE.format(
processor_class=_PROCESSOR_FOR_DOC,
model_class="OVModelForSpeechSeq2Seq",
checkpoint="openai/whisper-tiny",
)
)
def forward(
self,
input_features: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
**kwargs,
) -> Seq2SeqLMOutput:
return super().forward(
input_ids=input_features,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
**kwargs,
)
@classmethod
def _from_pretrained(
cls,
model_id: Union[str, Path],
config: "PretrainedConfig",
**kwargs,
):
if "WhisperForConditionalGeneration" in config.architectures:
return _OVModelForWhisper._from_pretrained(model_id, config, **kwargs)
else:
return super()._from_pretrained(model_id, config, **kwargs)
class _OVModelForWhisper(OVModelForSpeechSeq2Seq):
"""
Whisper implements its own generate() method.
"""
auto_model_class = WhisperForConditionalGeneration
@classmethod
def _from_pretrained(
cls,
model_id: Union[str, Path],
config: "PretrainedConfig",
**kwargs,
):
return super(OVModelForSpeechSeq2Seq, cls)._from_pretrained(model_id, config, **kwargs)
# Adapted from transformers.models.whisper.modeling_whisper
def generate(
self,
input_features: Optional[torch.Tensor] = None,
generation_config=None,
logits_processor=None,
stopping_criteria=None,
prefix_allowed_tokens_fn=None,
synced_gpus=False,
return_timestamps=None,
task=None,
language=None,
is_multilingual=None,
prompt_ids: Optional[torch.Tensor] = None,
num_segment_frames: Optional[int] = None,
return_token_timestamps: Optional[bool] = None,
return_segments: bool = False,
attention_mask: Optional[torch.Tensor] = None,
time_precision: int = 0.02,
return_dict_in_generate: Optional[bool] = None,
**kwargs,
):
if "inputs" in kwargs:
input_features = kwargs.pop("inputs")
logging.warn(
"The input name `inputs` is deprecated. Please make sure to use `input_features` instead.",
FutureWarning,
)
return_dict_in_generate = (
return_dict_in_generate
if return_dict_in_generate is not None
else self.generation_config.return_dict_in_generate
)
if generation_config is None:
generation_config = copy.deepcopy(self.generation_config)
input_stride = (
1 * 2
) # NOTE: replaced from `self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]`
if num_segment_frames is None:
num_segment_frames = input_stride * self.config.max_source_positions
# 1. Check whether we're in shortform or longform mode
if input_features is not None:
total_input_frames = input_features.shape[-1]
elif "encoder_outputs" in kwargs:
encoder_outputs_shape = (
kwargs["encoder_outputs"][0].shape
if isinstance(kwargs["encoder_outputs"], BaseModelOutput)
else kwargs["encoder_outputs"].shape
)
total_input_frames = encoder_outputs_shape[1] * input_stride
else:
raise ValueError("Make sure to provide either `input_features` or `encoder_outputs` to `generate`.")
is_shortform = total_input_frames <= num_segment_frames
# 2. Make sure the generation config is correctly set depending on whether timestamps are to be returned or not
if return_timestamps is True:
if not hasattr(generation_config, "no_timestamps_token_id"):
raise ValueError(
"You are trying to return timestamps, but the generation config is not properly set. "
"Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. "
"For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
)
generation_config.return_timestamps = return_timestamps
elif not is_shortform:
if return_timestamps is False:
raise ValueError(
"You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which "
"requires the model to predict timestamp tokens. Please either pass `return_timestamps=True` or make sure to pass no more than 3000 mel input features."
)
if not hasattr(generation_config, "no_timestamps_token_id"):
raise ValueError(
"You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which "
"requires the generation config to have `no_timestamps_token_id` correctly. "
"Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. "
"For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
"or make sure to pass no more than 3000 mel input features."
)
logger.info("Setting `return_timestamps=True` for long-form generation.")
generation_config.return_timestamps = True
else:
generation_config.return_timestamps = False
# 3. Make sure to correctly set language-related parameters
if is_multilingual is not None:
if not hasattr(generation_config, "is_multilingual"):
raise ValueError(
"The generation config is outdated and is thus not compatible with the `is_multilingual` argument "
"to `generate`. Please update the generation config as per the instructions "
"https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
)
generation_config.is_multilingual = is_multilingual
if hasattr(generation_config, "is_multilingual") and not generation_config.is_multilingual:
if task is not None or language is not None:
raise ValueError(
"Cannot specify `task` or `language` for an English-only model. If the model is intended to be "
"multilingual, pass `is_multilingual=True` to generate, or update the generation config."
)
if language is not None:
if not hasattr(generation_config, "lang_to_id"):
raise ValueError(
"The generation config is outdated and is thus not compatible with the `language` argument "
"to `generate`. Either set the language using the `forced_decoder_ids` in the model config, "
"or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
)
language = language.lower()
generation_config.language = language
if task is not None:
if not hasattr(generation_config, "task_to_id"):
raise ValueError(
"The generation config is outdated and is thus not compatible with the `task` argument "
"to `generate`. Either set the task using the `forced_decoder_ids` in the model config, "
"or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
)
generation_config.task = task
# 4. Add forced decoder ids depending on passed `language`, `task`,`prompt_ids`, `return_token_timestamps` and `return_timestamps`
forced_decoder_ids = None
# Legacy code for backward compatibility
if hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None:
forced_decoder_ids = self.config.forced_decoder_ids
elif (
hasattr(self.generation_config, "forced_decoder_ids")
and self.generation_config.forced_decoder_ids is not None
):
forced_decoder_ids = self.generation_config.forced_decoder_ids
else:
forced_decoder_ids = kwargs.get("forced_decoder_ids", None)
if task is not None or language is not None or (forced_decoder_ids is None and prompt_ids is not None):
forced_decoder_ids = []
if hasattr(generation_config, "language"):
if generation_config.language in generation_config.lang_to_id.keys():
language_token = generation_config.language
elif generation_config.language in TO_LANGUAGE_CODE.keys():
language_token = f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>"
elif generation_config.language in TO_LANGUAGE_CODE.values():
language_token = f"<|{generation_config.language}|>"
else:
is_language_code = len(generation_config.language) == 2
raise ValueError(
f"Unsupported language: {generation_config.language}. Language should be one of:"
f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}."
)
forced_decoder_ids.append((1, generation_config.lang_to_id[language_token]))
else:
forced_decoder_ids.append((1, None)) # automatically detect the language
if hasattr(generation_config, "task"):
if generation_config.task in TASK_IDS:
forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task]))
else:
raise ValueError(
f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`"
)
elif hasattr(generation_config, "task_to_id"):
forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) # defaults to transcribe
if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps:
idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
if forced_decoder_ids is not None:
generation_config.forced_decoder_ids = forced_decoder_ids
if prompt_ids is not None:
if kwargs.get("decoder_start_token_id") is not None:
raise ValueError(
"When specifying `prompt_ids`, you cannot also specify `decoder_start_token_id` as it gets overwritten."
)
prompt_ids = prompt_ids.tolist()
decoder_start_token_id, *text_prompt_ids = prompt_ids
# Slicing the text prompt ids in a manner consistent with the OpenAI implementation
# to accomodate context space for the prefix (see https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/decoding.py#L599)
text_prompt_ids = text_prompt_ids[-self.config.max_target_positions // 2 - 1 :]
# Set the decoder_start_token_id to <|startofprev|>
kwargs.update({"decoder_start_token_id": decoder_start_token_id})
# If the user passes `max_new_tokens`, increase its number to account for the prompt
if kwargs.get("max_new_tokens", None) is not None:
kwargs["max_new_tokens"] += len(text_prompt_ids)
if kwargs["max_new_tokens"] >= self.config.max_target_positions:
raise ValueError(
f"The length of the sliced `prompt_ids` is {len(text_prompt_ids)}, and the `max_new_tokens` "
f"{kwargs['max_new_tokens'] - len(text_prompt_ids)}. Thus, the combined length of the sliced "
f"`prompt_ids` and `max_new_tokens` is: {kwargs['max_new_tokens']}. This exceeds the "
f"`max_target_positions` of the Whisper model: {self.config.max_target_positions}. "
"You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, "
f"so that their combined length is less that {self.config.max_target_positions}."
)
# Reformat the forced_decoder_ids to incorporate the prompt
non_prompt_forced_decoder_ids = (
kwargs.pop("forced_decoder_ids", None) or generation_config.forced_decoder_ids
)
forced_decoder_ids = [
*text_prompt_ids,
generation_config.decoder_start_token_id,
*[token for _rank, token in non_prompt_forced_decoder_ids],
]
forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_decoder_ids)]
generation_config.forced_decoder_ids = forced_decoder_ids
if return_token_timestamps:
kwargs["output_attentions"] = True
return_dict_in_generate = True
if getattr(generation_config, "task", None) == "translate":
logger.warning("Token-level timestamps may not be reliable for task 'translate'.")
if not hasattr(generation_config, "alignment_heads"):
raise ValueError(
"Model generation config has no `alignment_heads`, token-level timestamps not available. "
"See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config."
)
if kwargs.get("num_frames") is not None:
generation_config.num_frames = kwargs.pop("num_frames")
if generation_config.return_timestamps is True:
last_forced_decoder_ids = (
generation_config.forced_decoder_ids[-1][-1]
if hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids
else None
)
if last_forced_decoder_ids == self.generation_config.no_timestamps_token_id:
# remove no_timestamp to be forcefully generated if we want to return timestamps
# this is also important to make sure `WhisperTimeStampLogitsProcessor` functions correctly
forced_decoder_ids = generation_config.forced_decoder_ids[:-1]
# Make sure that if list is empty we set it to None
generation_config.forced_decoder_ids = None if len(forced_decoder_ids) == 0 else forced_decoder_ids
timestamp_processor = [WhisperTimeStampLogitsProcessor(generation_config)]
logits_processor = (
timestamp_processor if logits_processor is None else timestamp_processor + logits_processor