Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Adds a "duplicate()" method on instances and fields (#4294)
Browse files Browse the repository at this point in the history
* modify behavior of deepcopy on TextField

* update CHANGELOG

* make a little more robust

* add 'duplicate' method

* update CHANGELOG

* add a test
  • Loading branch information
epwalsh authored May 27, 2020
1 parent 8ff47d3 commit 79999ec
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 6 deletions.
8 changes: 6 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Unreleased

### Fixed
- Nothing yet

- A bug where `TextField`s could not be duplicated since some tokenizers cannot be deep-copied.
See https://github.com/allenai/allennlp/issues/4270.

### Added
- Nothing yet

- A `duplicate()` method on `Instance`s and `Field`s, to be used instead of `copy.deepcopy()`.

### Changed

- Nothing yet

## [v1.0.0rc5](https://github.com/allenai/allennlp/releases/tag/v1.0.0rc5) - 2020-05-26
Expand Down
4 changes: 4 additions & 0 deletions allennlp/data/fields/field.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
from typing import Dict, Generic, List, TypeVar

import torch
Expand Down Expand Up @@ -120,3 +121,6 @@ def __eq__(self, other) -> bool:

def __len__(self):
raise NotImplementedError

def duplicate(self):
return deepcopy(self)
15 changes: 15 additions & 0 deletions allennlp/data/fields/text_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
standard word vectors, or pass through an LSTM.
"""
from collections import defaultdict
from copy import deepcopy
from typing import Dict, List, Optional, Iterator
import textwrap

Expand Down Expand Up @@ -153,3 +154,17 @@ def __getitem__(self, idx: int) -> Token:

def __len__(self) -> int:
return len(self.tokens)

@overrides
def duplicate(self):
"""
Overrides the behavior of `duplicate` so that `self._token_indexers` won't
actually be deep-copied.
Not only would it be extremely inefficient to deep-copy the token indexers,
but it also fails in many cases since some tokenizers (like those used in
the 'transformers' lib) cannot actually be deep-copied.
"""
new = TextField(deepcopy(self.tokens), {k: v for k, v in self._token_indexers.items()})
new._indexed_tokens = deepcopy(self._indexed_tokens)
return new
5 changes: 5 additions & 0 deletions allennlp/data/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,8 @@ def __str__(self) -> str:
return " ".join(
[base_string] + [f"\t {name}: {field} \n" for name, field in self.fields.items()]
)

def duplicate(self) -> "Instance":
new = Instance({k: field.duplicate() for k, field in self.fields.items()})
new.indexed = self.indexed
return new
3 changes: 1 addition & 2 deletions allennlp/predictors/sentence_tagger.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import List, Dict
from copy import deepcopy

from overrides import overrides
import numpy
Expand Down Expand Up @@ -105,7 +104,7 @@ def predictions_to_labeled_instances(
# Creates a new instance for each contiguous tag
instances = []
for labels in predicted_spans:
new_instance = deepcopy(instance)
new_instance = instance.duplicate()
text_field: TextField = instance["tokens"] # type: ignore
new_instance.add_field(
"tags", SequenceLabelField(labels, text_field), self._model.vocab
Expand Down
3 changes: 1 addition & 2 deletions allennlp/predictors/text_classifier.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from copy import deepcopy
from typing import List, Dict

from overrides import overrides
Expand Down Expand Up @@ -42,7 +41,7 @@ def _json_to_instance(self, json_dict: JsonDict) -> Instance:
def predictions_to_labeled_instances(
self, instance: Instance, outputs: Dict[str, numpy.ndarray]
) -> List[Instance]:
new_instance = deepcopy(instance)
new_instance = instance.duplicate()
label = numpy.argmax(outputs["probs"])
new_instance.add_field("label", LabelField(int(label), skip_indexing=True))
return [new_instance]
20 changes: 20 additions & 0 deletions tests/data/instance_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from allennlp.common.testing import AllenNlpTestCase
from allennlp.data import Instance
from allennlp.data.fields import TextField, LabelField
from allennlp.data.token_indexers import PretrainedTransformerIndexer
from allennlp.data.tokenizers import Token


Expand All @@ -20,3 +21,22 @@ def test_instance_implements_mutable_mapping(self):
values = [v for k, v in instance.items()]
assert words_field in values
assert label_field in values

def test_duplicate(self):
# Verify the `duplicate()` method works with a `PretrainedTransformerIndexer` in
# a `TextField`. See https://github.com/allenai/allennlp/issues/4270.
instance = Instance(
{
"words": TextField(
[Token("hello")], {"tokens": PretrainedTransformerIndexer("bert-base-uncased")}
)
}
)

other = instance.duplicate()
assert other == instance

# Adding new fields to the original instance should not effect the duplicate.
instance.add_field("labels", LabelField("some_label"))
assert "labels" not in other.fields
assert other != instance # sanity check on the '__eq__' method.

0 comments on commit 79999ec

Please sign in to comment.