Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix trainer.fit_loop.split_idx reference #8601

Merged
merged 3 commits into from
Jul 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
-


-

- Fixed `trainer.fit_loop.split_idx` always returning `None` ([#8601](https://github.com/PyTorchLightning/pytorch-lightning/pull/8601))

-

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(self) -> None:
self.accumulated_loss: Optional[Tensor] = None
self.batch_outputs: Optional[List[List[STEP_OUTPUT]]] = None
self.running_loss: TensorRunningAccum = TensorRunningAccum(window_length=20)
# the current split index when the batch gets split into chunks in truncated backprop through time
self.split_idx: Optional[int] = None
self.optim_progress = OptimizationProgress()

Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ def __init__(self, min_steps: int, max_steps: int):
self.global_step: int = 0
# the total batch index across all epochs
self.total_batch_idx: int = 0
# the current split index when the batch gets split into chunks in truncated backprop through time
self.split_idx: Optional[int] = None
self.is_last_batch: Optional[bool] = None
self.batch_progress = Progress()
self.scheduler_progress = SchedulerProgress()
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def batch_idx(self) -> int:
@property
def split_idx(self) -> int:
"""Returns the index of the current batch split (within the current batch) for bptt"""
return self.epoch_loop.split_idx
return self.epoch_loop.batch_loop.split_idx

@property
def min_steps(self) -> int:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,9 @@ def on_batch_start(self) -> None:
self._epoch_end_reached = False

def epoch_end_reached(self):
self.trainer.logger_connector._epoch_end_reached = True
self.trainer.logger_connector._batch_idx = None
self.trainer.logger_connector._split_idx = None
self._epoch_end_reached = True
self._batch_idx = None
self._split_idx = None

def on_epoch_end(self) -> None:
assert self._epoch_end_reached
Expand Down
52 changes: 27 additions & 25 deletions tests/trainer/logging_/test_train_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,62 +214,62 @@ def validation_step(self, batch, batch_idx):


def test_tbptt_log(tmpdir):
"""
Tests that only training_step can be used
"""
truncated_bptt_steps = 2
sequence_size = 30
batch_size = 30

x_seq = torch.rand(batch_size, sequence_size, 1)
y_seq_list = torch.rand(batch_size, sequence_size, 1).tolist()
N, T, F = 32, 15, 1 # batches x timesteps (sequence size) x features
batch_size = 10
assert T % truncated_bptt_steps != 0, "Should test leftover time steps"

class MockSeq2SeqDataset(torch.utils.data.Dataset):
def __getitem__(self, i):
return x_seq, y_seq_list
def __init__(self):
self.x_seq = torch.randn(N, T, F)
self.y_seq = torch.randn(N, T, F)

def __getitem__(self, index):
return self.x_seq[index], self.y_seq[index]

def __len__(self):
return 1
return N

class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.test_hidden = None
self.layer = torch.nn.Linear(2, 2)
self.layer = torch.nn.LSTM(input_size=F, hidden_size=T, batch_first=True)

def training_step(self, batch, batch_idx, hiddens):
assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps"
if hiddens is not None:
assert hiddens.grad_fn is None
self.test_hidden = torch.tensor(2.0, requires_grad=True).pow(2)
split_idx = self.trainer.fit_loop.split_idx
self.test_hidden = torch.tensor(split_idx, requires_grad=True, dtype=torch.float).pow(2)

x_tensor, y_list = batch
assert x_tensor.shape[1] == truncated_bptt_steps, "tbptt split Tensor failed"
x, y = batch
if self.trainer.fit_loop.epoch_loop.batch_loop.done:
# last split idx, not aligned
assert x.shape[1] == T % truncated_bptt_steps
assert y.shape[1] == T % truncated_bptt_steps
else:
assert x.shape[1] == truncated_bptt_steps
assert y.shape[1] == truncated_bptt_steps

y_tensor = torch.tensor(y_list, dtype=x_tensor.dtype)
assert y_tensor.shape[1] == truncated_bptt_steps, "tbptt split list failed"

pred = self(x_tensor.view(batch_size, truncated_bptt_steps))
loss = torch.nn.functional.mse_loss(pred, y_tensor.view(batch_size, truncated_bptt_steps))
pred, _ = self(x)
loss = torch.nn.functional.mse_loss(pred, y)

self.log("a", loss, on_epoch=True)

return {"loss": loss, "hiddens": self.test_hidden}

def on_train_epoch_start(self) -> None:
def on_train_batch_start(self, *args, **kwargs) -> None:
self.test_hidden = None

def train_dataloader(self):
return torch.utils.data.DataLoader(
dataset=MockSeq2SeqDataset(), batch_size=batch_size, shuffle=False, sampler=None
)
return torch.utils.data.DataLoader(dataset=MockSeq2SeqDataset(), batch_size=batch_size)

model = TestModel()
model.training_epoch_end = None

trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=10,
limit_val_batches=0,
truncated_bptt_steps=truncated_bptt_steps,
max_epochs=2,
Expand All @@ -278,6 +278,8 @@ def train_dataloader(self):
)
trainer.fit(model)

assert trainer.fit_loop.batch_idx == N // batch_size
assert trainer.fit_loop.split_idx == T // truncated_bptt_steps
assert set(trainer.logged_metrics) == {"a_step", "a_epoch", "epoch"}


Expand Down