diff --git a/src/baal/bayesian/dropout.py b/src/baal/bayesian/dropout.py index 98154bb4..4652d088 100644 --- a/src/baal/bayesian/dropout.py +++ b/src/baal/bayesian/dropout.py @@ -129,7 +129,7 @@ def _patch_dropout_layers(module: torch.nn.Module) -> bool: module.add_module(name, new_module) # recursively apply to child - changed = changed or _patch_dropout_layers(child) + changed |= _patch_dropout_layers(child) return changed diff --git a/tests/bayesian/dropout_test.py b/tests/bayesian/dropout_test.py index b453ba82..ff59d3b9 100644 --- a/tests/bayesian/dropout_test.py +++ b/tests/bayesian/dropout_test.py @@ -2,9 +2,25 @@ import pytest import torch + import baal.bayesian.dropout +@pytest.fixture +def a_model_with_dropout(): + return torch.nn.Sequential( + torch.nn.Linear(10, 5), + torch.nn.ReLU(), + torch.nn.Sequential( + torch.nn.Dropout(p=0.5), + torch.nn.Linear(5, 5), + torch.nn.ReLU(), ), + torch.nn.Sequential( + torch.nn.Dropout(p=0.5), + torch.nn.Linear(5, 2), + )) + + def test_1d_eval_remains_stochastic(): dummy_input = torch.randn(8, 10) test_module = torch.nn.Sequential( @@ -42,19 +58,11 @@ def test_2d_eval_remains_stochastic(): @pytest.mark.parametrize("inplace", (True, False)) -def test_patch_module_replaces_all_dropout_layers(inplace): - - test_module = torch.nn.Sequential( - torch.nn.Linear(10, 5), - torch.nn.ReLU(), - torch.nn.Dropout(p=0.5), - torch.nn.Linear(5, 2), - ) - - mc_test_module = baal.bayesian.dropout.patch_module(test_module, inplace=inplace) +def test_patch_module_replaces_all_dropout_layers(inplace, a_model_with_dropout): + mc_test_module = baal.bayesian.dropout.patch_module(a_model_with_dropout, inplace=inplace) # objects should be the same if inplace is True and not otherwise: - assert (mc_test_module is test_module) == inplace + assert (mc_test_module is a_model_with_dropout) == inplace assert not any( isinstance(module, torch.nn.Dropout) for module in mc_test_module.modules() ) @@ -63,9 +71,9 @@ def test_patch_module_replaces_all_dropout_layers(inplace): for module in mc_test_module.modules() ) + @pytest.mark.parametrize("inplace", (True, False)) def test_patch_module_raise_warnings(inplace): - test_module = torch.nn.Sequential( torch.nn.Linear(10, 5), torch.nn.ReLU(), @@ -78,22 +86,17 @@ def test_patch_module_raise_warnings(inplace): assert issubclass(w[-1].category, UserWarning) assert "No layer was modified by patch_module" in str(w[-1].message) -def test_module_class_replaces_dropout_layers(): + +def test_module_class_replaces_dropout_layers(a_model_with_dropout): dummy_input = torch.randn(8, 10) - test_module = torch.nn.Sequential( - torch.nn.Linear(10, 5), - torch.nn.ReLU(), - torch.nn.Dropout(p=0.5), - torch.nn.Linear(5, 2), - ) - test_mc_module = baal.bayesian.dropout.MCDropoutModule(test_module) + test_mc_module = baal.bayesian.dropout.MCDropoutModule(a_model_with_dropout) assert not any( - isinstance(module, torch.nn.Dropout) for module in test_module.modules() + isinstance(module, torch.nn.Dropout) for module in a_model_with_dropout.modules() ) assert any( isinstance(module, baal.bayesian.dropout.Dropout) - for module in test_module.modules() + for module in a_model_with_dropout.modules() ) torch.manual_seed(2019) with torch.no_grad(): @@ -101,3 +104,7 @@ def test_module_class_replaces_dropout_layers(): (test_mc_module(dummy_input) == test_mc_module(dummy_input)).all() for _ in range(10) ) + + +if __name__ == '__main__': + pytest.main()