Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Rename sanity_checks to confidence_checks (#5201)
Browse files Browse the repository at this point in the history
* renaming sanity_checks to confidence_checks

* update changelog

* docs fix

* clean up
  • Loading branch information
AkshitaB authored May 14, 2021
1 parent db8ff67 commit cccb35d
Show file tree
Hide file tree
Showing 25 changed files with 96 additions and 53 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Use `dist_reduce_sum` in distributed metrics.
- Allow Google Cloud Storage paths in `cached_path` ("gs://...").
- Print the first batch to the console by default.
- Renamed `sanity_checks` to `confidence_checks` (`sanity_checks` is deprecated and will be removed in AllenNLP 3.0).

### Added

- Added `TaskSuite` base class and command line functionality for running [`checklist`](https://github.com/marcotcr/checklist) test suites, along with implementations for `SentimentAnalysisSuite`, `QuestionAnsweringSuite`, and `TextualEntailmentSuite`. These can be found in the `allennlp.sanity_checks.task_checklists` module.
- Added `TaskSuite` base class and command line functionality for running [`checklist`](https://github.com/marcotcr/checklist) test suites, along with implementations for `SentimentAnalysisSuite`, `QuestionAnsweringSuite`, and `TextualEntailmentSuite`. These can be found in the `allennlp.confidence_checks.task_checklists` module.
- Added `allennlp diff` command to compute a diff on model checkpoints, analogous to what `git diff` does on two files.
- Added `allennlp.nn.util.load_state_dict` helper function.
- Added a way to avoid downloading and loading pretrained weights in modules that wrap transformers
Expand Down
6 changes: 3 additions & 3 deletions allennlp/commands/checklist.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
The `checklist` subcommand allows you to sanity check your
model's predictions using a trained model and its
The `checklist` subcommand allows you to conduct behavioural
testing for your model's predictions using a trained model and its
[`Predictor`](../predictors/predictor.md#predictor) wrapper.
"""

Expand All @@ -15,7 +15,7 @@
from allennlp.common.checks import check_for_gpu, ConfigurationError
from allennlp.models.archival import load_archive
from allennlp.predictors.predictor import Predictor
from allennlp.sanity_checks.task_checklists.task_suite import TaskSuite
from allennlp.confidence_checks.task_checklists.task_suite import TaskSuite


@Subcommand.register("checklist")
Expand Down
2 changes: 1 addition & 1 deletion allennlp/common/testing/checklist_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional
from checklist.test_suite import TestSuite
from checklist.test_types import MFT as MinimumFunctionalityTest
from allennlp.sanity_checks.task_checklists.task_suite import TaskSuite
from allennlp.confidence_checks.task_checklists.task_suite import TaskSuite


@TaskSuite.register("fake-task-suite")
Expand Down
2 changes: 1 addition & 1 deletion allennlp/common/testing/model_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from allennlp.data.batch import Batch
from allennlp.models import load_archive, Model
from allennlp.training import GradientDescentTrainer
from allennlp.sanity_checks.normalization_bias_verification import NormalizationBiasVerification
from allennlp.confidence_checks.normalization_bias_verification import NormalizationBiasVerification


class ModelTestCase(AllenNlpTestCase):
Expand Down
2 changes: 2 additions & 0 deletions allennlp/confidence_checks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from allennlp.confidence_checks.verification_base import VerificationBase
from allennlp.confidence_checks.normalization_bias_verification import NormalizationBiasVerification
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
from torch import nn as nn
from typing import Tuple, List, Callable
from allennlp.sanity_checks.verification_base import VerificationBase
from allennlp.confidence_checks.verification_base import VerificationBase
import logging

logger = logging.getLogger(__name__)
Expand Down
10 changes: 10 additions & 0 deletions allennlp/confidence_checks/task_checklists/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from allennlp.confidence_checks.task_checklists.task_suite import TaskSuite
from allennlp.confidence_checks.task_checklists.sentiment_analysis_suite import (
SentimentAnalysisSuite,
)
from allennlp.confidence_checks.task_checklists.question_answering_suite import (
QuestionAnsweringSuite,
)
from allennlp.confidence_checks.task_checklists.textual_entailment_suite import (
TextualEntailmentSuite,
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from checklist.test_suite import TestSuite
from checklist.test_types import MFT
from checklist.perturb import Perturb
from allennlp.sanity_checks.task_checklists.task_suite import TaskSuite
from allennlp.sanity_checks.task_checklists import utils
from allennlp.confidence_checks.task_checklists.task_suite import TaskSuite
from allennlp.confidence_checks.task_checklists import utils


def _crossproduct(template: CheckListTemplate):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from checklist.test_types import MFT, INV, DIR, Expect
from checklist.editor import Editor
from checklist.perturb import Perturb
from allennlp.sanity_checks.task_checklists.task_suite import TaskSuite
from allennlp.sanity_checks.task_checklists import utils
from allennlp.confidence_checks.task_checklists.task_suite import TaskSuite
from allennlp.confidence_checks.task_checklists import utils
from allennlp.data.instance import Instance


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from allennlp.common.registrable import Registrable
from allennlp.common.file_utils import cached_path
from allennlp.predictors.predictor import Predictor
from allennlp.sanity_checks.task_checklists import utils
from allennlp.confidence_checks.task_checklists import utils

logger = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from checklist.test_suite import TestSuite
from checklist.test_types import MFT
from checklist.perturb import Perturb
from allennlp.sanity_checks.task_checklists.task_suite import TaskSuite
from allennlp.sanity_checks.task_checklists import utils
from allennlp.confidence_checks.task_checklists.task_suite import TaskSuite
from allennlp.confidence_checks.task_checklists import utils


def _wrap_apply_to_each(perturb_fn: Callable, both: bool = False, *args, **kwargs):
Expand Down
11 changes: 9 additions & 2 deletions allennlp/sanity_checks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,9 @@
from allennlp.sanity_checks.verification_base import VerificationBase
from allennlp.sanity_checks.normalization_bias_verification import NormalizationBiasVerification
from allennlp.confidence_checks.verification_base import VerificationBase
from allennlp.confidence_checks.normalization_bias_verification import NormalizationBiasVerification

import warnings

warnings.warn(
"Module 'sanity_checks' is deprecated, please use 'confidence_checks' instead.",
DeprecationWarning,
)
8 changes: 4 additions & 4 deletions allennlp/sanity_checks/task_checklists/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from allennlp.sanity_checks.task_checklists.task_suite import TaskSuite
from allennlp.sanity_checks.task_checklists.sentiment_analysis_suite import (
from allennlp.confidence_checks.task_checklists.task_suite import TaskSuite
from allennlp.confidence_checks.task_checklists.sentiment_analysis_suite import (
SentimentAnalysisSuite,
)
from allennlp.sanity_checks.task_checklists.question_answering_suite import (
from allennlp.confidence_checks.task_checklists.question_answering_suite import (
QuestionAnsweringSuite,
)
from allennlp.sanity_checks.task_checklists.textual_entailment_suite import (
from allennlp.confidence_checks.task_checklists.textual_entailment_suite import (
TextualEntailmentSuite,
)
2 changes: 1 addition & 1 deletion allennlp/training/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from allennlp.training.callbacks.callback import TrainerCallback
from allennlp.training.callbacks.console_logger import ConsoleLoggerCallback
from allennlp.training.callbacks.sanity_checks import SanityChecksCallback
from allennlp.training.callbacks.confidence_checks import ConfidenceChecksCallback
from allennlp.training.callbacks.tensorboard import TensorBoardCallback
from allennlp.training.callbacks.track_epoch import TrackEpochCallback
from allennlp.training.callbacks.wandb import WandBCallback
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,27 @@

from allennlp.training.callbacks.callback import TrainerCallback
from allennlp.data import TensorDict
from allennlp.sanity_checks.normalization_bias_verification import NormalizationBiasVerification
from allennlp.confidence_checks.normalization_bias_verification import NormalizationBiasVerification


if TYPE_CHECKING:
from allennlp.training.trainer import GradientDescentTrainer


# `sanity_checks` is deprecated and will be removed.
@TrainerCallback.register("sanity_checks")
class SanityChecksCallback(TrainerCallback):
@TrainerCallback.register("confidence_checks")
class ConfidenceChecksCallback(TrainerCallback):
"""
Performs model sanity checks.
Performs model confidence checks.
Checks performed:
* `NormalizationBiasVerification` for detecting invalid combinations of
bias and normalization layers.
See `allennlp.sanity_checks.normalization_bias_verification` for more details.
See `allennlp.confidence_checks.normalization_bias_verification` for more details.
Note: Any new sanity checks should also be added to this callback.
Note: Any new confidence checks should also be added to this callback.
"""

def on_start(
Expand Down Expand Up @@ -54,18 +56,18 @@ def on_batch(
self._verification.destroy_hooks()
detected_pairs = self._verification.collect_detections()
if len(detected_pairs) > 0:
raise SanityCheckError(
raise ConfidenceCheckError(
"The NormalizationBiasVerification check failed. See logs for more details."
)


class SanityCheckError(Exception):
class ConfidenceCheckError(Exception):
"""
The error type raised when a sanity check fails.
The error type raised when a confidence check fails.
"""

def __init__(self, message) -> None:
super().__init__(
message
+ "\nYou can disable these checks by setting the trainer parameter `run_sanity_checks` to `False`."
+ "\nYou can disable these checks by setting the trainer parameter `run_confidence_checks` to `False`."
)
37 changes: 28 additions & 9 deletions allennlp/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re
import time
import traceback
import warnings
from contextlib import contextmanager
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Type

Expand All @@ -23,7 +24,11 @@
from allennlp.data import DataLoader, TensorDict
from allennlp.models.model import Model
from allennlp.training import util as training_util
from allennlp.training.callbacks import TrainerCallback, SanityChecksCallback, ConsoleLoggerCallback
from allennlp.training.callbacks import (
TrainerCallback,
ConfidenceChecksCallback,
ConsoleLoggerCallback,
)
from allennlp.training.checkpointer import Checkpointer
from allennlp.training.learning_rate_schedulers import LearningRateScheduler
from allennlp.training.metric_tracker import MetricTracker
Expand Down Expand Up @@ -263,10 +268,13 @@ class GradientDescentTrainer(Trainer):
addition to any other callbacks listed in the `callbacks` parameter.
When set to `False`, `DEFAULT_CALLBACKS` are not used.
run_confidence_checks : `bool`, optional (default = `True`)
Determines whether model confidence checks, such as
[`NormalizationBiasVerification`](../../confidence_checks/normalization_bias_verification/),
are run.
run_sanity_checks : `bool`, optional (default = `True`)
Determines whether model sanity checks, such as
[`NormalizationBiasVerification`](../../sanity_checks/normalization_bias_verification/),
are ran.
This parameter is deprecated. Please use `run_confidence_checks` instead.
"""

Expand Down Expand Up @@ -294,7 +302,8 @@ def __init__(
num_gradient_accumulation_steps: int = 1,
use_amp: bool = False,
enable_default_callbacks: bool = True,
run_sanity_checks: bool = True,
run_confidence_checks: bool = True,
**kwargs,
) -> None:
super().__init__(
serialization_dir=serialization_dir,
Expand All @@ -304,6 +313,13 @@ def __init__(
world_size=world_size,
)

if "run_sanity_checks" in kwargs:
warnings.warn(
"'run_sanity_checks' is deprecated, please use 'run_confidence_checks' instead.",
DeprecationWarning,
)
run_confidence_checks = kwargs["run_sanity_checks"]

# I am not calling move_to_gpu here, because if the model is
# not already on the GPU then the optimizer is going to be wrong.
self.model = model
Expand Down Expand Up @@ -345,8 +361,9 @@ def __init__(

self._callbacks = callbacks or []
default_callbacks = list(DEFAULT_CALLBACKS) if enable_default_callbacks else []
if run_sanity_checks:
default_callbacks.append(SanityChecksCallback)

if run_confidence_checks:
default_callbacks.append(ConfidenceChecksCallback)
for callback_cls in default_callbacks:
for callback in self._callbacks:
if callback.__class__ == callback_cls:
Expand Down Expand Up @@ -1014,7 +1031,8 @@ def from_partial_objects(
checkpointer: Lazy[Checkpointer] = Lazy(Checkpointer),
callbacks: List[Lazy[TrainerCallback]] = None,
enable_default_callbacks: bool = True,
run_sanity_checks: bool = True,
run_confidence_checks: bool = True,
**kwargs,
) -> "Trainer":
"""
This method exists so that we can have a documented method to construct this class using
Expand Down Expand Up @@ -1106,7 +1124,8 @@ def from_partial_objects(
num_gradient_accumulation_steps=num_gradient_accumulation_steps,
use_amp=use_amp,
enable_default_callbacks=enable_default_callbacks,
run_sanity_checks=run_sanity_checks,
run_confidence_checks=run_confidence_checks,
**kwargs,
)


Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import torch

from allennlp.common.testing import AllenNlpTestCase
from allennlp.common.testing.sanity_check_test import (
from allennlp.common.testing.confidence_check_test import (
FakeModelForTestingNormalizationBiasVerification,
)
from allennlp.sanity_checks.normalization_bias_verification import NormalizationBiasVerification
from allennlp.confidence_checks.normalization_bias_verification import NormalizationBiasVerification


class TestNormalizationBiasVerification(AllenNlpTestCase):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from allennlp.sanity_checks.task_checklists.sentiment_analysis_suite import SentimentAnalysisSuite
from allennlp.confidence_checks.task_checklists.sentiment_analysis_suite import (
SentimentAnalysisSuite,
)
from allennlp.common.testing import AllenNlpTestCase, requires_gpu
from allennlp.models.archival import load_archive
from allennlp.predictors import Predictor
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from allennlp.sanity_checks.task_checklists.task_suite import TaskSuite
from allennlp.confidence_checks.task_checklists.task_suite import TaskSuite
from allennlp.common.testing import AllenNlpTestCase
from allennlp.common.checks import ConfigurationError
from allennlp.models.archival import load_archive
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from allennlp.sanity_checks.task_checklists import utils
from allennlp.confidence_checks.task_checklists import utils
from allennlp.common.testing import AllenNlpTestCase


Expand Down
18 changes: 9 additions & 9 deletions tests/training/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@
TrainerCallback,
TrackEpochCallback,
TensorBoardCallback,
SanityChecksCallback,
ConfidenceChecksCallback,
ConsoleLoggerCallback,
)
from allennlp.training.callbacks.sanity_checks import SanityCheckError
from allennlp.training.callbacks.confidence_checks import ConfidenceCheckError
from allennlp.training.learning_rate_schedulers import CosineWithRestarts
from allennlp.training.learning_rate_schedulers import ExponentialLearningRateScheduler
from allennlp.training.momentum_schedulers import MomentumScheduler
Expand All @@ -49,7 +49,7 @@
TensorField,
)
from allennlp.training.optimizers import Optimizer
from allennlp.common.testing.sanity_check_test import (
from allennlp.common.testing.confidence_check_test import (
FakeModelForTestingNormalizationBiasVerification,
)

Expand Down Expand Up @@ -814,7 +814,7 @@ def test_trainer_can_log_learning_rates_tensorboard(self):

trainer.train()

def test_sanity_check_callback(self):
def test_confidence_check_callback(self):
model_with_bias = FakeModelForTestingNormalizationBiasVerification(use_bias=True)
inst = Instance({"x": TensorField(torch.rand(3, 1, 4))})
data_loader = SimpleDataLoader([inst, inst], 2)
Expand All @@ -824,12 +824,12 @@ def test_sanity_check_callback(self):
data_loader,
num_epochs=1,
serialization_dir=self.TEST_DIR,
callbacks=[SanityChecksCallback(serialization_dir=self.TEST_DIR)],
callbacks=[ConfidenceChecksCallback(serialization_dir=self.TEST_DIR)],
)
with pytest.raises(SanityCheckError):
with pytest.raises(ConfidenceCheckError):
trainer.train()

def test_sanity_check_default(self):
def test_confidence_check_default(self):
model_with_bias = FakeModelForTestingNormalizationBiasVerification(use_bias=True)
inst = Instance({"x": TensorField(torch.rand(3, 1, 4))})
data_loader = SimpleDataLoader([inst, inst], 2)
Expand All @@ -839,15 +839,15 @@ def test_sanity_check_default(self):
data_loader=data_loader,
num_epochs=1,
)
with pytest.raises(SanityCheckError):
with pytest.raises(ConfidenceCheckError):
trainer.train()

trainer = GradientDescentTrainer.from_partial_objects(
model_with_bias,
serialization_dir=self.TEST_DIR,
data_loader=data_loader,
num_epochs=1,
run_sanity_checks=False,
run_confidence_checks=False,
)

# Check is not run, so no failure.
Expand Down

0 comments on commit cccb35d

Please sign in to comment.