Skip to content

Commit

Permalink
Un-balanced logging properly supported (#5119)
Browse files Browse the repository at this point in the history
* resolve bug

* clean code

* resolve comments

* Update tests/trainer/optimization/test_multiple_optimizers.py

Co-authored-by: Rohit Gupta <[email protected]>

* resolve another bug

* add comments

* use abs to find diff

* update

* resolve flake8

Co-authored-by: Rohit Gupta <[email protected]>
  • Loading branch information
2 people authored and awaelchli committed Dec 18, 2020
1 parent b3fc662 commit df8b676
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,13 @@ def check_dataloader_idx(self, result: Result) -> bool:
random_key = list(result.keys())[-1]
return result["meta"][random_key]["dataloader_idx"] is not None

def get_latest_from_func_name(self, latest_result, func_name: str, *args, **kwargs) -> Dict:
def get_latest_from_func_name(self, latest_result_opt, func_name: str, *args, **kwargs) -> Dict:
results = {}
add_dataloader_idx = self.check_dataloader_idx(latest_result)
func = getattr(latest_result, func_name)
results.update(func(*args, add_dataloader_idx=add_dataloader_idx, **kwargs))
for opt_idx in latest_result_opt:
latest_result = latest_result_opt[opt_idx]
add_dataloader_idx = self.check_dataloader_idx(latest_result)
func = getattr(latest_result, func_name)
results.update(func(*args, add_dataloader_idx=add_dataloader_idx, **kwargs))
return results

def run_latest_batch_metrics_with_func_name(self, func_name, *args, **kwargs) -> List[Dict]:
Expand Down Expand Up @@ -156,6 +158,7 @@ def append(self, result, dataloader_idx: Optional[int] = None, extra_info: Optio
assert isinstance(result, Result)
if dataloader_idx is None:
dataloader_idx = 0

if extra_info is None:
extra_info = {}

Expand All @@ -166,22 +169,27 @@ def append(self, result, dataloader_idx: Optional[int] = None, extra_info: Optio
if dataloader_idx not in self._internals:
self._internals[dataloader_idx] = {}
self._internals_reduced[dataloader_idx] = defaultdict(dict)
self._latest_ref[dataloader_idx] = {}

# extract infos
opt_idx = extra_info["opt_idx"]
batch_idx = extra_info["batch_idx"]

self._append_to_structure(self._internals[dataloader_idx], opt_idx, batch_idx, result)

self._latest_ref[dataloader_idx] = result
self._latest_ref[dataloader_idx][opt_idx] = result

# [dataloader_idx] is a list
else:
self._internal_type = ResultStoreType.OUTSIDE_BATCH_TRAIN_LOOP
self._internals.setdefault(dataloader_idx, [])
self._internals[dataloader_idx].append(result)

self._latest_ref[dataloader_idx] = result
if dataloader_idx not in self._latest_ref:
self._latest_ref[dataloader_idx] = {}
self._latest_ref[dataloader_idx][0] = {}

self._latest_ref[dataloader_idx][0] = result

def auto_reduce_results_on_epoch_end(self) -> None:
"""
Expand All @@ -206,13 +214,9 @@ def auto_reduce_results_on_epoch_end(self) -> None:
# TODO: How to start training in middle of epoch
opt_outputs = epoch_metrics[opt_idx]

num_batch_idx = len(self._internals[dl_idx][num_opt_idx]) - 1
assert num_batch_idx >= 0
batch_indexes = self._internals[dl_idx][num_opt_idx].keys()

# reduce across time first
time_reduced_outputs = []
for batch_idx in batch_indexes:
for batch_idx in opt_outputs.keys():
tbptt_outs = opt_outputs[batch_idx]
tbptt_outs = tbptt_outs[0].__class__.reduce_across_time(tbptt_outs)
if len(tbptt_outs) > 1:
Expand Down
63 changes: 63 additions & 0 deletions tests/trainer/optimization/test_multiple_optimizers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests to ensure that the behaviours related to multiple optimizers works
"""
import torch

import pytorch_lightning as pl
from tests.base.boring_model import BoringModel


def test_unbalanced_logging_with_multiple_optimizers(tmpdir):
"""
This tests ensures reduction works in un-balanced logging settings
"""
class TestModel(BoringModel):

loss_1 = []
loss_2 = []

def training_step(self, batch, batch_idx, optimizer_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
if optimizer_idx == 0 and self.trainer.global_step > 10:
self.log("loss_1", loss, on_epoch=True, prog_bar=True)
self.loss_1.append(loss.detach().clone())
elif optimizer_idx == 1:
self.log("loss_2", loss, on_epoch=True, prog_bar=True)
self.loss_2.append(loss.detach().clone())
return {"loss": loss}

def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.001)
optimizer2 = torch.optim.SGD(self.layer.parameters(), lr=0.001)
return [optimizer, optimizer2]

model = TestModel()
model.training_epoch_end = None

# Initialize a trainer
trainer = pl.Trainer(
default_root_dir=tmpdir,
max_epochs=1,
)

trainer.fit(model)

assert torch.equal(trainer.callback_metrics["loss_2_step"], model.loss_2[-1])
assert torch.equal(trainer.callback_metrics["loss_1_step"], model.loss_1[-1])
# test loss are properly reduced
assert torch.abs(trainer.callback_metrics["loss_2_epoch"] - torch.FloatTensor(model.loss_2).mean()) < 1e-6
assert torch.abs(trainer.callback_metrics["loss_1_epoch"] - torch.FloatTensor(model.loss_1).mean()) < 1e-6

0 comments on commit df8b676

Please sign in to comment.