From 9d33c5d3c8b3ef668d187b0d5b641de38972bdec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Branchaud-Charron?= Date: Fri, 8 Sep 2023 17:28:25 -0400 Subject: [PATCH] Add warnings when using replicate_in_memory=True and MCCachingModule (#273) --- baal/modelwrapper.py | 3 +++ baal/utils/warnings.py | 19 +++++++++++++++++++ notebooks/mccaching_layer.ipynb | 4 +++- tests/bayesian/test_caching.py | 13 +++++++++++++ 4 files changed, 38 insertions(+), 1 deletion(-) create mode 100644 baal/utils/warnings.py diff --git a/baal/modelwrapper.py b/baal/modelwrapper.py index a2acb5ce..d1b36845 100644 --- a/baal/modelwrapper.py +++ b/baal/modelwrapper.py @@ -18,6 +18,7 @@ from baal.utils.cuda_utils import to_cuda from baal.utils.iterutils import map_on_tensor from baal.utils.metrics import Loss +from baal.utils.warnings import raise_warnings_cache_replicated log = structlog.get_logger("ModelWrapper") @@ -49,6 +50,8 @@ def __init__(self, model, criterion, replicate_in_memory=True): self.replicate_in_memory = replicate_in_memory self._active_dataset_size = -1 + raise_warnings_cache_replicated(self.model, replicate_in_memory=replicate_in_memory) + def train_on_dataset( self, dataset, diff --git a/baal/utils/warnings.py b/baal/utils/warnings.py new file mode 100644 index 00000000..2f767746 --- /dev/null +++ b/baal/utils/warnings.py @@ -0,0 +1,19 @@ +import warnings + +from torch import nn + +from baal.bayesian.caching_utils import LRUCacheModule + +WARNING_CACHE_REPLICATED = """ +To use MCCachingModule at maximum effiency, we recommend using + `replicate_in_memory=False`, but it is `True`. +""" + + +def raise_warnings_cache_replicated(module, replicate_in_memory): + if ( + isinstance(module, nn.Module) + and replicate_in_memory + and any(isinstance(m, LRUCacheModule) for m in module.modules()) + ): + warnings.warn(WARNING_CACHE_REPLICATED, UserWarning) diff --git a/notebooks/mccaching_layer.ipynb b/notebooks/mccaching_layer.ipynb index a7e3e568..ff9708c4 100644 --- a/notebooks/mccaching_layer.ipynb +++ b/notebooks/mccaching_layer.ipynb @@ -83,7 +83,9 @@ "source": [ "## Introducing MCCachingModule!\n", "\n", - "By simply wrapping the module with `MCCachingModule` we run the same inference 70% faster!" + "By simply wrapping the module with `MCCachingModule` we run the same inference 70% faster!\n", + "\n", + "**NOTE**: You should *always* use `ModelWrapper(..., replicate_in_memory=False)` when in combination with `MCCachingModule`." ], "metadata": { "collapsed": false diff --git a/tests/bayesian/test_caching.py b/tests/bayesian/test_caching.py index 9e267f89..9c74f8cb 100644 --- a/tests/bayesian/test_caching.py +++ b/tests/bayesian/test_caching.py @@ -1,7 +1,10 @@ +import warnings + import pytest import torch from torch.nn import Sequential, Linear +from baal import ModelWrapper from baal.bayesian.caching_utils import MCCachingModule @@ -50,3 +53,13 @@ def test_caching(my_model): assert LinearMocked.call_count == 20 +def test_caching_warnings(my_model): + my_model = MCCachingModule(my_model) + with warnings.catch_warnings(record=True) as tape: + ModelWrapper(my_model, criterion=None, replicate_in_memory=True) + assert len(tape) == 1 and "MCCachingModule" in str(tape[0].message) + + with warnings.catch_warnings(record=True) as tape: + ModelWrapper(my_model, criterion=None, replicate_in_memory=False) + assert len(tape) == 0 +