Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Several micro optimizations (#4833)
Browse files Browse the repository at this point in the history
* benchmark transfers

* create tensors directl on device when possible

* fix
  • Loading branch information
epwalsh authored Dec 2, 2020
1 parent 48a4865 commit cec9209
Show file tree
Hide file tree
Showing 20 changed files with 89 additions and 39 deletions.
3 changes: 3 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@
**/__pycache__
.gitignore
.git
.coverage
.benchmarks
.mypy_cache
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ __pycache__

.coverage
.pytest_cache/
.benchmarks

# documentation build artifacts

Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed a lot of instances where tensors were first created and then sent to a device
with `.to(device)`. Instead, these tensors are now created directly on the target device.
- Fixed issue with `GradientDescentTrainer` when constructed with `validation_data_loader=None` and `learning_rate_scheduler!=None`.
- Fixed a bug when removing all handlers in root logger.
- `ShardedDatasetReader` now inherits parameters from `base_reader` when required.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _register_forward_hook(self, stdev: float):
def forward_hook(module, inputs, output):
# Random noise = N(0, stdev * (max-min))
scale = output.detach().max() - output.detach().min()
noise = torch.randn(output.shape).to(output.device) * stdev * scale
noise = torch.randn(output.shape, device=output.device) * stdev * scale

# Add the random noise
output.add_(noise)
Expand Down
2 changes: 1 addition & 1 deletion allennlp/modules/sampled_softmax_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def forward(

if embeddings.shape[0] == 0:
# empty batch
return torch.tensor(0.0).to(embeddings.device)
return torch.tensor(0.0, device=embeddings.device)

if not self.training:
return self._forward_eval(embeddings, targets)
Expand Down
2 changes: 0 additions & 2 deletions allennlp/nn/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1548,7 +1548,6 @@ def add_sentence_boundary_token_ids(
The new mask for the tensor, taking into account the appended tokens
marking the beginning and end of the sentence.
"""
# TODO: matthewp, profile this transfer
sequence_lengths = mask.sum(dim=1).detach().cpu().numpy()
tensor_shape = list(tensor.data.shape)
new_shape = list(tensor_shape)
Expand Down Expand Up @@ -1603,7 +1602,6 @@ def remove_sentence_boundaries(
new_mask : `torch.BoolTensor`
The new mask for the tensor of shape `(batch_size, timesteps - 2)`.
"""
# TODO: matthewp, profile this transfer
sequence_lengths = mask.sum(dim=1).detach().cpu().numpy()
tensor_shape = list(tensor.data.shape)
new_shape = list(tensor_shape)
Expand Down
4 changes: 2 additions & 2 deletions allennlp/training/metrics/attachment_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def __call__( # type: ignore
dist.all_reduce(unlabeled_exact_match, op=dist.ReduceOp.SUM)
dist.all_reduce(correct_labels_and_indices, op=dist.ReduceOp.SUM)
dist.all_reduce(labeled_exact_match, op=dist.ReduceOp.SUM)
total_sentences = torch.tensor(total_sentences).to(device)
total_words = torch.tensor(total_words).to(device)
total_sentences = torch.tensor(total_sentences, device=device)
total_words = torch.tensor(total_words, device=device)
dist.all_reduce(total_sentences, op=dist.ReduceOp.SUM)
dist.all_reduce(total_words, op=dist.ReduceOp.SUM)
total_sentences = total_sentences.item()
Expand Down
4 changes: 2 additions & 2 deletions allennlp/training/metrics/average.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def __call__(self, value):
_count = 1
if is_distributed():
device = torch.device("cuda" if dist.get_backend() == "nccl" else "cpu")
count = torch.tensor(_count).to(device)
total_value = torch.tensor(_total_value).to(device)
count = torch.tensor(_count, device=device)
total_value = torch.tensor(_total_value, device=device)
dist.all_reduce(count, op=dist.ReduceOp.SUM)
dist.all_reduce(total_value, op=dist.ReduceOp.SUM)
_count = count.item()
Expand Down
8 changes: 4 additions & 4 deletions allennlp/training/metrics/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ def __call__(
predictions, gold_targets, ngram_size
)
if is_distributed():
_precision_matches = torch.tensor(precision_matches).to(device)
_precision_totals = torch.tensor(precision_totals).to(device)
_precision_matches = torch.tensor(precision_matches, device=device)
_precision_totals = torch.tensor(precision_totals, device=device)
dist.all_reduce(_precision_matches, op=dist.ReduceOp.SUM)
dist.all_reduce(_precision_totals, op=dist.ReduceOp.SUM)
precision_matches = _precision_matches.item() / world_size
Expand All @@ -150,8 +150,8 @@ def __call__(
_reference_lengths = valid_gold_targets_mask.sum().item()

if is_distributed():
prediction_lengths = torch.tensor(_prediction_lengths).to(device)
reference_lengths = torch.tensor(_reference_lengths).to(device)
prediction_lengths = torch.tensor(_prediction_lengths, device=device)
reference_lengths = torch.tensor(_reference_lengths, device=device)
dist.all_reduce(prediction_lengths, op=dist.ReduceOp.SUM)
dist.all_reduce(reference_lengths, op=dist.ReduceOp.SUM)
_prediction_lengths = prediction_lengths.item()
Expand Down
8 changes: 4 additions & 4 deletions allennlp/training/metrics/covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ def __call__(

# # Note: this gives an approximate aggregation of the covariance.
# device = gold_labels.device
# delta_mean_prediction = torch.tensor(delta_mean_prediction).to(device)
# delta_mean_label = torch.tensor(delta_mean_label).to(device)
# delta_co_moment = torch.tensor(delta_co_moment).to(device)
# _total_count = torch.tensor(updated_count).to(device)
# delta_mean_prediction = torch.tensor(delta_mean_prediction, device=device)
# delta_mean_label = torch.tensor(delta_mean_label, device=device)
# delta_co_moment = torch.tensor(delta_co_moment, device=device)
# _total_count = torch.tensor(updated_count, device=device)
# dist.all_reduce(delta_mean_prediction, op=dist.ReduceOp.SUM)
# dist.all_reduce(delta_mean_label, op=dist.ReduceOp.SUM)
# dist.all_reduce(delta_co_moment, op=dist.ReduceOp.SUM)
Expand Down
2 changes: 1 addition & 1 deletion allennlp/training/metrics/entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __call__(
_count = 1

if is_distributed():
count = torch.tensor(_count).to(device)
count = torch.tensor(_count, device=device)
dist.all_reduce(_entropy, op=dist.ReduceOp.SUM)
dist.all_reduce(count, op=dist.ReduceOp.SUM)
_count = count.item()
Expand Down
6 changes: 3 additions & 3 deletions allennlp/training/metrics/evalb_bracketing_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,9 @@ def __call__(self, predicted_trees: List[Tree], gold_trees: List[Tree]) -> None:

if is_distributed():
device = torch.device("cuda" if dist.get_backend() == "nccl" else "cpu")
correct_predicted_brackets = torch.tensor(_correct_predicted_brackets).to(device)
predicted_brackets = torch.tensor(_predicted_brackets).to(device)
gold_brackets = torch.tensor(_gold_brackets).to(device)
correct_predicted_brackets = torch.tensor(_correct_predicted_brackets, device=device)
predicted_brackets = torch.tensor(_predicted_brackets, device=device)
gold_brackets = torch.tensor(_gold_brackets, device=device)
dist.all_reduce(correct_predicted_brackets, op=dist.ReduceOp.SUM)
dist.all_reduce(predicted_brackets, op=dist.ReduceOp.SUM)
dist.all_reduce(gold_brackets, op=dist.ReduceOp.SUM)
Expand Down
8 changes: 3 additions & 5 deletions allennlp/training/metrics/fbeta_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def __call__(
# Watch it:
# The total numbers of true positives under all _predicted_ classes are zeros.
if true_positives_bins.shape[0] == 0:
true_positive_sum = torch.zeros(num_classes, device=predictions.device)
true_positive_sum = torch.zeros(num_classes, device=device)
else:
true_positive_sum = torch.bincount(
true_positives_bins.long(), minlength=num_classes
Expand All @@ -154,7 +154,7 @@ def __call__(
if pred_bins.shape[0] != 0:
pred_sum = torch.bincount(pred_bins, minlength=num_classes).float()
else:
pred_sum = torch.zeros(num_classes, device=predictions.device)
pred_sum = torch.zeros(num_classes, device=device)

gold_labels_bins = gold_labels[mask].long()
if gold_labels.shape[0] != 0:
Expand All @@ -165,9 +165,7 @@ def __call__(
self._total_sum += mask.sum().to(torch.float)

if is_distributed():
true_positive_sum = torch.tensor(true_positive_sum).to(device)
pred_sum = torch.tensor(pred_sum).to(device)
true_sum = torch.tensor(true_sum).to(device)
true_positive_sum = torch.tensor(true_positive_sum, device=device)
dist.all_reduce(true_positive_sum, op=dist.ReduceOp.SUM)
dist.all_reduce(pred_sum, op=dist.ReduceOp.SUM)
dist.all_reduce(true_sum, op=dist.ReduceOp.SUM)
Expand Down
6 changes: 3 additions & 3 deletions allennlp/training/metrics/fbeta_multi_label_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,9 @@ def __call__(
self._total_sum += mask.expand_as(gold_labels).sum().to(torch.float)

if is_distributed():
true_positive_sum = torch.tensor(true_positive_sum).to(device)
pred_sum = torch.tensor(pred_sum).to(device)
true_sum = torch.tensor(true_sum).to(device)
true_positive_sum = torch.tensor(true_positive_sum, device=device)
pred_sum = torch.tensor(pred_sum, device=device)
true_sum = torch.tensor(true_sum, device=device)
dist.all_reduce(true_positive_sum, op=dist.ReduceOp.SUM)
dist.all_reduce(pred_sum, op=dist.ReduceOp.SUM)
dist.all_reduce(true_sum, op=dist.ReduceOp.SUM)
Expand Down
4 changes: 2 additions & 2 deletions allennlp/training/metrics/mean_absolute_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def __call__(
_absolute_error = torch.sum(absolute_errors)

if is_distributed():
absolute_error = torch.tensor(_absolute_error).to(device)
total_count = torch.tensor(_total_count).to(device)
absolute_error = torch.tensor(_absolute_error, device=device)
total_count = torch.tensor(_total_count, device=device)
dist.all_reduce(absolute_error, op=dist.ReduceOp.SUM)
dist.all_reduce(total_count, op=dist.ReduceOp.SUM)
_absolute_error = absolute_error.item()
Expand Down
10 changes: 5 additions & 5 deletions allennlp/training/metrics/rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def _get_rouge_l_score(

if is_distributed():
device = predicted_tokens.device
_total_f1 = torch.tensor(total_f1).to(device)
_total_f1 = torch.tensor(total_f1, device=device)
dist.all_reduce(_total_f1, op=dist.ReduceOp.SUM)
total_f1 = _total_f1.item()

Expand Down Expand Up @@ -162,9 +162,9 @@ def _get_rouge_n_stats(

if is_distributed():
device = predicted_tokens.device
_total_recall = torch.tensor(total_recall).to(device)
_total_precision = torch.tensor(total_precision).to(device)
_total_f1 = torch.tensor(total_f1).to(device)
_total_recall = torch.tensor(total_recall, device=device)
_total_precision = torch.tensor(total_precision, device=device)
_total_f1 = torch.tensor(total_f1, device=device)
dist.all_reduce(_total_recall, op=dist.ReduceOp.SUM)
dist.all_reduce(_total_precision, op=dist.ReduceOp.SUM)
dist.all_reduce(_total_f1, op=dist.ReduceOp.SUM)
Expand Down Expand Up @@ -209,7 +209,7 @@ def __call__(
sequence_count = len(predictions)
if is_distributed():
device = predictions.device
_sequence_count = torch.tensor(sequence_count).to(device)
_sequence_count = torch.tensor(sequence_count, device=device)
dist.all_reduce(_sequence_count, op=dist.ReduceOp.SUM)
sequence_count = _sequence_count.item()
self._total_sequence_count += sequence_count
Expand Down
4 changes: 2 additions & 2 deletions allennlp/training/metrics/sequence_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def __call__(
_correct_count = correct

if is_distributed():
correct_count = torch.tensor(_correct_count).to(device)
total_count = torch.tensor(_total_count).to(device)
correct_count = torch.tensor(_correct_count, device=device)
total_count = torch.tensor(_total_count, device=device)
dist.all_reduce(correct_count, op=dist.ReduceOp.SUM)
dist.all_reduce(total_count, op=dist.ReduceOp.SUM)
_correct_count = correct_count.item()
Expand Down
4 changes: 2 additions & 2 deletions allennlp/training/metrics/unigram_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def __call__(
_total_count = predictions.size()[0]

if is_distributed():
correct_count = torch.tensor(_correct_count).to(device)
total_count = torch.tensor(_total_count).to(device)
correct_count = torch.tensor(_correct_count, device=device)
total_count = torch.tensor(_total_count, device=device)
dist.all_reduce(correct_count, op=dist.ReduceOp.SUM)
dist.all_reduce(total_count, op=dist.ReduceOp.SUM)
_correct_count = correct_count.item()
Expand Down
46 changes: 46 additions & 0 deletions benchmarks/nn/util_bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import torch

from allennlp.nn import util
from allennlp.common.testing import requires_gpu


@requires_gpu
def bench_add_sentence_boundary_token_ids(benchmark):
device = torch.device("cuda")
# shape: (32, 50)
tensor = torch.tensor([[3] * 50] * 32, device=device)
# shape: (32, 50)
mask = torch.tensor([[True] * 50, [True] * 30 + [False] * 20] * 16, device=device)
begin_token = 1
end_token = 2
benchmark(util.add_sentence_boundary_token_ids, tensor, mask, begin_token, end_token)


@requires_gpu
def bench_remove_sentence_boundaries(benchmark):
device = torch.device("cuda")
# shape: (32, 50, 1)
tensor = torch.tensor([[3] * 50] * 32, device=device).unsqueeze(-1)
# shape: (32, 50)
mask = torch.tensor([[True] * 50, [True] * 30 + [False] * 20] * 16, device=device)
benchmark(util.remove_sentence_boundaries, tensor, mask)


@requires_gpu
def bench_create_tensor_then_send_to_device(benchmark):
device = torch.device("cuda:0")

def create_tensor():
return torch.rand((32, 50)).to(device)

benchmark(create_tensor)


@requires_gpu
def bench_create_tensor_directly_on_device(benchmark):
device = torch.device("cuda:0")

def create_tensor():
return torch.rand((32, 50), device=device)

benchmark(create_tensor)
2 changes: 2 additions & 0 deletions benchmarks/pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@
python_files = *_bench.py
python_functions = bench_* *_bench
python_classes =
markers =
gpu: marks tests that need at least one GPU

0 comments on commit cec9209

Please sign in to comment.