Skip to content

Commit

Permalink
Fix trainer.fit_loop.split_idx reference (#8601)
Browse files Browse the repository at this point in the history
* Fix split idx reference

* Update CHANGELOG

* Add comment
  • Loading branch information
carmocca authored and awaelchli committed Jul 31, 2021
1 parent c7f8c8c commit c623757
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 29 deletions.
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def batch_idx(self) -> int:
@property
def split_idx(self) -> int:
"""Returns the index of the current batch split (within the current batch) for bptt"""
return self.epoch_loop.split_idx
return self.epoch_loop.batch_loop.split_idx

@property
def min_steps(self) -> int:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,9 @@ def on_batch_start(self) -> None:
self._epoch_end_reached = False

def epoch_end_reached(self):
self.trainer.logger_connector._epoch_end_reached = True
self.trainer.logger_connector._batch_idx = None
self.trainer.logger_connector._split_idx = None
self._epoch_end_reached = True
self._batch_idx = None
self._split_idx = None

def on_epoch_end(self) -> None:
assert self._epoch_end_reached
Expand Down
52 changes: 27 additions & 25 deletions tests/trainer/logging_/test_train_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,62 +214,62 @@ def validation_step(self, batch, batch_idx):


def test_tbptt_log(tmpdir):
"""
Tests that only training_step can be used
"""
truncated_bptt_steps = 2
sequence_size = 30
batch_size = 30

x_seq = torch.rand(batch_size, sequence_size, 1)
y_seq_list = torch.rand(batch_size, sequence_size, 1).tolist()
N, T, F = 32, 15, 1 # batches x timesteps (sequence size) x features
batch_size = 10
assert T % truncated_bptt_steps != 0, "Should test leftover time steps"

class MockSeq2SeqDataset(torch.utils.data.Dataset):
def __getitem__(self, i):
return x_seq, y_seq_list
def __init__(self):
self.x_seq = torch.randn(N, T, F)
self.y_seq = torch.randn(N, T, F)

def __getitem__(self, index):
return self.x_seq[index], self.y_seq[index]

def __len__(self):
return 1
return N

class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.test_hidden = None
self.layer = torch.nn.Linear(2, 2)
self.layer = torch.nn.LSTM(input_size=F, hidden_size=T, batch_first=True)

def training_step(self, batch, batch_idx, hiddens):
assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps"
if hiddens is not None:
assert hiddens.grad_fn is None
self.test_hidden = torch.tensor(2.0, requires_grad=True).pow(2)
split_idx = self.trainer.fit_loop.split_idx
self.test_hidden = torch.tensor(split_idx, requires_grad=True, dtype=torch.float).pow(2)

x_tensor, y_list = batch
assert x_tensor.shape[1] == truncated_bptt_steps, "tbptt split Tensor failed"
x, y = batch
if self.trainer.fit_loop.epoch_loop.batch_loop.done:
# last split idx, not aligned
assert x.shape[1] == T % truncated_bptt_steps
assert y.shape[1] == T % truncated_bptt_steps
else:
assert x.shape[1] == truncated_bptt_steps
assert y.shape[1] == truncated_bptt_steps

y_tensor = torch.tensor(y_list, dtype=x_tensor.dtype)
assert y_tensor.shape[1] == truncated_bptt_steps, "tbptt split list failed"

pred = self(x_tensor.view(batch_size, truncated_bptt_steps))
loss = torch.nn.functional.mse_loss(pred, y_tensor.view(batch_size, truncated_bptt_steps))
pred, _ = self(x)
loss = torch.nn.functional.mse_loss(pred, y)

self.log("a", loss, on_epoch=True)

return {"loss": loss, "hiddens": self.test_hidden}

def on_train_epoch_start(self) -> None:
def on_train_batch_start(self, *args, **kwargs) -> None:
self.test_hidden = None

def train_dataloader(self):
return torch.utils.data.DataLoader(
dataset=MockSeq2SeqDataset(), batch_size=batch_size, shuffle=False, sampler=None
)
return torch.utils.data.DataLoader(dataset=MockSeq2SeqDataset(), batch_size=batch_size)

model = TestModel()
model.training_epoch_end = None

trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=10,
limit_val_batches=0,
truncated_bptt_steps=truncated_bptt_steps,
max_epochs=2,
Expand All @@ -278,6 +278,8 @@ def train_dataloader(self):
)
trainer.fit(model)

assert trainer.fit_loop.batch_idx == N // batch_size
assert trainer.fit_loop.split_idx == T // truncated_bptt_steps
assert set(trainer.logged_metrics) == {"a_step", "a_epoch", "epoch"}


Expand Down

0 comments on commit c623757

Please sign in to comment.