diff --git a/CHANGELOG.md b/CHANGELOG.md index b8e84f5410e..69463b7ce59 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added an `Evaluator` class to make comparing source, target, and predictions easier. - Added a way to resize the vocabulary in the T5 module - Added an argument `reinit_modules` to `cached_transformers.get()` that allows you to re-initialize the pretrained weights of a transformer model, using layer indices or regex strings. - Added attribute `_should_validate_this_epoch` to `GradientDescentTrainer` that controls whether validation is run at the end of each epoch. diff --git a/allennlp/commands/evaluate.py b/allennlp/commands/evaluate.py index 8c562cc3c11..2b9403a417f 100644 --- a/allennlp/commands/evaluate.py +++ b/allennlp/commands/evaluate.py @@ -7,17 +7,17 @@ import argparse import json import logging -from typing import Any, Dict - +from pathlib import Path +from os import PathLike +from typing import Union, Dict, Any, Optional from copy import deepcopy - from allennlp.commands.subcommand import Subcommand from allennlp.common import logging as common_logging from allennlp.common.util import prepare_environment from allennlp.data import DataLoader from allennlp.models.archival import load_archive -from allennlp.training.util import evaluate +from allennlp.evaluation import Evaluator logger = logging.getLogger(__name__) @@ -35,28 +35,23 @@ def add_subparser(self, parser: argparse._SubParsersAction) -> argparse.Argument subparser.add_argument( "input_file", type=str, - help=( - "path to the file containing the evaluation data" - ' (for mutiple files, put "," between filenames e.g., input1.txt,input2.txt)' - ), + help="path to the file containing the evaluation data (for mutiple " + "files, put between filenames e.g., input1.txt,input2.txt)", ) subparser.add_argument( "--output-file", type=str, - help=( - "optional path to write the metrics to as JSON" - ' (for mutiple files, put "," between filenames e.g., output1.txt,output2.txt)' - ), + help="optional path to write the metrics to as JSON (for mutiple " + "files, put between filenames e.g., output1.txt,output2.txt)", ) subparser.add_argument( "--predictions-output-file", type=str, - help=( - "optional path to write the predictions to as JSON lines" - ' (for mutiple files, put "," between filenames e.g., output1.jsonl,output2.jsonl)' - ), + help="optional path to write the predictions to as JSON lines " + "(for mutiple files, put between filenames e.g., " + "output1.jsonl,output2.jsonl)", ) subparser.add_argument( @@ -116,13 +111,126 @@ def add_subparser(self, parser: argparse._SubParsersAction) -> argparse.Argument help="outputs tqdm status on separate lines and slows tqdm refresh rate", ) + subparser.add_argument( + "--auto-names", + default="NONE", + help="Automatically create output names for each evaluation file. " + "`NONE` will not automatically generate a file name for the " + "neither the metrics nor the predictions. In this case you will" + " need to pas in both `metrics_output_file` and `predictions_output_file`. " + "`METRICS` will only automatically create a file name for the" + " metrics file. `PREDS` will only automatically create a file" + " name for the predictions outputs. `ALL` will create a " + "filename for both the metrics and the predictions.", + choices=["NONE", "METRICS", "PREDS", "ALL"], + ) + subparser.set_defaults(func=evaluate_from_args) return subparser def evaluate_from_args(args: argparse.Namespace) -> Dict[str, Any]: - common_logging.FILE_FRIENDLY_LOGGING = args.file_friendly_logging + return evaluate_from_archive( + archive_file=args.archive_file, + input_file=args.input_file, + metrics_output_file=args.output_file, + predictions_output_file=args.predictions_output_file, + batch_size=args.batch_size, + cmd_overrides=args.overrides, + cuda_device=args.cuda_device, + embedding_sources_mapping=args.embedding_sources_mapping, + extend_vocab=args.extend_vocab, + weights_file=args.weights_file, + file_friendly_logging=args.file_friendly_logging, + batch_weight_key=args.batch_weight_key, + auto_names=args.auto_names, + ) + + +def evaluate_from_archive( + archive_file: Union[str, PathLike], + input_file: str, + metrics_output_file: Optional[str] = None, + predictions_output_file: Optional[str] = None, + batch_size: Optional[int] = None, + cmd_overrides: Union[str, Dict[str, Any]] = "", + cuda_device: int = -1, + embedding_sources_mapping: str = None, + extend_vocab: bool = False, + weights_file: str = None, + file_friendly_logging: bool = False, + batch_weight_key: str = None, + auto_names: str = "NONE", +) -> Dict[str, Any]: + """ + + # Parameters + + archive_file: `Union[str, PathLike]` + Path to an archived trained model. + + input_file: `str` + path to the file containing the evaluation data (for multiple files, + put ":" between filenames e.g., input1.txt:input2.txt) + + metrics_output_file: `str`, optional (default=`None`) + optional path to write the metrics to as JSON (for multiple files, put + ":" between filenames e.g., output1.txt:output2.txt) + + predictions_output_file: `str`, optional (default=`None`) + "optional path to write the predictions to (for multiple files, put ":" + between filenames e.g., output1.jsonl:output2.jsonl) + + batch_size: `int`, optional (default=`None`) + If non-empty, the batch size to use during evaluation. + + cmd_overrides: `str`, optional (default=`""`) + a json(net) structure used to override the experiment configuration, + e.g., '{\"iterator.batch_size\": 16}'. Nested parameters can be + specified either with nested dictionaries or with dot syntax. + + cuda_device: `int`, optional (default=`-1`) + id of GPU to use (if any) + + embedding_sources_mapping: `str`, optional (default=`None`) + a JSON dict defining mapping from embedding module path to embedding + pretrained-file used during training. If not passed, and embedding + needs to be extended, we will try to use the original file paths used + during training. If they are not available we will use random vectors + for embedding extension. + + extend_vocab: `bool`, optional (default=`False`) + if specified, we will use the instances in your new dataset to extend + your vocabulary. If pretrained-file was used to initialize embedding + layers, you may also need to pass --embedding-sources-mapping. + + weights_file:`str`, optional (default=`None`) + A path that overrides which weights file to use + + file_friendly_logging : `bool`, optional (default=`False`) + If `True`, we add newlines to tqdm output, even on an interactive terminal, and we slow + down tqdm's output to only once every 10 seconds. + + batch_weight_key: `str`, optional (default=`None`) + If non-empty, name of metric used to weight the loss on a per-batch basis. + + auto_names:`str`, optional (default=`"NONE"`) + Automatically create output names for each evaluation file.`NONE` will + not automatically generate a file name for the neither the metrics nor + the predictions. In this case you will need to pas in both + `metrics_output_file` and `predictions_output_file`. `METRICS` will only + automatically create a file name for the metrics file. `PREDS` will only + automatically create a file name for the predictions outputs. `ALL` + will create a filename for both the metrics and the predictions. + + # Returns + + all_metrics: `Dict[str, Any]` + The metrics from every evaluation file passed. + + """ + common_logging.FILE_FRIENDLY_LOGGING = file_friendly_logging # Disable some of the more verbose logging statements logging.getLogger("allennlp.common.params").disabled = True @@ -131,77 +239,124 @@ def evaluate_from_args(args: argparse.Namespace) -> Dict[str, Any]: # Load from archive archive = load_archive( - args.archive_file, - weights_file=args.weights_file, - cuda_device=args.cuda_device, - overrides=args.overrides, + archive_file, + weights_file=weights_file, + cuda_device=cuda_device, + overrides=cmd_overrides, ) config = deepcopy(archive.config) prepare_environment(config) model = archive.model model.eval() + # Load the evaluator from the config key `Evaluator` + evaluator_params = config.pop("evaluation", {}) + evaluator_params["cuda_device"] = cuda_device + evaluator = Evaluator.from_params(evaluator_params) + # Load the evaluation data dataset_reader = archive.validation_dataset_reader # split files - evaluation_data_path_list = args.input_file.split(",") - if args.output_file is not None: - output_file_list = args.output_file.split(",") - assert len(output_file_list) == len( - evaluation_data_path_list - ), "The number of `output_file` paths must be equal to the number of datasets being evaluated." - if args.predictions_output_file is not None: - predictions_output_file_list = args.predictions_output_file.split(",") - assert len(predictions_output_file_list) == len(evaluation_data_path_list), ( - "The number of `predictions_output_file` paths must be equal" - + "to the number of datasets being evaluated. " - ) + evaluation_data_path_list = input_file.split(",") + + # TODO(gabeorlanski): Is it safe to always default to .outputs and .preds? + # TODO(gabeorlanski): Add in way to save to specific output directory + if metrics_output_file is not None: + if auto_names == "METRICS" or auto_names == "ALL": + logger.warning( + f"Passed output_files will be ignored, auto_names is" f" set to {auto_names}" + ) + + # Keep the path of the parent otherwise it will write to the CWD + output_file_list = [ + p.parent.joinpath(f"{p.stem}.outputs") for p in map(Path, evaluation_data_path_list) + ] + else: + output_file_list = metrics_output_file.split(",") # type: ignore + assert len(output_file_list) == len(evaluation_data_path_list), ( + "The number of `metrics_output_file` paths must be equal to the number " + "of datasets being evaluated." + ) + if predictions_output_file is not None: + if auto_names == "PREDS" or auto_names == "ALL": + logger.warning( + f"Passed predictions files will be ignored, auto_names is" f" set to {auto_names}" + ) + + # Keep the path of the parent otherwise it will write to the CWD + predictions_output_file_list = [ + p.parent.joinpath(f"{p.stem}.preds") for p in map(Path, evaluation_data_path_list) + ] + else: + predictions_output_file_list = predictions_output_file.split(",") # type: ignore + assert len(predictions_output_file_list) == len(evaluation_data_path_list), ( + "The number of `predictions_output_file` paths must be equal" + + "to the number of datasets being evaluated. " + ) # output file output_file_path = None predictions_output_file_path = None # embedding sources - if args.extend_vocab: + if extend_vocab: logger.info("Vocabulary is being extended with embedding sources.") embedding_sources = ( - json.loads(args.embedding_sources_mapping) if args.embedding_sources_mapping else {} + json.loads(embedding_sources_mapping) if embedding_sources_mapping else {} ) + all_metrics = {} for index in range(len(evaluation_data_path_list)): config = deepcopy(archive.config) evaluation_data_path = evaluation_data_path_list[index] - if args.output_file is not None: + + # Get the eval file name so we can save each metric by file name in the + # output dictionary. + eval_file_name = Path(evaluation_data_path).stem + + if metrics_output_file is not None: + # noinspection PyUnboundLocalVariable output_file_path = output_file_list[index] - if args.predictions_output_file is not None: + + if predictions_output_file is not None: + # noinspection PyUnboundLocalVariable predictions_output_file_path = predictions_output_file_list[index] logger.info("Reading evaluation data from %s", evaluation_data_path) data_loader_params = config.get("validation_data_loader", None) if data_loader_params is None: data_loader_params = config.get("data_loader") - if args.batch_size: - data_loader_params["batch_size"] = args.batch_size + if batch_size: + data_loader_params["batch_size"] = batch_size data_loader = DataLoader.from_params( params=data_loader_params, reader=dataset_reader, data_path=evaluation_data_path ) - if args.extend_vocab: + if extend_vocab: logger.info("Vocabulary is being extended with test instances.") model.vocab.extend_from_instances(instances=data_loader.iter_instances()) + # noinspection PyUnboundLocalVariable model.extend_embedder_vocab(embedding_sources) data_loader.index_with(model.vocab) - metrics = evaluate( + metrics = evaluator( model, data_loader, - args.cuda_device, - args.batch_weight_key, - output_file=output_file_path, + batch_weight_key, + metrics_output_file=output_file_path, predictions_output_file=predictions_output_file_path, ) + + # Add the metric prefixed by the file it came from. + for name, value in metrics.items(): + if len(evaluation_data_path_list) > 1: + key = f"{eval_file_name}_" + else: + key = "" + all_metrics[f"{key}{name}"] = value + logger.info("Finished evaluating.") - return metrics + return all_metrics diff --git a/allennlp/commands/train.py b/allennlp/commands/train.py index 8c96a18c806..46b91798aaf 100644 --- a/allennlp/commands/train.py +++ b/allennlp/commands/train.py @@ -238,6 +238,11 @@ def train_model( training_util.create_serialization_dir(params, serialization_dir, recover, force) params.to_file(os.path.join(serialization_dir, CONFIG_NAME)) + # Change Author: Gabe Orlanski + # Placeholder for the time being to make sure no errors are raised b/c of + # the evaluator. + params.pop("evaluation", None) + meta = Meta.new() meta.to_file(os.path.join(serialization_dir, META_NAME)) diff --git a/allennlp/common/testing/test_case.py b/allennlp/common/testing/test_case.py index 9f466e8ee6b..b3c4cbd00ab 100644 --- a/allennlp/common/testing/test_case.py +++ b/allennlp/common/testing/test_case.py @@ -32,7 +32,6 @@ def setup_method(self): logging.getLogger("allennlp.modules.token_embedders.embedding").setLevel(logging.INFO) logging.getLogger("urllib3.connectionpool").disabled = True log_pytorch_version_info() - self.TEST_DIR = pathlib.Path(TEST_DIR) os.makedirs(self.TEST_DIR, exist_ok=True) diff --git a/allennlp/evaluation/__init__.py b/allennlp/evaluation/__init__.py new file mode 100644 index 00000000000..feb03dd961e --- /dev/null +++ b/allennlp/evaluation/__init__.py @@ -0,0 +1,2 @@ +from allennlp.evaluation.evaluator import Evaluator, SimpleEvaluator +from allennlp.evaluation.serializers.serializers import Serializer diff --git a/allennlp/evaluation/evaluator.py b/allennlp/evaluation/evaluator.py new file mode 100644 index 00000000000..9c7ef7b6703 --- /dev/null +++ b/allennlp/evaluation/evaluator.py @@ -0,0 +1,241 @@ +""" +Evaluator class for evaluating a model with a given dataset +""" +from typing import Union, Dict, Any, Optional +from os import PathLike +from pathlib import Path +import torch +import logging + +from allennlp.common.checks import check_for_gpu +from allennlp.common.tqdm import Tqdm +from allennlp.common.util import dump_metrics, int_to_device +from allennlp.nn import util as nn_util +from allennlp.common import Registrable +from allennlp.models import Model +from allennlp.data import DataLoader +from allennlp.evaluation.serializers.serializers import Serializer, SimpleSerializer + +logger = logging.getLogger(__name__) + + +class Evaluator(Registrable): + """ + Evaluation Base class + + # Parameters + + batch_postprocessor: `Postprocessor`, optional (default=`SimplePostprocessor`) + The postprocessor to use for turning both the batches and the outputs + of the model into human readable data. + + cuda_device : `Union[int, torch.device]`, optional (default=`-1`) + The cuda device to use for this evaluation. The model is assumed to + already be using this device; this parameter is only used for moving + the input data to the correct device. + + postprocessor_fn_name: `str`, optional (default=`"make_output_human_readable"`) + Function name of the model's postprocessing function. + """ + + default_implementation = "simple" + + def __init__( + self, + batch_serializer: Optional[Serializer] = None, + cuda_device: Union[int, torch.device] = -1, + postprocessor_fn_name: str = "make_output_human_readable", + ): + self.batch_serializer = batch_serializer or SimpleSerializer() + self.cuda_device = cuda_device + self.postprocessor_fn_name = postprocessor_fn_name + + def __call__( + self, + model: Model, + data_loader: DataLoader, + batch_weight_key: str = None, + metrics_output_file: Union[str, PathLike] = None, + predictions_output_file: Union[str, PathLike] = None, + ) -> Dict[str, Any]: + """ + Evaluate a single data source. + + # Parameters + + model : `Model` + The model to evaluate + data_loader : `DataLoader` + The `DataLoader` that will iterate over the evaluation data (data loaders already contain + their data). + batch_weight_key : `str`, optional (default=`None`) + If given, this is a key in the output dictionary for each batch that specifies how to weight + the loss for that batch. If this is not given, we use a weight of 1 for every batch. + metrics_output_file : `Union[str, PathLike]`, optional (default=`None`) + Optional path to write the final metrics to. + + predictions_output_file : `Union[str, PathLike]`, optional (default=`None`) + Optional path to write the predictions to. If passed the + postprocessor will be called and its output will be written as lines. + + + # Returns + + metrics: `Dict[str, Any]` + The metrics from evaluating the file. + """ + raise NotImplementedError("__call__") + + +@Evaluator.register("simple") +class SimpleEvaluator(Evaluator): + """ + Simple evaluator implementation. Uses the vanilla evaluation code. + + # Parameters + + batch_postprocessor: `Postprocessor`, optional (default=`SimplePostprocessor`) + The postprocessor to use for turning both the batches and the outputs + of the model into human readable data. + + cuda_device : `Union[int, torch.device]`, optional (default=`-1`) + The cuda device to use for this evaluation. The model is assumed to + already be using this device; this parameter is only used for moving + the input data to the correct device. + + postprocessor_fn_name: `str`, optional (default=`"make_output_human_readable"`) + Function name of the model's postprocessing function. + """ + + def __init__( + self, + batch_serializer: Optional[Serializer] = None, + cuda_device: Union[int, torch.device] = -1, + postprocessor_fn_name: str = "make_output_human_readable", + ): + super(SimpleEvaluator, self).__init__(batch_serializer, cuda_device, postprocessor_fn_name) + + def __call__( + self, + model: Model, + data_loader: DataLoader, + batch_weight_key: str = None, + metrics_output_file: Union[str, PathLike] = None, + predictions_output_file: Union[str, PathLike] = None, + ): + """ + Evaluate a single data source. + + # Parameters + + model : `Model` + The model to evaluate + data_loader : `DataLoader` + The `DataLoader` that will iterate over the evaluation data (data loaders already contain + their data). + batch_weight_key : `str`, optional (default=`None`) + If given, this is a key in the output dictionary for each batch that specifies how to weight + the loss for that batch. If this is not given, we use a weight of 1 for every batch. + metrics_output_file : `Union[str, PathLike]`, optional (default=`None`) + Optional path to write the final metrics to. + predictions_output_file : `Union[str, PathLike]`, optional (default=`None`) + Optional path to write the predictions to. + + # Returns + + metrics: `Dict[str, Any]` + The metrics from evaluating the file. + """ + check_for_gpu(self.cuda_device) + data_loader.set_target_device(int_to_device(self.cuda_device)) + metrics_output_file = Path(metrics_output_file) if metrics_output_file is not None else None + if predictions_output_file is not None: + predictions_file = Path(predictions_output_file).open("w", encoding="utf-8") + else: + predictions_file = None # type: ignore + + model_postprocess_function = getattr(model, self.postprocessor_fn_name, None) + + with torch.no_grad(): + model.eval() + + iterator = iter(data_loader) + logger.info("Iterating over dataset") + generator_tqdm = Tqdm.tqdm(iterator) + # Number of batches in instances. + batch_count = 0 + # Number of batches where the model produces a loss. + loss_count = 0 + # Cumulative weighted loss + total_loss = 0.0 + # Cumulative weight across all batches. + total_weight = 0.0 + + for batch in generator_tqdm: + batch_count += 1 + batch = nn_util.move_to_device(batch, self.cuda_device) + output_dict = model(**batch) + loss = output_dict.get("loss") + + metrics = model.get_metrics() + + if loss is not None: + loss_count += 1 + if batch_weight_key: + weight = output_dict[batch_weight_key].item() + else: + weight = 1.0 + + total_weight += weight + total_loss += loss.item() * weight + # Report the average loss so far. + metrics["loss"] = total_loss / total_weight + + description = ( + ", ".join( + [ + "%s: %.2f" % (name, value) + for name, value in metrics.items() + if not name.startswith("_") + ] + ) + + " ||" + ) + generator_tqdm.set_description(description, refresh=False) + + # TODO(gabeorlanski): Add in postprocessing the batch for token + # metrics + if predictions_file is not None: + predictions_file.write( + self.batch_serializer( + batch, + output_dict, + data_loader, + output_postprocess_function=model_postprocess_function, + ) + + "\n" + ) + + if predictions_file is not None: + predictions_file.close() + + final_metrics = model.get_metrics(reset=True) + if loss_count > 0: + # Sanity check + if loss_count != batch_count: + raise RuntimeError( + "The model you are trying to evaluate only sometimes produced a loss!" + ) + final_metrics["loss"] = total_loss / total_weight + + if metrics_output_file is not None: + dump_metrics(str(metrics_output_file), final_metrics, log=True) + + return final_metrics + + def _to_params(self) -> Dict[str, Any]: + return { + "type": "simple", + "cuda_device": self.cuda_device, + "batch_postprocessor": self.batch_serializer.to_params(), + } diff --git a/allennlp/evaluation/postprocessors/__init__.py b/allennlp/evaluation/postprocessors/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/allennlp/evaluation/serializers/__init__.py b/allennlp/evaluation/serializers/__init__.py new file mode 100644 index 00000000000..a9aef4a6b22 --- /dev/null +++ b/allennlp/evaluation/serializers/__init__.py @@ -0,0 +1 @@ +from allennlp.evaluation.serializers.serializers import SimpleSerializer diff --git a/allennlp/evaluation/serializers/serializers.py b/allennlp/evaluation/serializers/serializers.py new file mode 100644 index 00000000000..2fbcfa79f6e --- /dev/null +++ b/allennlp/evaluation/serializers/serializers.py @@ -0,0 +1,106 @@ +from typing import Optional, Dict, Any, Callable +import logging +import json + +from allennlp.common.util import sanitize +from allennlp.data.fields import TensorField +from allennlp.common import Registrable +from allennlp.data import DataLoader + +logger = logging.getLogger(__name__) + + +class Serializer(Registrable): + """ + General serializer class for turning batches into human readable data + """ + + def __call__( + self, + batch: Dict[str, TensorField], + output_dict: Dict, + data_loader: DataLoader, + output_postprocess_function: Optional[Callable] = None, + ) -> str: + """ + Postprocess a batch. + + # Parameters + + batch: `Dict[str, TensorField]` + The batch that was passed to the model's forward function. + + output_dict: `Dict` + The output of the model's forward function on the batch + + data_loader: `DataLoader` + The dataloader to be used. + + output_postprocess_function: `Callable`, optional (default=`None`) + If you have a function to preprocess only the outputs ( + i.e. `model.make_human_readable`), use this parameter to have it + called on the output dict. + + # Returns + + postprocessed: `str` + The postprocessed batches as strings + """ + raise NotImplementedError("__call__") + + default_implementation = "simple" + + +@Serializer.register("simple") +class SimpleSerializer(Serializer): + """ + Very simple serializer. Only sanitizes the batches and outputs. Will use + a passed serializer function for the outputs if it exists. + """ + + def _to_params(self) -> Dict[str, Any]: + return {"type": "simple"} + + def __call__( + self, + batch: Dict[str, TensorField], + output_dict: Dict, + data_loader: DataLoader, + output_postprocess_function: Optional[Callable] = None, + ): + """ + Serializer a batch. + + # Parameters + + batch: `Dict[str, TensorField]` + The batch that was passed to the model's forward function. + + output_dict: `Dict` + The output of the model's forward function on the batch + + data_loader: `DataLoader` + The dataloader to be used. + + output_postprocess_function: `Callable`, optional (default=`None`) + If you have a function to preprocess only the outputs ( + i.e. `model.make_human_readable`), use this parameter to have it + called on the output dict. + + # Returns + + serialized: `str` + The serialized batches as strings + """ + if batch is None: + raise ValueError("Serializer got a batch that is None") + if output_dict is None: + raise ValueError("Serializer got an output_dict that is None") + + serialized = sanitize(batch) + if output_postprocess_function is not None: + serialized.update(sanitize(output_postprocess_function(output_dict))) + else: + serialized.update(sanitize(output_dict)) + + return json.dumps(serialized) diff --git a/tests/commands/evaluate_test.py b/tests/commands/evaluate_test.py index 1cbaf8147cd..eebf7753453 100644 --- a/tests/commands/evaluate_test.py +++ b/tests/commands/evaluate_test.py @@ -1,12 +1,13 @@ import argparse import json +from pathlib import Path from typing import Iterator, List, Dict - +from shutil import copyfile +import pytest import torch from flaky import flaky -import pytest -from allennlp.commands.evaluate import evaluate_from_args, Evaluate, evaluate +from allennlp.commands.evaluate import evaluate_from_args, Evaluate from allennlp.common.testing import AllenNlpTestCase from allennlp.data.data_loaders import TensorDict from allennlp.models import Model @@ -43,25 +44,6 @@ def setup_method(self): subparsers = self.parser.add_subparsers(title="Commands", metavar="") Evaluate().add_subparser(subparsers) - def test_evaluate_calculates_average_loss(self): - losses = [7.0, 9.0, 8.0] - outputs = [{"loss": torch.Tensor([loss])} for loss in losses] - data_loader = DummyDataLoader(outputs) - metrics = evaluate(DummyModel(), data_loader, -1, "") - assert metrics["loss"] == pytest.approx(8.0) - - def test_evaluate_calculates_average_loss_with_weights(self): - losses = [7.0, 9.0, 8.0] - weights = [10, 2, 1.5] - inputs = zip(losses, weights) - outputs = [ - {"loss": torch.Tensor([loss]), "batch_weight": torch.Tensor([weight])} - for loss, weight in inputs - ] - data_loader = DummyDataLoader(outputs) - metrics = evaluate(DummyModel(), data_loader, -1, "batch_weight") - assert metrics["loss"] == pytest.approx((70 + 18 + 12) / 13.5) - @flaky def test_evaluate_from_args(self): kebab_args = [ @@ -114,34 +96,55 @@ def test_output_file_evaluate_from_args(self): assert "tags" in prediction def test_multiple_output_files_evaluate_from_args(self): - output_file = str(self.TEST_DIR / "metrics.json") - predictions_output_file = str(self.TEST_DIR / "predictions.jsonl") + data_file = Path(self.FIXTURES_ROOT / "data" / "conll2003.txt") + paths = [] + out_paths = [] + pred_paths = [] + for i in range(3): + tmp_path = self.TEST_DIR.joinpath(f"TEST{i}.txt") + + # Need to create paths to check when they do not exist + out_paths.append(tmp_path.parent.joinpath(f"OUTPUTS{i}.json")) + pred_paths.append(tmp_path.parent.joinpath(f"PREDS{i}.txt")) + + copyfile(data_file, tmp_path) + paths.append(tmp_path) + kebab_args = [ "evaluate", str( self.FIXTURES_ROOT / "simple_tagger_with_span_f1" / "serialization" / "model.tar.gz" ), - str(self.FIXTURES_ROOT / "data" / "conll2003.txt") - + "," - + str(self.FIXTURES_ROOT / "data" / "conll2003.txt"), + ",".join(map(str, paths)), "--cuda-device", "-1", "--output-file", - output_file + "," + output_file, + ",".join(map(str, out_paths)), "--predictions-output-file", - predictions_output_file + "," + predictions_output_file, + ",".join(map(str, pred_paths)), ] args = self.parser.parse_args(kebab_args) computed_metrics = evaluate_from_args(args) + computed_by_file = {} + for k, v in computed_metrics.items(): + fn, *metric_name = k.split("_") + if fn not in computed_by_file: + computed_by_file[fn] = {} + computed_by_file[fn]["_".join(metric_name)] = v - with open(output_file, "r") as file: - saved_metrics = json.load(file) - assert computed_metrics == saved_metrics + assert len(computed_by_file) == len(paths) + expected_input_data = data_file.read_text("utf-8") - with open(predictions_output_file, "r") as file: - for line in file: - prediction = json.loads(line.strip()) - assert "tags" in prediction + for i, p in enumerate(paths): + # Make sure it was not modified + assert p.read_text("utf-8") == expected_input_data + + assert p.stem in computed_by_file, f"paths[{i}]={p.stem}" + + assert out_paths[i].exists(), f"paths[{i}]={p.stem}" + saved_metrics = json.loads(out_paths[i].read_text("utf-8")) + assert saved_metrics == computed_by_file[p.stem], f"paths[{i}]={p.stem}" + assert pred_paths[i].exists(), f"paths[{i}]={p.stem}" def test_evaluate_works_with_vocab_expansion(self): archive_path = str( @@ -175,3 +178,57 @@ def test_evaluate_works_with_vocab_expansion(self): ) assert metrics_1 != metrics_2 assert metrics_2 != metrics_3 + + @pytest.mark.parametrize("auto_names", ["NONE", "METRICS", "PREDS", "ALL"]) + def test_auto_names_creates_files(self, auto_names): + data_file = Path(self.FIXTURES_ROOT / "data" / "conll2003.txt") + paths = [] + out_paths = [] + pred_paths = [] + for i in range(5): + tmp_path = self.TEST_DIR.joinpath(f"TEST{i}.txt") + + # Need to create paths to check when they do not exist + out_paths.append(tmp_path.parent.joinpath(f"OUTPUTS{i}.json")) + pred_paths.append(tmp_path.parent.joinpath(f"PREDS{i}.txt")) + + copyfile(data_file, tmp_path) + paths.append(tmp_path) + + kebab_args = [ + "evaluate", + str( + self.FIXTURES_ROOT / "simple_tagger_with_span_f1" / "serialization" / "model.tar.gz" + ), + ",".join(map(str, paths)), + "--cuda-device", + "-1", + "--output-file", + ",".join(map(str, out_paths)), + "--predictions-output-file", + ",".join(map(str, pred_paths)), + "--auto-names", + auto_names, + ] + + args = self.parser.parse_args(kebab_args) + _ = evaluate_from_args(args) + + expected_input_data = data_file.read_text("utf-8") + + for i, p in enumerate(paths): + # Make sure it was not modified + assert p.read_text("utf-8") == expected_input_data + + if auto_names == "METRICS" or auto_names == "ALL": + assert not out_paths[i].exists() + assert p.parent.joinpath(f"{p.stem}.outputs").exists() + else: + assert out_paths[i].exists() + assert not p.parent.joinpath(f"{p.stem}.outputs").exists() + if auto_names == "PREDS" or auto_names == "ALL": + assert not pred_paths[i].exists() + assert p.parent.joinpath(f"{p.stem}.preds").exists() + else: + assert pred_paths[i].exists() + assert not p.parent.joinpath(f"{p.stem}.preds").exists() diff --git a/tests/evaluation/__init__.py b/tests/evaluation/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/evaluation/evaluator_tests.py b/tests/evaluation/evaluator_tests.py new file mode 100644 index 00000000000..57ec28e6bfe --- /dev/null +++ b/tests/evaluation/evaluator_tests.py @@ -0,0 +1,64 @@ +from typing import Iterator, List, Dict + +import torch +import pytest + +from allennlp.common.testing import AllenNlpTestCase +from allennlp.data.data_loaders import TensorDict +from allennlp.models import Model +from allennlp.evaluation import Evaluator +from allennlp.common import Params + + +class DummyDataLoader: + def __init__(self, outputs: List[TensorDict]) -> None: + super().__init__() + self._outputs = outputs + + def __iter__(self) -> Iterator[TensorDict]: + yield from self._outputs + + def __len__(self): + return len(self._outputs) + + def set_target_device(self, _): + pass + + +class DummyModel(Model): + def __init__(self) -> None: + super().__init__(None) # type: ignore + + def forward(self, **kwargs) -> Dict[str, torch.Tensor]: # type: ignore + return kwargs + + +class TestEvaluator(AllenNlpTestCase): + def setup_method(self): + self.evaluator = Evaluator.from_params(Params({"batch_postprocessor": "simple"})) + + def test_evaluate_calculates_average_loss(self): + losses = [7.0, 9.0, 8.0] + outputs = [{"loss": torch.Tensor([loss])} for loss in losses] + data_loader = DummyDataLoader(outputs) + metrics = self.evaluator(DummyModel(), data_loader, "") # type: ignore + assert metrics["loss"] == pytest.approx(8.0) + + def test_evaluate_calculates_average_loss_with_weights(self): + losses = [7.0, 9.0, 8.0] + weights = [10, 2, 1.5] + inputs = zip(losses, weights) + outputs = [ + {"loss": torch.Tensor([loss]), "batch_weight": torch.Tensor([weight])} + for loss, weight in inputs + ] + data_loader = DummyDataLoader(outputs) + metrics = self.evaluator(DummyModel(), data_loader, "batch_weight") # type: ignore + assert metrics["loss"] == pytest.approx((70 + 18 + 12) / 13.5) + + def test_to_params(self): + assert self.evaluator.to_params() == { + "type": "simple", + "cuda_device": -1, + "batch_postprocessor": {"type": "simple"}, + } diff --git a/tests/evaluation/serializers/__init__.py b/tests/evaluation/serializers/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/evaluation/serializers/serializer_test.py b/tests/evaluation/serializers/serializer_test.py new file mode 100644 index 00000000000..684f01c8a5b --- /dev/null +++ b/tests/evaluation/serializers/serializer_test.py @@ -0,0 +1,76 @@ +from typing import Iterator, List +import torch +import pytest +import json + +from allennlp.common.testing import AllenNlpTestCase +from allennlp.common import Params +from allennlp.common.util import sanitize +from allennlp.data.data_loaders import TensorDict +from allennlp.evaluation import Serializer +from allennlp.evaluation.serializers import SimpleSerializer + + +class DummyDataLoader: + def __init__(self, outputs: List[TensorDict]) -> None: + super().__init__() + self._outputs = outputs + + def __iter__(self) -> Iterator[TensorDict]: + yield from self._outputs + + def __len__(self): + return len(self._outputs) + + def set_target_device(self, _): + pass + + +class TestSerializer(AllenNlpTestCase): + def setup_method(self): + super(TestSerializer, self).setup_method() + self.postprocessor = Serializer.from_params(Params({})) + + def test_postprocessor_default_implementation(self): + assert self.postprocessor.to_params().params == {"type": "simple"} + assert isinstance(self.postprocessor, SimpleSerializer) + + @pytest.mark.parametrize( + "batch", + [ + { + "Do you want ants?": "Because that's how you get ants.", + "testing": torch.tensor([[1, 2, 3]]), + }, + {}, + None, + ], + ids=["TestBatch", "EmptyBatch", "None"], + ) + @pytest.mark.parametrize( + "output_dict", + [{"You're": ["Not", [["My"]], "Supervisor"]}, {}, None], + ids=["TestOutput", "EmptyOutput", "None"], + ) + @pytest.mark.parametrize( + "postprocess_func", + [lambda x: {k.upper(): v for k, v in x.items()}, None], + ids=["PassedFunction", "NoPassedFunction"], + ) + def test_simple_postprocessor_call(self, batch, output_dict, postprocess_func): + data_loader = DummyDataLoader([]) + if batch is None or output_dict is None: + with pytest.raises(ValueError): + self.postprocessor(batch, output_dict, data_loader) # type: ignore + return + + expected = json.dumps( + sanitize( + {**batch, **(postprocess_func(output_dict) if postprocess_func else output_dict)} + ) + ) + + result = self.postprocessor( + batch, output_dict, data_loader, postprocess_func # type: ignore + ) + assert result == expected