Skip to content

Commit

Permalink
Fix the test case; disable assert error suppression
Browse files Browse the repository at this point in the history
  • Loading branch information
ceshine committed Apr 9, 2021
1 parent e35192d commit c69c5a1
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
10 changes: 5 additions & 5 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,11 +633,11 @@ def run_train(self) -> None:
if not self.interrupted:
self.state = TrainerState.INTERRUPTED
self.on_keyboard_interrupt()
except (RuntimeError, AssertionError):
# if an exception is raised, the finally block is executed and can hide the actual exception
# that was initially raised if `on_train_end` also raises an exception. we want to avoid that
# for assertions and other runtime errors so we aren't misled while debugging
print_exc()
# except (RuntimeError, AssertionError):
# # if an exception is raised, the finally block is executed and can hide the actual exception
# # that was initially raised if `on_train_end` also raises an exception. we want to avoid that
# # for assertions and other runtime errors so we aren't misled while debugging
# print_exc()
finally:
# hook
self.train_loop.on_train_end()
Expand Down
8 changes: 4 additions & 4 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,13 +918,13 @@ def test_gradient_clipping_by_value(tmpdir):

model = BoringModel()

grad_clip_val = 0.0001
grad_clip_val = 1e-5
trainer = Trainer(
max_steps=10,
max_epochs=1,
gradient_clip_val=grad_clip_val,
gradient_clip_algorithm='value',
default_root_dir=tmpdir,
default_root_dir=tmpdir
)

trainer.train_loop.old_training_step_and_backward = trainer.train_loop.training_step_and_backward
Expand All @@ -938,8 +938,8 @@ def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hidde
parameters = model.parameters()
grad_max_list = [torch.max(p.grad.detach().abs()) for p in parameters]
grad_max = torch.max(torch.stack(grad_max_list))
assert round(grad_max.item(), 6) <= grad_clip_val, \
f"Gradient max value {grad_max} > grad_clip_val {grad_clip_val} ."
assert abs(round(grad_max.item(), 6) - grad_clip_val) < 1e-6, \
f"Gradient max value {grad_max} != grad_clip_val {grad_clip_val} ."

return ret_val

Expand Down

0 comments on commit c69c5a1

Please sign in to comment.