Skip to content

Commit

Permalink
Fix model templates and use less than 119 chars (#9684)
Browse files Browse the repository at this point in the history
* Fix model templates and use less than 119 chars

* Missing new line
  • Loading branch information
sgugger authored Jan 19, 2021
1 parent 2ebbbf5 commit 7e662e6
Show file tree
Hide file tree
Showing 13 changed files with 32 additions and 12 deletions.
3 changes: 2 additions & 1 deletion src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,7 +997,8 @@ def forward(

if use_cache:
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)
use_cache = False

Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,8 @@ def forward(

if use_cache:
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)
use_cache = False

Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/blenderbot/modeling_blenderbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,7 +959,8 @@ def forward(

if use_cache:
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)
use_cache = False

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -959,7 +959,8 @@ def forward(

if use_cache:
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)
use_cache = False

Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/electra/modeling_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,8 @@ def forward(

if use_cache:
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)
use_cache = False

Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,8 @@ def forward(

if use_cache:
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)
use_cache = False

Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/layoutlm/modeling_layoutlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,8 @@ def forward(

if use_cache:
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)
use_cache = False

Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/led/modeling_led.py
Original file line number Diff line number Diff line change
Expand Up @@ -1924,7 +1924,8 @@ def forward(

if use_cache:
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)
use_cache = False

Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/marian/modeling_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,8 @@ def forward(

if use_cache:
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)
use_cache = False

Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/mbart/modeling_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,7 +1006,8 @@ def forward(

if use_cache:
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)
use_cache = False

Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/pegasus/modeling_pegasus.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,7 +970,8 @@ def forward(

if use_cache:
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)
use_cache = False

Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/roberta/modeling_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,8 @@ def forward(

if use_cache:
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)
use_cache = False

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -526,8 +526,16 @@ def forward(

layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None

if getattr(self.config, "gradient_checkpointing", False) and self.training:

if use_cache:
logger.warn(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)
use_cache = False

def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
Expand Down

0 comments on commit 7e662e6

Please sign in to comment.