Skip to content

Commit

Permalink
Merge pull request #1093 from pommedeterresautee/remove_cat_call
Browse files Browse the repository at this point in the history
Reduce the number of concatenation for 10% inference time reduction
  • Loading branch information
Alan Akbik authored Sep 13, 2019
2 parents 25861e1 + e767cd1 commit 1bf72db
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 8 deletions.
11 changes: 9 additions & 2 deletions flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,9 @@ def set_embedding(self, name: str, vector: torch.tensor):
device = flair.device
if len(self._embeddings.keys()) > 0:
device = next(iter(self._embeddings.values())).device
self._embeddings[name] = vector.to(device)
if device != vector.device:
vector = vector.to(device)
self._embeddings[name] = vector

def to(self, device: str, pin_memory: bool = False):
for name, vector in self._embeddings.items():
Expand All @@ -257,6 +259,9 @@ def clear_embeddings(self, embedding_names: List[str] = None):
if name in self._embeddings.keys():
del self._embeddings[name]

def get_each_embedding(self) -> torch.tensor:
return [self._embeddings[embed] for embed in sorted(self._embeddings.keys())]

def get_embedding(self) -> torch.tensor:
embeddings = [
self._embeddings[embed] for embed in sorted(self._embeddings.keys())
Expand Down Expand Up @@ -642,7 +647,9 @@ def set_embedding(self, name: str, vector: torch.tensor):
device = flair.device
if len(self._embeddings.keys()) > 0:
device = next(iter(self._embeddings.values())).device
self._embeddings[name] = vector.to(device, non_blocking=True)
if device != vector.device:
vector = vector.to(device)
self._embeddings[name] = vector

def get_embedding(self) -> torch.tensor:
embeddings = []
Expand Down
15 changes: 12 additions & 3 deletions flair/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2756,9 +2756,18 @@ def _add_embeddings_internal(self, sentences: Union[List[Sentence], Sentence]):

for s_id, sentence in enumerate(sentences):
# fill values with word embeddings
sentence_tensor[s_id][: len(sentence)] = torch.cat(
[token.get_embedding().unsqueeze(0) for token in sentence], 0
)
all_embs = list()

for index_token, token in enumerate(sentence):
embs = token.get_each_embedding()
if not all_embs:
all_embs = [list() for _ in range(len(embs))]
for index_emb, emb in enumerate(embs):
all_embs[index_emb].append(emb)

concat_word_emb = [torch.stack(embs) for embs in all_embs]
concat_sentence_emb = torch.cat(concat_word_emb, dim=1)
sentence_tensor[s_id][: len(sentence)] = concat_sentence_emb

# --------------------------------------------------------------------
# FF PART
Expand Down
15 changes: 12 additions & 3 deletions flair/models/sequence_tagger_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,9 +457,18 @@ def forward(self, sentences: List[Sentence]):
)

for s_id, sentence in enumerate(sentences):
# fill values with word embeddings
token_embeddings = [token.get_embedding() for token in sentence]
sentence_tensor[s_id][: len(sentence)] = torch.stack(token_embeddings)
all_embs = list()

for index_token, token in enumerate(sentence):
embs = token.get_each_embedding()
if not all_embs:
all_embs = [list() for _ in range(len(embs))]
for index_emb, emb in enumerate(embs):
all_embs[index_emb].append(emb)

concat_word_emb = [torch.stack(embs) for embs in all_embs]
concat_sentence_emb = torch.cat(concat_word_emb, dim=1)
sentence_tensor[s_id][: len(sentence)] = concat_sentence_emb

# --------------------------------------------------------------------
# FF PART
Expand Down

0 comments on commit 1bf72db

Please sign in to comment.