From 6e96d7885bededd835e643cd5ea490efbd4df453 Mon Sep 17 00:00:00 2001 From: aakbik Date: Mon, 24 Jun 2019 07:47:47 +0200 Subject: [PATCH 1/2] GH-736: memory optimization for FlairEmbeddings --- flair/models/language_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flair/models/language_model.py b/flair/models/language_model.py index 9084d3f7a0..5f6fdecf84 100644 --- a/flair/models/language_model.py +++ b/flair/models/language_model.py @@ -130,7 +130,7 @@ def get_representation(self, strings: List[str], chars_per_chunk: int = 512): prediction, rnn_output, hidden = self.forward(batch, hidden) rnn_output = rnn_output.detach() - output_parts.append(rnn_output) + output_parts.append(rnn_output.to("cpu")) # concatenate all chunks to make final output output = torch.cat(output_parts) From 1c5fccfbc12600bc70e1dc2da3640cf7aafbde16 Mon Sep 17 00:00:00 2001 From: aakbik Date: Mon, 24 Jun 2019 07:48:38 +0200 Subject: [PATCH 2/2] GH-798: skip sentences without labels in CSV reader --- flair/datasets.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/flair/datasets.py b/flair/datasets.py index a56ec0baaf..cb0cef67d8 100644 --- a/flair/datasets.py +++ b/flair/datasets.py @@ -681,6 +681,22 @@ def __init__( for row in csv_reader: + # test if format is OK + wrong_format = False + for text_column in self.text_columns: + if text_column >= len(row): + wrong_format = True + + # test if at least one label given + has_label = False + for column in self.column_name_map: + if self.column_name_map[column].startswith("label") and row[column]: + has_label = True + break + + if wrong_format or not has_label: + continue + if self.in_memory: text = " || ".join(