Skip to content

Commit

Permalink
use tensor forward in dependency parser model
Browse files Browse the repository at this point in the history
  • Loading branch information
helpmefindaname committed Feb 21, 2022
1 parent c51deaa commit 5d02677
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 14 deletions.
30 changes: 17 additions & 13 deletions flair/models/dependency_parser_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,9 @@ def __init__(

self.to(flair.device)

def forward(self, sentences: List[Sentence]):
def _prepare_tensors(self, sentences: List[Sentence]) -> Tuple[torch.Tensor, torch.IntTensor]:
self.token_embeddings.embed(sentences)
batch_size = len(sentences)

lengths: List[int] = [len(sentence.tokens) for sentence in sentences]
seq_len: int = max(lengths)

Expand All @@ -119,7 +118,6 @@ def forward(self, sentences: List[Sentence]):
device=flair.device,
)

# embed sentences
all_embs = list()
for sentence in sentences:
all_embs += [emb for token in sentence for emb in token.get_each_embedding()]
Expand All @@ -136,16 +134,20 @@ def forward(self, sentences: List[Sentence]):
self.token_embeddings.embedding_length,
]
)
return sentence_tensor, torch.IntTensor(lengths)

def forward( # type: ignore[override]
self, sentence_tensor: torch.Tensor, lengths: torch.IntTensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# Main model implementation drops words and tags (independently), instead, we use word dropout!
if self.use_word_dropout:
sentence_tensor = self.word_dropout(sentence_tensor)

if self.use_rnn:
sentence_sequence = pack_padded_sequence(sentence_tensor, torch.IntTensor(lengths), True, False)
sentence_sequence = pack_padded_sequence(sentence_tensor, lengths, True, False)

sentence_sequence, _ = self.lstm(sentence_sequence)
sentence_tensor, _ = pad_packed_sequence(sentence_sequence, True, total_length=seq_len)
sentence_tensor, _ = pad_packed_sequence(sentence_sequence, True, total_length=sentence_tensor.size(1))

# apply MLPs for arc and relations to the BiLSTM output states
arc_h = self.mlp_arc_h(sentence_tensor)
Expand All @@ -161,10 +163,10 @@ def forward(self, sentences: List[Sentence]):

return score_arc, score_rel

def forward_loss(self, data_points: List[Sentence]) -> torch.Tensor:

score_arc, score_rel = self.forward(data_points)
loss_arc, loss_rel = self._calculate_loss(score_arc, score_rel, data_points)
def forward_loss(self, sentences: List[Sentence]) -> torch.Tensor:
sentence_tensor, lengths = self._prepare_tensors(sentences)
score_arc, score_rel = self.forward(sentence_tensor, lengths)
loss_arc, loss_rel = self._calculate_loss(score_arc, score_rel, sentences)
main_loss = loss_arc + loss_rel

return main_loss
Expand Down Expand Up @@ -225,15 +227,16 @@ def predict(

for batch in data_loader:
with torch.no_grad():
score_arc, score_rel = self.forward(batch)
sentence_tensor, lengths = self._prepare_tensors(batch)
score_arc, score_rel = self.forward(sentence_tensor, lengths)
arc_prediction, relation_prediction = self._obtain_labels_(score_arc, score_rel)

for sentnce_index, (sentence, sent_tags, sent_arcs) in enumerate(
for sentence_index, (sentence, sent_tags, sent_arcs) in enumerate(
zip(batch, relation_prediction, arc_prediction)
):

for token_index, (token, tag, head_id) in enumerate(zip(sentence.tokens, sent_tags, sent_arcs)):
token.add_tag(self.tag_type, tag, score_rel[sentnce_index][token_index])
token.add_tag(self.tag_type, tag, score_rel[sentence_index][token_index])

token.head_id = int(head_id)

Expand Down Expand Up @@ -274,7 +277,8 @@ def evaluate(
for batch in data_loader:
average_over += 1
with torch.no_grad():
score_arc, score_rel = self.forward(batch)
sentence_tensor, lengths = self._prepare_tensors(batch)
score_arc, score_rel = self.forward(sentence_tensor, lengths)
loss_arc, loss_rel = self._calculate_loss(score_arc, score_rel, batch)
arc_prediction, relation_prediction = self._obtain_labels_(score_arc, score_rel)

Expand Down
2 changes: 1 addition & 1 deletion flair/models/diagnosis/distance_prediction_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def __init__(
def label_type(self):
return "distance"

# forward allows only a single sentcence!!
# forward allows only a single sentence!!
def forward(self, sentence: Sentence):

# embed words of sentence
Expand Down

0 comments on commit 5d02677

Please sign in to comment.