diff --git a/.dockerignore b/.dockerignore index 7183ef15ace..a7e3b8ecea1 100644 --- a/.dockerignore +++ b/.dockerignore @@ -3,3 +3,6 @@ **/__pycache__ .gitignore .git +.coverage +.benchmarks +.mypy_cache diff --git a/.gitignore b/.gitignore index e40fe621816..6917232047e 100644 --- a/.gitignore +++ b/.gitignore @@ -44,6 +44,7 @@ __pycache__ .coverage .pytest_cache/ +.benchmarks # documentation build artifacts diff --git a/CHANGELOG.md b/CHANGELOG.md index 46b1b03d52e..449973a6eb0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fixed a lot of instances where tensors were first created and then sent to a device + with `.to(device)`. Instead, these tensors are now created directly on the target device. - Fixed issue with `GradientDescentTrainer` when constructed with `validation_data_loader=None` and `learning_rate_scheduler!=None`. - Fixed a bug when removing all handlers in root logger. - `ShardedDatasetReader` now inherits parameters from `base_reader` when required. diff --git a/allennlp/interpret/saliency_interpreters/smooth_gradient.py b/allennlp/interpret/saliency_interpreters/smooth_gradient.py index 67aad524b41..a1c5980ec31 100644 --- a/allennlp/interpret/saliency_interpreters/smooth_gradient.py +++ b/allennlp/interpret/saliency_interpreters/smooth_gradient.py @@ -58,7 +58,7 @@ def _register_forward_hook(self, stdev: float): def forward_hook(module, inputs, output): # Random noise = N(0, stdev * (max-min)) scale = output.detach().max() - output.detach().min() - noise = torch.randn(output.shape).to(output.device) * stdev * scale + noise = torch.randn(output.shape, device=output.device) * stdev * scale # Add the random noise output.add_(noise) diff --git a/allennlp/modules/sampled_softmax_loss.py b/allennlp/modules/sampled_softmax_loss.py index c500fd2ab66..88e82975c25 100644 --- a/allennlp/modules/sampled_softmax_loss.py +++ b/allennlp/modules/sampled_softmax_loss.py @@ -155,7 +155,7 @@ def forward( if embeddings.shape[0] == 0: # empty batch - return torch.tensor(0.0).to(embeddings.device) + return torch.tensor(0.0, device=embeddings.device) if not self.training: return self._forward_eval(embeddings, targets) diff --git a/allennlp/nn/util.py b/allennlp/nn/util.py index c103f84b71c..56793958e04 100644 --- a/allennlp/nn/util.py +++ b/allennlp/nn/util.py @@ -1548,7 +1548,6 @@ def add_sentence_boundary_token_ids( The new mask for the tensor, taking into account the appended tokens marking the beginning and end of the sentence. """ - # TODO: matthewp, profile this transfer sequence_lengths = mask.sum(dim=1).detach().cpu().numpy() tensor_shape = list(tensor.data.shape) new_shape = list(tensor_shape) @@ -1603,7 +1602,6 @@ def remove_sentence_boundaries( new_mask : `torch.BoolTensor` The new mask for the tensor of shape `(batch_size, timesteps - 2)`. """ - # TODO: matthewp, profile this transfer sequence_lengths = mask.sum(dim=1).detach().cpu().numpy() tensor_shape = list(tensor.data.shape) new_shape = list(tensor_shape) diff --git a/allennlp/training/metrics/attachment_scores.py b/allennlp/training/metrics/attachment_scores.py index c628f50a9c3..078cd3822d0 100644 --- a/allennlp/training/metrics/attachment_scores.py +++ b/allennlp/training/metrics/attachment_scores.py @@ -88,8 +88,8 @@ def __call__( # type: ignore dist.all_reduce(unlabeled_exact_match, op=dist.ReduceOp.SUM) dist.all_reduce(correct_labels_and_indices, op=dist.ReduceOp.SUM) dist.all_reduce(labeled_exact_match, op=dist.ReduceOp.SUM) - total_sentences = torch.tensor(total_sentences).to(device) - total_words = torch.tensor(total_words).to(device) + total_sentences = torch.tensor(total_sentences, device=device) + total_words = torch.tensor(total_words, device=device) dist.all_reduce(total_sentences, op=dist.ReduceOp.SUM) dist.all_reduce(total_words, op=dist.ReduceOp.SUM) total_sentences = total_sentences.item() diff --git a/allennlp/training/metrics/average.py b/allennlp/training/metrics/average.py index 56d8b71d184..6d1738f2918 100644 --- a/allennlp/training/metrics/average.py +++ b/allennlp/training/metrics/average.py @@ -32,8 +32,8 @@ def __call__(self, value): _count = 1 if is_distributed(): device = torch.device("cuda" if dist.get_backend() == "nccl" else "cpu") - count = torch.tensor(_count).to(device) - total_value = torch.tensor(_total_value).to(device) + count = torch.tensor(_count, device=device) + total_value = torch.tensor(_total_value, device=device) dist.all_reduce(count, op=dist.ReduceOp.SUM) dist.all_reduce(total_value, op=dist.ReduceOp.SUM) _count = count.item() diff --git a/allennlp/training/metrics/bleu.py b/allennlp/training/metrics/bleu.py index 2415407412e..8b1a48f2b2a 100644 --- a/allennlp/training/metrics/bleu.py +++ b/allennlp/training/metrics/bleu.py @@ -127,8 +127,8 @@ def __call__( predictions, gold_targets, ngram_size ) if is_distributed(): - _precision_matches = torch.tensor(precision_matches).to(device) - _precision_totals = torch.tensor(precision_totals).to(device) + _precision_matches = torch.tensor(precision_matches, device=device) + _precision_totals = torch.tensor(precision_totals, device=device) dist.all_reduce(_precision_matches, op=dist.ReduceOp.SUM) dist.all_reduce(_precision_totals, op=dist.ReduceOp.SUM) precision_matches = _precision_matches.item() / world_size @@ -150,8 +150,8 @@ def __call__( _reference_lengths = valid_gold_targets_mask.sum().item() if is_distributed(): - prediction_lengths = torch.tensor(_prediction_lengths).to(device) - reference_lengths = torch.tensor(_reference_lengths).to(device) + prediction_lengths = torch.tensor(_prediction_lengths, device=device) + reference_lengths = torch.tensor(_reference_lengths, device=device) dist.all_reduce(prediction_lengths, op=dist.ReduceOp.SUM) dist.all_reduce(reference_lengths, op=dist.ReduceOp.SUM) _prediction_lengths = prediction_lengths.item() diff --git a/allennlp/training/metrics/covariance.py b/allennlp/training/metrics/covariance.py index 180810d6fb4..8877f4f2027 100644 --- a/allennlp/training/metrics/covariance.py +++ b/allennlp/training/metrics/covariance.py @@ -111,10 +111,10 @@ def __call__( # # Note: this gives an approximate aggregation of the covariance. # device = gold_labels.device - # delta_mean_prediction = torch.tensor(delta_mean_prediction).to(device) - # delta_mean_label = torch.tensor(delta_mean_label).to(device) - # delta_co_moment = torch.tensor(delta_co_moment).to(device) - # _total_count = torch.tensor(updated_count).to(device) + # delta_mean_prediction = torch.tensor(delta_mean_prediction, device=device) + # delta_mean_label = torch.tensor(delta_mean_label, device=device) + # delta_co_moment = torch.tensor(delta_co_moment, device=device) + # _total_count = torch.tensor(updated_count, device=device) # dist.all_reduce(delta_mean_prediction, op=dist.ReduceOp.SUM) # dist.all_reduce(delta_mean_label, op=dist.ReduceOp.SUM) # dist.all_reduce(delta_co_moment, op=dist.ReduceOp.SUM) diff --git a/allennlp/training/metrics/entropy.py b/allennlp/training/metrics/entropy.py index fe326e47f5b..997d72db652 100644 --- a/allennlp/training/metrics/entropy.py +++ b/allennlp/training/metrics/entropy.py @@ -43,7 +43,7 @@ def __call__( _count = 1 if is_distributed(): - count = torch.tensor(_count).to(device) + count = torch.tensor(_count, device=device) dist.all_reduce(_entropy, op=dist.ReduceOp.SUM) dist.all_reduce(count, op=dist.ReduceOp.SUM) _count = count.item() diff --git a/allennlp/training/metrics/evalb_bracketing_scorer.py b/allennlp/training/metrics/evalb_bracketing_scorer.py index 8b13469501d..074b154acae 100644 --- a/allennlp/training/metrics/evalb_bracketing_scorer.py +++ b/allennlp/training/metrics/evalb_bracketing_scorer.py @@ -155,9 +155,9 @@ def __call__(self, predicted_trees: List[Tree], gold_trees: List[Tree]) -> None: if is_distributed(): device = torch.device("cuda" if dist.get_backend() == "nccl" else "cpu") - correct_predicted_brackets = torch.tensor(_correct_predicted_brackets).to(device) - predicted_brackets = torch.tensor(_predicted_brackets).to(device) - gold_brackets = torch.tensor(_gold_brackets).to(device) + correct_predicted_brackets = torch.tensor(_correct_predicted_brackets, device=device) + predicted_brackets = torch.tensor(_predicted_brackets, device=device) + gold_brackets = torch.tensor(_gold_brackets, device=device) dist.all_reduce(correct_predicted_brackets, op=dist.ReduceOp.SUM) dist.all_reduce(predicted_brackets, op=dist.ReduceOp.SUM) dist.all_reduce(gold_brackets, op=dist.ReduceOp.SUM) diff --git a/allennlp/training/metrics/fbeta_measure.py b/allennlp/training/metrics/fbeta_measure.py index 5bb70dc9f36..beb0bd95798 100644 --- a/allennlp/training/metrics/fbeta_measure.py +++ b/allennlp/training/metrics/fbeta_measure.py @@ -142,7 +142,7 @@ def __call__( # Watch it: # The total numbers of true positives under all _predicted_ classes are zeros. if true_positives_bins.shape[0] == 0: - true_positive_sum = torch.zeros(num_classes, device=predictions.device) + true_positive_sum = torch.zeros(num_classes, device=device) else: true_positive_sum = torch.bincount( true_positives_bins.long(), minlength=num_classes @@ -154,7 +154,7 @@ def __call__( if pred_bins.shape[0] != 0: pred_sum = torch.bincount(pred_bins, minlength=num_classes).float() else: - pred_sum = torch.zeros(num_classes, device=predictions.device) + pred_sum = torch.zeros(num_classes, device=device) gold_labels_bins = gold_labels[mask].long() if gold_labels.shape[0] != 0: @@ -165,9 +165,7 @@ def __call__( self._total_sum += mask.sum().to(torch.float) if is_distributed(): - true_positive_sum = torch.tensor(true_positive_sum).to(device) - pred_sum = torch.tensor(pred_sum).to(device) - true_sum = torch.tensor(true_sum).to(device) + true_positive_sum = torch.tensor(true_positive_sum, device=device) dist.all_reduce(true_positive_sum, op=dist.ReduceOp.SUM) dist.all_reduce(pred_sum, op=dist.ReduceOp.SUM) dist.all_reduce(true_sum, op=dist.ReduceOp.SUM) diff --git a/allennlp/training/metrics/fbeta_multi_label_measure.py b/allennlp/training/metrics/fbeta_multi_label_measure.py index ca925672b58..c95a757e5e5 100644 --- a/allennlp/training/metrics/fbeta_multi_label_measure.py +++ b/allennlp/training/metrics/fbeta_multi_label_measure.py @@ -156,9 +156,9 @@ def __call__( self._total_sum += mask.expand_as(gold_labels).sum().to(torch.float) if is_distributed(): - true_positive_sum = torch.tensor(true_positive_sum).to(device) - pred_sum = torch.tensor(pred_sum).to(device) - true_sum = torch.tensor(true_sum).to(device) + true_positive_sum = torch.tensor(true_positive_sum, device=device) + pred_sum = torch.tensor(pred_sum, device=device) + true_sum = torch.tensor(true_sum, device=device) dist.all_reduce(true_positive_sum, op=dist.ReduceOp.SUM) dist.all_reduce(pred_sum, op=dist.ReduceOp.SUM) dist.all_reduce(true_sum, op=dist.ReduceOp.SUM) diff --git a/allennlp/training/metrics/mean_absolute_error.py b/allennlp/training/metrics/mean_absolute_error.py index f094a8aea02..b7217c134d2 100644 --- a/allennlp/training/metrics/mean_absolute_error.py +++ b/allennlp/training/metrics/mean_absolute_error.py @@ -47,8 +47,8 @@ def __call__( _absolute_error = torch.sum(absolute_errors) if is_distributed(): - absolute_error = torch.tensor(_absolute_error).to(device) - total_count = torch.tensor(_total_count).to(device) + absolute_error = torch.tensor(_absolute_error, device=device) + total_count = torch.tensor(_total_count, device=device) dist.all_reduce(absolute_error, op=dist.ReduceOp.SUM) dist.all_reduce(total_count, op=dist.ReduceOp.SUM) _absolute_error = absolute_error.item() diff --git a/allennlp/training/metrics/rouge.py b/allennlp/training/metrics/rouge.py index d10de4d8835..cceaa6eacef 100644 --- a/allennlp/training/metrics/rouge.py +++ b/allennlp/training/metrics/rouge.py @@ -113,7 +113,7 @@ def _get_rouge_l_score( if is_distributed(): device = predicted_tokens.device - _total_f1 = torch.tensor(total_f1).to(device) + _total_f1 = torch.tensor(total_f1, device=device) dist.all_reduce(_total_f1, op=dist.ReduceOp.SUM) total_f1 = _total_f1.item() @@ -162,9 +162,9 @@ def _get_rouge_n_stats( if is_distributed(): device = predicted_tokens.device - _total_recall = torch.tensor(total_recall).to(device) - _total_precision = torch.tensor(total_precision).to(device) - _total_f1 = torch.tensor(total_f1).to(device) + _total_recall = torch.tensor(total_recall, device=device) + _total_precision = torch.tensor(total_precision, device=device) + _total_f1 = torch.tensor(total_f1, device=device) dist.all_reduce(_total_recall, op=dist.ReduceOp.SUM) dist.all_reduce(_total_precision, op=dist.ReduceOp.SUM) dist.all_reduce(_total_f1, op=dist.ReduceOp.SUM) @@ -209,7 +209,7 @@ def __call__( sequence_count = len(predictions) if is_distributed(): device = predictions.device - _sequence_count = torch.tensor(sequence_count).to(device) + _sequence_count = torch.tensor(sequence_count, device=device) dist.all_reduce(_sequence_count, op=dist.ReduceOp.SUM) sequence_count = _sequence_count.item() self._total_sequence_count += sequence_count diff --git a/allennlp/training/metrics/sequence_accuracy.py b/allennlp/training/metrics/sequence_accuracy.py index a2963e23ff0..46fdde6b31d 100644 --- a/allennlp/training/metrics/sequence_accuracy.py +++ b/allennlp/training/metrics/sequence_accuracy.py @@ -73,8 +73,8 @@ def __call__( _correct_count = correct if is_distributed(): - correct_count = torch.tensor(_correct_count).to(device) - total_count = torch.tensor(_total_count).to(device) + correct_count = torch.tensor(_correct_count, device=device) + total_count = torch.tensor(_total_count, device=device) dist.all_reduce(correct_count, op=dist.ReduceOp.SUM) dist.all_reduce(total_count, op=dist.ReduceOp.SUM) _correct_count = correct_count.item() diff --git a/allennlp/training/metrics/unigram_recall.py b/allennlp/training/metrics/unigram_recall.py index 477bfe2627e..994e841960d 100644 --- a/allennlp/training/metrics/unigram_recall.py +++ b/allennlp/training/metrics/unigram_recall.py @@ -83,8 +83,8 @@ def __call__( _total_count = predictions.size()[0] if is_distributed(): - correct_count = torch.tensor(_correct_count).to(device) - total_count = torch.tensor(_total_count).to(device) + correct_count = torch.tensor(_correct_count, device=device) + total_count = torch.tensor(_total_count, device=device) dist.all_reduce(correct_count, op=dist.ReduceOp.SUM) dist.all_reduce(total_count, op=dist.ReduceOp.SUM) _correct_count = correct_count.item() diff --git a/benchmarks/nn/util_bench.py b/benchmarks/nn/util_bench.py new file mode 100644 index 00000000000..9f3d31e6a47 --- /dev/null +++ b/benchmarks/nn/util_bench.py @@ -0,0 +1,46 @@ +import torch + +from allennlp.nn import util +from allennlp.common.testing import requires_gpu + + +@requires_gpu +def bench_add_sentence_boundary_token_ids(benchmark): + device = torch.device("cuda") + # shape: (32, 50) + tensor = torch.tensor([[3] * 50] * 32, device=device) + # shape: (32, 50) + mask = torch.tensor([[True] * 50, [True] * 30 + [False] * 20] * 16, device=device) + begin_token = 1 + end_token = 2 + benchmark(util.add_sentence_boundary_token_ids, tensor, mask, begin_token, end_token) + + +@requires_gpu +def bench_remove_sentence_boundaries(benchmark): + device = torch.device("cuda") + # shape: (32, 50, 1) + tensor = torch.tensor([[3] * 50] * 32, device=device).unsqueeze(-1) + # shape: (32, 50) + mask = torch.tensor([[True] * 50, [True] * 30 + [False] * 20] * 16, device=device) + benchmark(util.remove_sentence_boundaries, tensor, mask) + + +@requires_gpu +def bench_create_tensor_then_send_to_device(benchmark): + device = torch.device("cuda:0") + + def create_tensor(): + return torch.rand((32, 50)).to(device) + + benchmark(create_tensor) + + +@requires_gpu +def bench_create_tensor_directly_on_device(benchmark): + device = torch.device("cuda:0") + + def create_tensor(): + return torch.rand((32, 50), device=device) + + benchmark(create_tensor) diff --git a/benchmarks/pytest.ini b/benchmarks/pytest.ini index e465b28aad3..1a5dab6e5ea 100644 --- a/benchmarks/pytest.ini +++ b/benchmarks/pytest.ini @@ -6,3 +6,5 @@ python_files = *_bench.py python_functions = bench_* *_bench python_classes = +markers = + gpu: marks tests that need at least one GPU