Skip to content

Commit

Permalink
GH-1492: clean up code
Browse files Browse the repository at this point in the history
  • Loading branch information
alanakbik committed Apr 1, 2020
1 parent e787065 commit 745cc4f
Showing 1 changed file with 24 additions and 22 deletions.
46 changes: 24 additions & 22 deletions flair/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2210,13 +2210,16 @@ def __init__(
fine_tune: bool = False
):
"""
Bidirectional transformer embeddings of words, as proposed in Devlin et al., 2018.
:param bert_model_or_path: name of BERT model ('') or directory path containing custom model, configuration file
and vocab file (names of three files should be - config.json, pytorch_model.bin/model.chkpt, vocab.txt)
:param layers: string indicating which layers to take for embedding
:param pooling_operation: how to get from token piece embeddings to token embedding. Either pool them and take
the average ('mean') or use first word piece embedding as token embedding ('first)
:param document_only: set only document (sentence) emebddings
Bidirectional transformer embeddings of words from various transformer architectures.
:param model: name of transformer model (see https://huggingface.co/transformers/pretrained_models.html for
options)
:param layers: string indicating which layers to take for embedding (-1 is topmost layer)
:param pooling_operation: how to get from token piece embeddings to token embedding. Either take the first
subtoken ('first'), the last subtoken ('last'), both first and last ('first_last') or a mean over all ('mean')
:param batch_size: How many sentence to push through transformer at once. Set to 1 by default since transformer
models tend to be huge.
:param use_scalar_mix: If True, uses a scalar mix of layers as embedding
:param fine_tune: If True, allows transformers to be fine-tuned during training
"""
super().__init__()

Expand Down Expand Up @@ -2254,22 +2257,20 @@ def __init__(
self.begin_offset = 0

def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
"""Add embeddings to all words in a list of sentences. If embeddings are already added,
updates only if embeddings are non-static."""
"""Add embeddings to all words in a list of sentences."""

# using list comprehension
# split into micro batches of size self.batch_size before pushing through transformer
sentence_batches = [sentences[i * self.batch_size:(i + 1) * self.batch_size]
for i in range((len(sentences) + self.batch_size - 1) // self.batch_size)]

# embed each micro-batch
for batch in sentence_batches:

self._add_embeddings_to_sentences(batch)

return sentences

def _add_embeddings_to_sentences(self, sentences: List[Sentence]):
"""Add embeddings to all words in a list of sentences. If embeddings are already added,
updates only if embeddings are non-static."""
"""Match subtokenization to Flair tokenization and extract embeddings from transformers for each token."""

# first, subtokenize each sentence and find out into how many subtokens each token was divided
subtokenized_sentences = []
Expand Down Expand Up @@ -2423,13 +2424,14 @@ def __init__(
use_scalar_mix: bool = False,
):
"""
Bidirectional transformer embeddings of words, as proposed in Devlin et al., 2018.
:param bert_model_or_path: name of BERT model ('') or directory path containing custom model, configuration file
and vocab file (names of three files should be - config.json, pytorch_model.bin/model.chkpt, vocab.txt)
:param layers: string indicating which layers to take for embedding
:param pooling_operation: how to get from token piece embeddings to token embedding. Either pool them and take
the average ('mean') or use first word piece embedding as token embedding ('first)
:param document_only: set only document (sentence) emebddings
Bidirectional transformer embeddings of words from various transformer architectures.
:param model: name of transformer model (see https://huggingface.co/transformers/pretrained_models.html for
options)
:param fine_tune: If True, allows transformers to be fine-tuned during training
:param batch_size: How many sentence to push through transformer at once. Set to 1 by default since transformer
models tend to be huge.
:param layers: string indicating which layers to take for embedding (-1 is topmost layer)
:param use_scalar_mix: If True, uses a scalar mix of layers as embedding
"""
super().__init__()

Expand Down Expand Up @@ -2458,8 +2460,7 @@ def __init__(
self.initial_cls_token = True

def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
"""Add embeddings to all words in a list of sentences. If embeddings are already added,
updates only if embeddings are non-static."""
"""Add embeddings to all words in a list of sentences."""

# using list comprehension
sentence_batches = [sentences[i * self.batch_size:(i + 1) * self.batch_size]
Expand All @@ -2472,6 +2473,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
return sentences

def _add_embeddings_to_sentences(self, sentences: List[Sentence]):
"""Extract sentence embedding from CLS token or similar and add to Sentence object."""

# gradients are enabled if fine-tuning is enabled
gradient_context = torch.enable_grad() if (self.fine_tune and self.training) else torch.no_grad()
Expand Down

0 comments on commit 745cc4f

Please sign in to comment.