Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#146 Fix issue where at most a single submodule was affected by Dropout #147

Merged
merged 1 commit into from
Sep 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/baal/bayesian/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def _patch_dropout_layers(module: torch.nn.Module) -> bool:
module.add_module(name, new_module)

# recursively apply to child
changed = changed or _patch_dropout_layers(child)
changed |= _patch_dropout_layers(child)
return changed


Expand Down
51 changes: 29 additions & 22 deletions tests/bayesian/dropout_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,25 @@

import pytest
import torch

import baal.bayesian.dropout


@pytest.fixture
def a_model_with_dropout():
return torch.nn.Sequential(
torch.nn.Linear(10, 5),
torch.nn.ReLU(),
torch.nn.Sequential(
torch.nn.Dropout(p=0.5),
torch.nn.Linear(5, 5),
torch.nn.ReLU(), ),
torch.nn.Sequential(
torch.nn.Dropout(p=0.5),
torch.nn.Linear(5, 2),
))


def test_1d_eval_remains_stochastic():
dummy_input = torch.randn(8, 10)
test_module = torch.nn.Sequential(
Expand Down Expand Up @@ -42,19 +58,11 @@ def test_2d_eval_remains_stochastic():


@pytest.mark.parametrize("inplace", (True, False))
def test_patch_module_replaces_all_dropout_layers(inplace):

test_module = torch.nn.Sequential(
torch.nn.Linear(10, 5),
torch.nn.ReLU(),
torch.nn.Dropout(p=0.5),
torch.nn.Linear(5, 2),
)

mc_test_module = baal.bayesian.dropout.patch_module(test_module, inplace=inplace)
def test_patch_module_replaces_all_dropout_layers(inplace, a_model_with_dropout):
mc_test_module = baal.bayesian.dropout.patch_module(a_model_with_dropout, inplace=inplace)

# objects should be the same if inplace is True and not otherwise:
assert (mc_test_module is test_module) == inplace
assert (mc_test_module is a_model_with_dropout) == inplace
assert not any(
isinstance(module, torch.nn.Dropout) for module in mc_test_module.modules()
)
Expand All @@ -63,9 +71,9 @@ def test_patch_module_replaces_all_dropout_layers(inplace):
for module in mc_test_module.modules()
)


@pytest.mark.parametrize("inplace", (True, False))
def test_patch_module_raise_warnings(inplace):

test_module = torch.nn.Sequential(
torch.nn.Linear(10, 5),
torch.nn.ReLU(),
Expand All @@ -78,26 +86,25 @@ def test_patch_module_raise_warnings(inplace):
assert issubclass(w[-1].category, UserWarning)
assert "No layer was modified by patch_module" in str(w[-1].message)

def test_module_class_replaces_dropout_layers():

def test_module_class_replaces_dropout_layers(a_model_with_dropout):
dummy_input = torch.randn(8, 10)
test_module = torch.nn.Sequential(
torch.nn.Linear(10, 5),
torch.nn.ReLU(),
torch.nn.Dropout(p=0.5),
torch.nn.Linear(5, 2),
)
test_mc_module = baal.bayesian.dropout.MCDropoutModule(test_module)
test_mc_module = baal.bayesian.dropout.MCDropoutModule(a_model_with_dropout)

assert not any(
isinstance(module, torch.nn.Dropout) for module in test_module.modules()
isinstance(module, torch.nn.Dropout) for module in a_model_with_dropout.modules()
)
assert any(
isinstance(module, baal.bayesian.dropout.Dropout)
for module in test_module.modules()
for module in a_model_with_dropout.modules()
)
torch.manual_seed(2019)
with torch.no_grad():
assert not all(
(test_mc_module(dummy_input) == test_mc_module(dummy_input)).all()
for _ in range(10)
)


if __name__ == '__main__':
pytest.main()