Skip to content

Commit

Permalink
Merge pull request #11 from zalandoresearch/GH-10-text-classification
Browse files Browse the repository at this point in the history
GH-10: Added text classification model and trainer
  • Loading branch information
Alan Akbik authored Jul 27, 2018
2 parents 7018c1f + a341bf4 commit 5df34cb
Show file tree
Hide file tree
Showing 29 changed files with 1,386 additions and 365 deletions.
104 changes: 104 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
337 changes: 0 additions & 337 deletions flair/data.py

Large diffs are not rendered by default.

394 changes: 394 additions & 0 deletions flair/data_fetcher.py

Large diffs are not rendered by default.

37 changes: 24 additions & 13 deletions flair/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import os
import pickle
import re
import os
from abc import ABC, abstractmethod
from typing import List, Dict, Tuple
from abc import abstractmethod
from typing import List

import gensim
import numpy as np
import torch

from .file_utils import cached_path
from .language_model import RNNModel
from flair.models.language_model import RNNModel
from .data import Dictionary, Token, Sentence, TaggedCorpus
from .file_utils import cached_path


class TextEmbeddings(torch.nn.Module):
Expand Down Expand Up @@ -50,7 +50,7 @@ def embed(self, sentences: List[Sentence]) -> List[Sentence]:
return sentences

@abstractmethod
def _add_embeddings_internal(self, sentences: List[Sentence]):
def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
"""Private method for adding embeddings to all words in a list of sentences."""
pass

Expand Down Expand Up @@ -84,14 +84,14 @@ def embed(self, sentences: List[Sentence], static_embeddings: bool = True):
embedding.embed(sentences)

@property
def embedding_type(self):
def embedding_type(self) -> str:
return self.__embedding_type

@property
def embedding_length(self) -> int:
return self.__embedding_length

def _add_embeddings_internal(self, sentences: List[Sentence]):
def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:

for embedding in self.embeddings:
embedding._add_embeddings_internal(sentences)
Expand Down Expand Up @@ -413,6 +413,7 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:


class OnePassStoreEmbeddings(TextEmbeddings):

def __init__(self, embedding_stack: StackedEmbeddings, corpus: TaggedCorpus, detach: bool = True):
super().__init__()

Expand Down Expand Up @@ -564,14 +565,22 @@ def _add_embeddings_internal(self, sentences: List[Sentence]):
class TextLSTMEmbedder(TextEmbeddings):

def __init__(self, word_embeddings: List[TextEmbeddings], hidden_states=128, num_layers=1,
reproject_words: bool = True):
"""The constructor takes a list of embeddings to be combined."""
reproject_words: bool = True, bidirectional: bool = True):
"""The constructor takes a list of embeddings to be combined.
:param word_embeddings: a list of word embeddings
:param hidden_states: the number of hidden states in the lstm
:param num_layers: the number of layers for the lstm
:param reproject_words: boolean value, indicating whether to reproject the word embedding in a separate linear
layer before putting them into the lstm or not
:param bidirectional: boolean value, indicating whether to use a bidirectional lstm or not
"""
super().__init__()

# self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=word_embeddings)
self.embeddings: List[TextEmbeddings] = word_embeddings

self.reproject_words = reproject_words
self.bidirectional = bidirectional

self.length_of_all_word_embeddings = 0
for word_embedding in self.embeddings:
Expand All @@ -580,14 +589,16 @@ def __init__(self, word_embeddings: List[TextEmbeddings], hidden_states=128, num
self.name = 'text_lstm'
self.static_embeddings = False

# self.__embedding_length: int = hidden_states
self.__embedding_length: int = hidden_states * 2
if self.bidirectional:
self.__embedding_length: int = hidden_states * 2
else:
self.__embedding_length: int = hidden_states

# bidirectional LSTM on top of embedding layer
self.word_reprojection_map = torch.nn.Linear(self.length_of_all_word_embeddings,
self.length_of_all_word_embeddings)
self.rnn = torch.nn.LSTM(self.length_of_all_word_embeddings, hidden_states, num_layers=num_layers,
bidirectional=True)
bidirectional=self.bidirectional)
self.dropout = torch.nn.Dropout(0.5)

@property
Expand Down
Empty file added flair/models/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion flair/language_model.py → flair/models/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import math
from torch.autograd import Variable
from typing import Dict, List
from .data import Dictionary
from flair.data import Dictionary


class RNNModel(nn.Module):
Expand Down
4 changes: 2 additions & 2 deletions flair/tagging_model.py → flair/models/tagging_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import numpy as np

from flair.file_utils import cached_path
from .data import Dictionary, Sentence, Token
from .embeddings import TextEmbeddings
from flair.data import Dictionary, Sentence, Token
from flair.embeddings import TextEmbeddings

from typing import List, Tuple, Union

Expand Down
Loading

0 comments on commit 5df34cb

Please sign in to comment.