Skip to content

Commit

Permalink
GH-407: fix training on cuda without CRF
Browse files Browse the repository at this point in the history
  • Loading branch information
aakbik committed Feb 2, 2019
1 parent 1876f1f commit abc05ef
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions flair/models/sequence_tagger_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,10 +429,11 @@ def _calculate_loss(self, features, lengths, tags) -> float:
for sentence_feats, sentence_tags, sentence_length in zip(features, tags, lengths):
sentence_feats = sentence_feats[:sentence_length]

tag_tensor = torch.LongTensor(sentence_tags)
tag_tensor = tag_tensor.to(flair.device)
# print(sentence_tags)
# tag_tensor = torch.LongTensor(sentence_tags)
# tag_tensor = tag_tensor.to(flair.device)

score += torch.nn.functional.cross_entropy(sentence_feats, tag_tensor)
score += torch.nn.functional.cross_entropy(sentence_feats, sentence_tags)

return score

Expand Down

0 comments on commit abc05ef

Please sign in to comment.