Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Adds a "duplicate()" method on instances and fields #4294

Merged
merged 6 commits into from
May 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Unreleased

### Fixed
- Nothing yet

- A bug where `TextField`s could not be duplicated since some tokenizers cannot be deep-copied.
See https://github.com/allenai/allennlp/issues/4270.

### Added
- Nothing yet

- A `duplicate()` method on `Instance`s and `Field`s, to be used instead of `copy.deepcopy()`.

### Changed

- Nothing yet

## [v1.0.0rc5](https://github.com/allenai/allennlp/releases/tag/v1.0.0rc5) - 2020-05-26
Expand Down
4 changes: 4 additions & 0 deletions allennlp/data/fields/field.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
from typing import Dict, Generic, List, TypeVar

import torch
Expand Down Expand Up @@ -120,3 +121,6 @@ def __eq__(self, other) -> bool:

def __len__(self):
raise NotImplementedError

def duplicate(self):
return deepcopy(self)
15 changes: 15 additions & 0 deletions allennlp/data/fields/text_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
standard word vectors, or pass through an LSTM.
"""
from collections import defaultdict
from copy import deepcopy
from typing import Dict, List, Optional, Iterator
import textwrap

Expand Down Expand Up @@ -153,3 +154,17 @@ def __getitem__(self, idx: int) -> Token:

def __len__(self) -> int:
return len(self.tokens)

@overrides
def duplicate(self):
"""
Overrides the behavior of `duplicate` so that `self._token_indexers` won't
actually be deep-copied.

Not only would it be extremely inefficient to deep-copy the token indexers,
but it also fails in many cases since some tokenizers (like those used in
the 'transformers' lib) cannot actually be deep-copied.
"""
new = TextField(deepcopy(self.tokens), {k: v for k, v in self._token_indexers.items()})
new._indexed_tokens = deepcopy(self._indexed_tokens)
return new
5 changes: 5 additions & 0 deletions allennlp/data/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,8 @@ def __str__(self) -> str:
return " ".join(
[base_string] + [f"\t {name}: {field} \n" for name, field in self.fields.items()]
)

def duplicate(self) -> "Instance":
new = Instance({k: field.duplicate() for k, field in self.fields.items()})
new.indexed = self.indexed
return new
3 changes: 1 addition & 2 deletions allennlp/predictors/sentence_tagger.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import List, Dict
from copy import deepcopy

from overrides import overrides
import numpy
Expand Down Expand Up @@ -105,7 +104,7 @@ def predictions_to_labeled_instances(
# Creates a new instance for each contiguous tag
instances = []
for labels in predicted_spans:
new_instance = deepcopy(instance)
new_instance = instance.duplicate()
text_field: TextField = instance["tokens"] # type: ignore
new_instance.add_field(
"tags", SequenceLabelField(labels, text_field), self._model.vocab
Expand Down
3 changes: 1 addition & 2 deletions allennlp/predictors/text_classifier.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from copy import deepcopy
from typing import List, Dict

from overrides import overrides
Expand Down Expand Up @@ -42,7 +41,7 @@ def _json_to_instance(self, json_dict: JsonDict) -> Instance:
def predictions_to_labeled_instances(
self, instance: Instance, outputs: Dict[str, numpy.ndarray]
) -> List[Instance]:
new_instance = deepcopy(instance)
new_instance = instance.duplicate()
label = numpy.argmax(outputs["probs"])
new_instance.add_field("label", LabelField(int(label), skip_indexing=True))
return [new_instance]
20 changes: 20 additions & 0 deletions tests/data/instance_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from allennlp.common.testing import AllenNlpTestCase
from allennlp.data import Instance
from allennlp.data.fields import TextField, LabelField
from allennlp.data.token_indexers import PretrainedTransformerIndexer
from allennlp.data.tokenizers import Token


Expand All @@ -20,3 +21,22 @@ def test_instance_implements_mutable_mapping(self):
values = [v for k, v in instance.items()]
assert words_field in values
assert label_field in values

def test_duplicate(self):
# Verify the `duplicate()` method works with a `PretrainedTransformerIndexer` in
# a `TextField`. See https://github.com/allenai/allennlp/issues/4270.
instance = Instance(
{
"words": TextField(
[Token("hello")], {"tokens": PretrainedTransformerIndexer("bert-base-uncased")}
)
}
)

other = instance.duplicate()
assert other == instance

# Adding new fields to the original instance should not effect the duplicate.
instance.add_field("labels", LabelField("some_label"))
assert "labels" not in other.fields
assert other != instance # sanity check on the '__eq__' method.