Skip to content

Commit

Permalink
Merge pull request #1095 from zalandoresearch/clone-optimization
Browse files Browse the repository at this point in the history
GH-1089: Optimize FlairEmbeddings depending on storage mode
  • Loading branch information
Alan Akbik authored Sep 13, 2019
2 parents 1bf72db + 1edf79a commit b9c0e44
Show file tree
Hide file tree
Showing 9 changed files with 27 additions and 17 deletions.
3 changes: 3 additions & 0 deletions flair/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
else:
device = torch.device("cpu")

# global variable: embedding_storage_mode
embedding_storage_mode = "default"

from . import data
from . import models
from . import visual
Expand Down
6 changes: 5 additions & 1 deletion flair/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -1878,7 +1878,11 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
if not self.fine_tune:
embedding = embedding.detach()

token.set_embedding(self.name, embedding.clone())
# only clone if optimization mode is 'gpu'
if flair.embedding_storage_mode == "gpu":
embedding = embedding.clone()

token.set_embedding(self.name, embedding)

all_hidden_states_in_lm = all_hidden_states_in_lm.detach()
del all_hidden_states_in_lm
Expand Down
4 changes: 2 additions & 2 deletions flair/models/sequence_tagger_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def evaluate(
self,
data_loader: DataLoader,
out_path: Path = None,
embeddings_storage_mode: str = "none",
embedding_storage_mode: str = "none",
) -> (Result, float):

with torch.no_grad():
Expand Down Expand Up @@ -320,7 +320,7 @@ def evaluate(
else:
metric.add_tn(tag)

store_embeddings(batch, embeddings_storage_mode)
store_embeddings(batch, embedding_storage_mode)

eval_loss /= batch_no

Expand Down
6 changes: 3 additions & 3 deletions flair/models/similarity_learning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def evaluate(
self,
data_loader: DataLoader,
out_path: Path = None,
embeddings_storage_mode="none",
embedding_storage_mode="none",
) -> (Result, float):
# assumes that for each data pair there's at least one embedding per modality

Expand All @@ -281,7 +281,7 @@ def evaluate(
all_target_embeddings.append(
self._embed_target(target_inputs).to(self.eval_device)
)
store_embeddings(data_points, embeddings_storage_mode)
store_embeddings(data_points, embedding_storage_mode)
all_target_embeddings = torch.cat(all_target_embeddings, dim=0) # [n0, d0]
assert len(target_index) == all_target_embeddings.shape[0]

Expand Down Expand Up @@ -315,7 +315,7 @@ def evaluate(
]
ranks.extend(batch_gt_ranks.tolist())

store_embeddings(data_points, embeddings_storage_mode)
store_embeddings(data_points, embedding_storage_mode)

ranks = np.array(ranks)
median_rank = np.median(ranks)
Expand Down
4 changes: 2 additions & 2 deletions flair/models/text_classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def evaluate(
self,
data_loader: DataLoader,
out_path: Path = None,
embeddings_storage_mode: str = "none",
embedding_storage_mode: str = "none",
) -> (Result, float):

with torch.no_grad():
Expand Down Expand Up @@ -238,7 +238,7 @@ def evaluate(
):
metric.add_tn(label)

store_embeddings(batch, embeddings_storage_mode)
store_embeddings(batch, embedding_storage_mode)

eval_loss /= batch_count

Expand Down
4 changes: 2 additions & 2 deletions flair/models/text_regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def evaluate(
self,
data_loader: DataLoader,
out_path: Path = None,
embeddings_storage_mode: str = "none",
embedding_storage_mode: str = "none",
) -> (Result, float):

with torch.no_grad():
Expand Down Expand Up @@ -137,7 +137,7 @@ def evaluate(
)
lines.append(eval_line)

store_embeddings(batch, embeddings_storage_mode)
store_embeddings(batch, embedding_storage_mode)

eval_loss /= total_count

Expand Down
4 changes: 2 additions & 2 deletions flair/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ def evaluate(
self,
data_loader: DataLoader,
out_path: Path = None,
embeddings_storage_mode: str = "none",
embedding_storage_mode: str = "none",
) -> (Result, float):
"""Evaluates the model. Returns a Result object containing evaluation
results and a loss value. Implement this to enable evaluation.
:param data_loader: DataLoader that iterates over dataset to be evaluated
:param out_path: Optional output path to store predictions
:param embeddings_storage_mode: One of 'none', 'cpu' or 'gpu'. 'none' means all embeddings are deleted and
:param embedding_storage_mode: One of 'none', 'cpu' or 'gpu'. 'none' means all embeddings are deleted and
freshly recomputed, 'cpu' means all embeddings are stored on CPU, or 'gpu' means all embeddings are stored on GPU
:return: Returns a Tuple consisting of a Result object and a loss float value
"""
Expand Down
10 changes: 5 additions & 5 deletions flair/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def train(
batch_size=eval_mini_batch_size,
num_workers=num_workers,
),
embeddings_storage_mode=embeddings_storage_mode,
embedding_storage_mode=embeddings_storage_mode,
)
result_line += f"\t{train_eval_result.log_line}"

Expand All @@ -341,7 +341,7 @@ def train(
batch_size=eval_mini_batch_size,
num_workers=num_workers,
),
embeddings_storage_mode=embeddings_storage_mode,
embedding_storage_mode=embeddings_storage_mode,
)
result_line += f"\t{dev_loss}\t{dev_eval_result.log_line}"
log.info(
Expand Down Expand Up @@ -371,7 +371,7 @@ def train(
num_workers=num_workers,
),
base_path / "test.tsv",
embeddings_storage_mode=embeddings_storage_mode,
embedding_storage_mode=embeddings_storage_mode,
)
result_line += f"\t{test_loss}\t{test_eval_result.log_line}"
log.info(
Expand Down Expand Up @@ -511,7 +511,7 @@ def final_test(
num_workers=num_workers,
),
out_path=base_path / "test.tsv",
embeddings_storage_mode="none",
embedding_storage_mode="none",
)

test_results: Result = test_results
Expand All @@ -530,7 +530,7 @@ def final_test(
num_workers=num_workers,
),
out_path=base_path / f"{subcorpus.name}-test.tsv",
embeddings_storage_mode="none",
embedding_storage_mode="none",
)

# get and return the final test score of best model
Expand Down
3 changes: 3 additions & 0 deletions flair/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,3 +365,6 @@ def store_embeddings(sentences: List[Sentence], storage_mode: str):
pin_memory = False if str(flair.device) == "cpu" else True
for sentence in sentences:
sentence.to("cpu", pin_memory=pin_memory)

# record current embedding storage mode to allow optimization (for instance in FlairEmbeddings class)
flair.embedding_storage_mode = storage_mode

0 comments on commit b9c0e44

Please sign in to comment.