Skip to content

Commit

Permalink
Merge pull request #1238 from zalandoresearch/GH-1115-warning-validat…
Browse files Browse the repository at this point in the history
…ion-data

GH-1115: Add warning if validation data too small
  • Loading branch information
Alan Akbik authored Oct 22, 2019
2 parents e223601 + 05e1ca3 commit 9cf0951
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 3 deletions.
1 change: 0 additions & 1 deletion flair/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,6 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
for i, sentence in enumerate(sentences):

language_code = sentence.get_language_code()
print(language_code)
supported = [
"en",
"de",
Expand Down
2 changes: 0 additions & 2 deletions flair/models/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,8 +354,6 @@ def generate_text(
if not self.is_forward_lm:
text = text[::-1]

text = text.encode("utf-8")

return text, log_prob

def calculate_perplexity(self, text: str) -> float:
Expand Down
7 changes: 7 additions & 0 deletions flair/trainers/language_model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,13 @@ def train(

val_data = self._batchify(self.corpus.valid, mini_batch_size)

# error message if the validation dataset is too small
if val_data.size(0) == 1:
raise RuntimeError(
f"ERROR: Your validation dataset is too small. For your mini_batch_size, the data needs to "
f"consist of at least {mini_batch_size * 2} characters!"
)

base_path.mkdir(parents=True, exist_ok=True)
loss_txt = base_path / "loss.txt"
savefile = base_path / "best-lm.pt"
Expand Down

0 comments on commit 9cf0951

Please sign in to comment.