This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Max-pooling Extractor * Changes according to coding conventions. * Edit Changelog MaxPoolingExtractor * Corrected Spelling Co-authored-by: lk4239e <[email protected]> Co-authored-by: Pete <[email protected]>
- Loading branch information
1 parent
515fe9b
commit a3d7125
Showing
4 changed files
with
317 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
131 changes: 131 additions & 0 deletions
131
allennlp/modules/span_extractors/max_pooling_span_extractor.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
import torch | ||
|
||
from allennlp.modules.span_extractors.span_extractor import SpanExtractor | ||
from allennlp.modules.span_extractors.span_extractor_with_span_width_embedding import ( | ||
SpanExtractorWithSpanWidthEmbedding, | ||
) | ||
from allennlp.nn import util | ||
from allennlp.nn.util import masked_max | ||
|
||
|
||
@SpanExtractor.register("max_pooling") | ||
class MaxPoolingSpanExtractor(SpanExtractorWithSpanWidthEmbedding): | ||
""" | ||
Represents spans through the application of a dimension-wise max-pooling operation. | ||
Given a span x_i, ..., x_j with i,j as span_start and span_end, each dimension d | ||
of the resulting span s is computed via s_d = max(x_id, ..., x_jd). | ||
Elements masked-out by sequence_mask are ignored when max-pooling is computed. | ||
Span representations of masked out span_indices by span_mask are set to '0.' | ||
Registered as a `SpanExtractor` with name "max_pooling". | ||
# Parameters | ||
input_dim : `int`, required. | ||
The final dimension of the `sequence_tensor`. | ||
num_width_embeddings : `int`, optional (default = `None`). | ||
Specifies the number of buckets to use when representing | ||
span width features. | ||
span_width_embedding_dim : `int`, optional (default = `None`). | ||
The embedding size for the span_width features. | ||
bucket_widths : `bool`, optional (default = `False`). | ||
Whether to bucket the span widths into log-space buckets. If `False`, | ||
the raw span widths are used. | ||
# Returns | ||
max_pooling_text_embeddings : `torch.FloatTensor`. | ||
A tensor of shape (batch_size, num_spans, input_dim), which each span representation | ||
is the result of a max-pooling operation. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
input_dim: int, | ||
num_width_embeddings: int = None, | ||
span_width_embedding_dim: int = None, | ||
bucket_widths: bool = False, | ||
) -> None: | ||
super().__init__( | ||
input_dim=input_dim, | ||
num_width_embeddings=num_width_embeddings, | ||
span_width_embedding_dim=span_width_embedding_dim, | ||
bucket_widths=bucket_widths, | ||
) | ||
|
||
def get_output_dim(self) -> int: | ||
if self._span_width_embedding is not None: | ||
return self._input_dim + self._span_width_embedding.get_output_dim() | ||
return self._input_dim | ||
|
||
def _embed_spans( | ||
self, | ||
sequence_tensor: torch.FloatTensor, | ||
span_indices: torch.LongTensor, | ||
sequence_mask: torch.BoolTensor = None, | ||
span_indices_mask: torch.BoolTensor = None, | ||
) -> torch.FloatTensor: | ||
|
||
if sequence_tensor.size(-1) != self._input_dim: | ||
raise ValueError( | ||
f"Dimension mismatch expected ({sequence_tensor.size(-1)}) " | ||
f"received ({self._input_dim})." | ||
) | ||
|
||
if sequence_tensor.shape[1] <= span_indices.max() or span_indices.min() < 0: | ||
raise IndexError( | ||
f"Span index out of range, max index ({span_indices.max()}) " | ||
f"or min index ({span_indices.min()}) " | ||
f"not valid for sequence of length ({sequence_tensor.shape[1]})." | ||
) | ||
|
||
if (span_indices[:, :, 0] > span_indices[:, :, 1]).any(): | ||
raise IndexError( | ||
"Span start above span end", | ||
) | ||
|
||
# Calculate the maximum sequence length for each element in batch. | ||
# If span_end indices are above these length, we adjust the indices in adapted_span_indices | ||
if sequence_mask is not None: | ||
# shape (batch_size) | ||
sequence_lengths = util.get_lengths_from_binary_sequence_mask(sequence_mask) | ||
else: | ||
# shape (batch_size), filled with the sequence length size of the sequence_tensor. | ||
sequence_lengths = torch.ones_like( | ||
sequence_tensor[:, 0, 0], dtype=torch.long | ||
) * sequence_tensor.size(1) | ||
|
||
adapted_span_indices = torch.tensor(span_indices, device=span_indices.device) | ||
|
||
for b in range(sequence_lengths.shape[0]): | ||
adapted_span_indices[b, :, 1][adapted_span_indices[b, :, 1] >= sequence_lengths[b]] = ( | ||
sequence_lengths[b] - 1 | ||
) | ||
|
||
# Raise Error if span indices were completely masked by sequence mask. | ||
# We only adjust span_end to the last valid index, so if span_end is below span_start, | ||
# both were above the max index: | ||
|
||
if (adapted_span_indices[:, :, 0] > adapted_span_indices[:, :, 1]).any(): | ||
raise IndexError( | ||
"Span indices were masked out entirely by sequence mask", | ||
) | ||
|
||
# span_vals <- (batch x num_spans x max_span_length x dim) | ||
span_vals, span_mask = util.batched_span_select(sequence_tensor, adapted_span_indices) | ||
|
||
# The application of masked_max requires a mask of the same shape as span_vals | ||
# We repeat the mask along the last dimension (embedding dimension) | ||
repeat_dim = len(span_vals.shape) - 1 | ||
repeat_idx = [1] * (repeat_dim) + [span_vals.shape[-1]] | ||
|
||
# ext_span_mask <- (batch x num_spans x max_span_length x dim) | ||
# ext_span_mask True for values in span, False for masked out values | ||
ext_span_mask = span_mask.unsqueeze(repeat_dim).repeat(repeat_idx) | ||
|
||
# max_out <- (batch x num_spans x dim) | ||
max_out = masked_max(span_vals, ext_span_mask, dim=-2) | ||
|
||
return max_out |
184 changes: 184 additions & 0 deletions
184
tests/modules/span_extractors/max_pooling_span_extractor_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
import pytest | ||
import torch | ||
|
||
from allennlp.common.params import Params | ||
from allennlp.modules.span_extractors import SpanExtractor | ||
from allennlp.modules.span_extractors.max_pooling_span_extractor import MaxPoolingSpanExtractor | ||
|
||
|
||
class TestMaxPoolingSpanExtractor: | ||
def test_locally_span_extractor_can_build_from_params(self): | ||
params = Params( | ||
{ | ||
"type": "max_pooling", | ||
"input_dim": 3, | ||
"num_width_embeddings": 5, | ||
"span_width_embedding_dim": 3, | ||
} | ||
) | ||
extractor = SpanExtractor.from_params(params) | ||
assert isinstance(extractor, MaxPoolingSpanExtractor) | ||
assert extractor.get_output_dim() == 6 | ||
|
||
def test_max_values_extracted(self): | ||
# Test if max_pooling is correctly applied | ||
# We use a high dimensional random vector and assume that a randomly correct result is too unlikely | ||
sequence_tensor = torch.randn([2, 10, 30]) | ||
extractor = MaxPoolingSpanExtractor(30) | ||
|
||
indices = torch.LongTensor([[[1, 1], [2, 4], [9, 9]], [[0, 1], [4, 4], [0, 9]]]) | ||
span_representations = extractor(sequence_tensor, indices) | ||
|
||
assert list(span_representations.size()) == [2, 3, 30] | ||
assert extractor.get_output_dim() == 30 | ||
assert extractor.get_input_dim() == 30 | ||
|
||
# We iterate over the tensor to compare the span extractors's results | ||
# with the results of python max operation over each dimension for each span and for each batch | ||
# For each batch | ||
for batch, X in enumerate(indices): | ||
# For each defined span index | ||
for indices_ind, span_def in enumerate(X): | ||
|
||
# original features of current tested span | ||
# span_width x embedding dim (30) | ||
span_features_complete = sequence_tensor[batch][span_def[0] : span_def[1] + 1] | ||
|
||
# comparison for each dimension | ||
for i in range(extractor.get_output_dim()): | ||
# get the features for dimension i of current span | ||
features_from_span = span_features_complete[:, i] | ||
real_max_value = max(features_from_span) | ||
|
||
extracted_max_value = span_representations[batch, indices_ind, i] | ||
|
||
assert real_max_value == extracted_max_value, ( | ||
f"Error extracting max value for " | ||
f"batch {batch}, span {indices_ind} on dimension {i}." | ||
f"expected {real_max_value} " | ||
f"but got {extracted_max_value} which is " | ||
f"not the maximum element." | ||
) | ||
|
||
def test_sequence_mask_correct_excluded(self): | ||
# Check if span indices masked out by the sequence mask are ignored when computing | ||
# the span representations. For this test span_start is valid, but span_end is masked out. | ||
|
||
sequence_tensor = torch.randn([2, 6, 30]) | ||
|
||
extractor = MaxPoolingSpanExtractor(30) | ||
indices = torch.LongTensor([[[1, 1], [3, 5], [2, 5]], [[0, 0], [0, 3], [4, 5]]]) | ||
# define sequence mak | ||
seq_mask = torch.BoolTensor([[True] * 4 + [False] * 2, [True] * 5 + [False] * 1]) | ||
|
||
span_representations = extractor(sequence_tensor, indices, sequence_mask=seq_mask) | ||
|
||
# After we computed the representations we set values to -inf | ||
# to compute the "real" max-pooling with python's max function. | ||
sequence_tensor[torch.logical_not(seq_mask)] = float("-inf") | ||
|
||
# Comparison is similar to test_max_values_extracted | ||
for batch, X in enumerate(indices): | ||
for indices_ind, span_def in enumerate(X): | ||
|
||
span_features_complete = sequence_tensor[batch][span_def[0] : span_def[1] + 1] | ||
|
||
for i, _ in enumerate(span_features_complete): | ||
features_from_span = span_features_complete[:, i] | ||
real_max_value = max(features_from_span) | ||
extracted_max_value = span_representations[batch, indices_ind, i] | ||
|
||
assert real_max_value == extracted_max_value, ( | ||
f"Error extracting max value for " | ||
f"batch {batch}, span {indices_ind} on dimension {i}." | ||
f"expected {real_max_value} " | ||
f"but got {extracted_max_value} which is " | ||
f"not the maximum element." | ||
) | ||
|
||
def test_span_mask_correct_excluded(self): | ||
# All masked out span indices by span_mask should be '0' | ||
|
||
sequence_tensor = torch.randn([2, 6, 10]) | ||
|
||
extractor = MaxPoolingSpanExtractor(10) | ||
indices = torch.LongTensor([[[1, 1], [3, 5], [2, 5]], [[0, 0], [0, 3], [4, 5]]]) | ||
|
||
span_mask = torch.BoolTensor([[True] * 3, [False] * 3]) | ||
|
||
span_representations = extractor( | ||
sequence_tensor, | ||
indices, | ||
span_indices_mask=span_mask, | ||
) | ||
|
||
# The span-mask masks out all indices in the last batch | ||
# We check whether all span representations for this batch are '0' | ||
X = indices[-1] | ||
batch = -1 | ||
for indices_ind, span_def in enumerate(X): | ||
|
||
span_features_complete = sequence_tensor[batch][span_def[0] : span_def[1] + 1] | ||
|
||
for i, _ in enumerate(span_features_complete): | ||
real_max_value = torch.FloatTensor([0.0]) | ||
extracted_max_value = span_representations[batch, indices_ind, i] | ||
|
||
assert real_max_value == extracted_max_value, ( | ||
f"Error extracting max value for " | ||
f"batch {batch}, span {indices_ind} on dimension {i}." | ||
f"expected {real_max_value} " | ||
f"but got {extracted_max_value} which is " | ||
f"not the maximum element." | ||
) | ||
|
||
def test_inconsistent_extractor_dimension_throws_exception(self): | ||
|
||
sequence_tensor = torch.randn([2, 6, 10]) | ||
indices = torch.LongTensor([[[1, 1], [2, 4], [9, 9]], [[0, 1], [4, 4], [0, 9]]]) | ||
|
||
with pytest.raises(ValueError): | ||
extractor = MaxPoolingSpanExtractor(9) | ||
extractor(sequence_tensor, indices) | ||
|
||
with pytest.raises(ValueError): | ||
extractor = MaxPoolingSpanExtractor(11) | ||
extractor(sequence_tensor, indices) | ||
|
||
def test_span_indices_outside_sequence(self): | ||
|
||
sequence_tensor = torch.randn([2, 6, 10]) | ||
indices = torch.LongTensor([[[6, 6], [2, 4]], [[0, 1], [4, 4]]]) | ||
|
||
with pytest.raises(IndexError): | ||
extractor = MaxPoolingSpanExtractor(10) | ||
extractor(sequence_tensor, indices) | ||
|
||
indices = torch.LongTensor([[[5, 6], [2, 4]], [[0, 1], [4, 4]]]) | ||
|
||
with pytest.raises(IndexError): | ||
extractor = MaxPoolingSpanExtractor(10) | ||
extractor(sequence_tensor, indices) | ||
|
||
indices = torch.LongTensor([[[-1, 0], [2, 4]], [[0, 1], [4, 4]]]) | ||
|
||
with pytest.raises(IndexError): | ||
extractor = MaxPoolingSpanExtractor(10) | ||
extractor(sequence_tensor, indices) | ||
|
||
def test_span_start_below_span_end(self): | ||
|
||
sequence_tensor = torch.randn([2, 6, 10]) | ||
indices = torch.LongTensor([[[4, 2], [2, 4], [1, 1]], [[0, 1], [4, 4], [1, 1]]]) | ||
with pytest.raises(IndexError): | ||
extractor = MaxPoolingSpanExtractor(10) | ||
extractor(sequence_tensor, indices) | ||
|
||
def test_span_sequence_complete_masked(self): | ||
|
||
sequence_tensor = torch.randn([2, 6, 10]) | ||
seq_mask = torch.BoolTensor([[True] * 2 + [False] * 4, [True] * 3 + [False] * 3]) | ||
indices = torch.LongTensor([[[5, 5]], [[4, 5]]]) | ||
with pytest.raises(IndexError): | ||
extractor = MaxPoolingSpanExtractor(10) | ||
extractor(sequence_tensor, indices, sequence_mask=seq_mask) |