Skip to content

Commit

Permalink
Makes automatic optimization a model attribute (#4602)
Browse files Browse the repository at this point in the history
* Makes automatic optimization a model attribute

* Update trainer.py

* remove setting property in model

* Update pytorch_lightning/core/lightning.py

Co-authored-by: Rohit Gupta <[email protected]>

* Update pytorch_lightning/trainer/trainer.py

Co-authored-by: Rohit Gupta <[email protected]>

* Update trainer.py

Co-authored-by: Sean Naren <[email protected]>
Co-authored-by: Roger Shieh <[email protected]>
Co-authored-by: Rohit Gupta <[email protected]>
Co-authored-by: Jeff Yang <[email protected]>
  • Loading branch information
5 people authored Nov 14, 2020
1 parent 144a5c9 commit e04e7c9
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 2 deletions.
7 changes: 7 additions & 0 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,13 @@ def on_gpu(self):
"""
return self.device.type == "cuda"

@property
def automatic_optimization(self) -> bool:
"""
If False you are responsible for calling .backward, .step, zero_grad.
"""
return True

def print(self, *args, **kwargs) -> None:
r"""
Prints only from process 0. Use this in any distributed mode to log only once.
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/trainer/connectors/model_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def copy_trainer_model_properties(self, model):
else:
ref_model = model

automatic_optimization = ref_model.automatic_optimization and self.trainer.train_loop.automatic_optimization
self.trainer.train_loop.automatic_optimization = automatic_optimization

for m in [model, ref_model]:
m.trainer = self.trainer
m.logger = self.trainer.logger
Expand Down
14 changes: 12 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def __init__(
amp_backend: str = 'native',
amp_level: str = 'O2',
distributed_backend: Optional[str] = None,
automatic_optimization: bool = True,
automatic_optimization: Optional[bool] = None,
move_metrics_to_cpu: bool = False,
):
r"""
Expand Down Expand Up @@ -212,7 +212,9 @@ def __init__(
log_every_n_steps: How often to log within steps (defaults to every 50 steps).
automatic_optimization: If False you are responsible for calling .backward, .step, zero_grad.
Meant to be used with multiple optimizers by advanced users.
If False you are responsible for calling .backward, .step, zero_grad in LightningModule.
This argument has been moved to LightningModule. It is deprecated here in v1.1 and
will be removed in v1.3.
prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data.
Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data
Expand Down Expand Up @@ -355,6 +357,14 @@ def __init__(
)

# init train loop related flags
# TODO: deprecate in 1.2.0
if automatic_optimization is None:
automatic_optimization = True
else:
rank_zero_warn(
"Disable automatic optimization with the trainer flag is deprecated and will be removed in v1.3.0!"
"Please use the property on the LightningModule for disabling automatic optimization"
)
self.train_loop.on_trainer_init(
max_epochs,
min_epochs,
Expand Down

0 comments on commit e04e7c9

Please sign in to comment.