Skip to content

Commit

Permalink
GH-19: learning rate scheduler and additional metrics for sequence la…
Browse files Browse the repository at this point in the history
…beler
  • Loading branch information
aakbik authored and tabergma committed Jul 31, 2018
1 parent db1b4c2 commit 8afe853
Show file tree
Hide file tree
Showing 9 changed files with 79 additions and 127 deletions.
5 changes: 4 additions & 1 deletion flair/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def embedding_type(self) -> str:


class DocumentEmbeddings(Embeddings):
"""Abstract base class for all document-level embeddings. Ever new type of document embedding must implement these methods."""
"""Abstract base class for all token-level embeddings. Ever new type of word embedding must implement these methods."""

@property
@abstractmethod
Expand Down Expand Up @@ -208,6 +208,9 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
else:
word_embedding = np.zeros(self.embedding_length, dtype='float')

# if torch.cuda.is_available():
# word_embedding = torch.cuda.FloatTensor(word_embedding)
# else:
word_embedding = torch.FloatTensor(word_embedding)

token.set_embedding(self.name, word_embedding)
Expand Down
4 changes: 4 additions & 0 deletions flair/models/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ def __init__(self,

self.init_weights()

# auto-spawn on GPU if available
if torch.cuda.is_available():
self.cuda()

def init_weights(self):
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
Expand Down
7 changes: 5 additions & 2 deletions flair/models/sequence_tagger_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,6 @@ def __init__(self,
dropout=0.5,
bidirectional=True)

self.nonlinearity = nn.Tanh()

# final linear map to tag space
if self.use_rnn:
self.linear = nn.Linear(hidden_size * 2, len(tag_dictionary))
Expand All @@ -103,6 +101,11 @@ def __init__(self,
self.transitions.data[self.tag_dictionary.get_idx_for_item(START_TAG), :] = -10000
self.transitions.data[:, self.tag_dictionary.get_idx_for_item(STOP_TAG)] = -10000

# auto-spawn on GPU if available
if torch.cuda.is_available():
self.cuda()


def save(self, model_file: str):
model_state = {
'state_dict': self.state_dict(),
Expand Down
4 changes: 4 additions & 0 deletions flair/models/text_classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def __init__(self,

self._init_weights()

# auto-spawn on GPU if available
if torch.cuda.is_available():
self.cuda()

def _init_weights(self):
nn.init.xavier_uniform_(self.decoder.weight)

Expand Down
127 changes: 49 additions & 78 deletions flair/trainers/sequence_tagger_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
import re
import sys
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau

from flair.models.sequence_tagger_model import SequenceTagger
from flair.data import Sentence, Token, TaggedCorpus
from flair.training_utils import Metric


class SequenceTaggerTrainer:
Expand All @@ -23,23 +25,25 @@ def train(self,
learning_rate: float = 0.1,
mini_batch_size: int = 32,
max_epochs: int = 100,
save_model: bool = True,
anneal_factor: float = 0.5,
patience: int = 3,
checkpoint: bool = False,
embeddings_in_memory: bool = True,
train_with_dev: bool = False,
anneal_mode: bool = False):
train_with_dev: bool = False):

checkpoint: bool = False
evaluation_method = 'F1'
if self.model.tag_type in ['ner', 'np', 'srl']: evaluation_method = 'span-F1'
if self.model.tag_type in ['pos', 'upos']: evaluation_method = 'accuracy'
print(evaluation_method)

evaluate_with_fscore: bool = True
if self.model.tag_type not in ['ner', 'np', 'srl']: evaluate_with_fscore = False
os.makedirs(base_path, exist_ok=True)

self.base_path = base_path
os.makedirs(self.base_path, exist_ok=True)

loss_txt = os.path.join(self.base_path, "loss.txt")
loss_txt = os.path.join(base_path, "loss.txt")
open(loss_txt, "w", encoding='utf-8').close()

optimizer = torch.optim.SGD(self.model.parameters(), lr=learning_rate)
scheduler: ReduceLROnPlateau = ReduceLROnPlateau(optimizer, verbose=True, factor=anneal_factor,
patience=patience)

train_data = self.corpus.train

Expand All @@ -50,19 +54,13 @@ def train(self,
# At any point you can hit Ctrl + C to break out of training early.
try:

# record overall best dev scores and best loss
best_score = 0
if train_with_dev: best_score = 10000
# best_dev_score = 0
# best_loss: float = 10000

# this variable is used for annealing schemes
epochs_without_improvement: int = 0

for epoch in range(0, max_epochs):

current_loss: int = 0

for group in optimizer.param_groups:
learning_rate = group['lr']

if not self.test_mode: random.shuffle(train_data)

batches = [train_data[x:x + mini_batch_size] for x in range(0, len(train_data), mini_batch_size)]
Expand Down Expand Up @@ -99,69 +97,33 @@ def train(self,

current_loss /= len(train_data)

# IMPORTANT: Switch to eval mode
# anneal against train loss
scheduler.step(current_loss)

# switch to eval mode
self.model.eval()

if not train_with_dev:
print('.. evaluating... dev... ')
dev_score, dev_fp, dev_result = self.evaluate(self.corpus.dev,
evaluate_with_fscore=evaluate_with_fscore,
dev_score, dev_fp, dev_result = self.evaluate(self.corpus.dev, base_path,
evaluation_method=evaluation_method,
embeddings_in_memory=embeddings_in_memory)
else:
dev_fp = 0
dev_result = '_'

print('test... ')
test_score, test_fp, test_result = self.evaluate(self.corpus.test,
evaluate_with_fscore=evaluate_with_fscore,
test_score, test_fp, test_result = self.evaluate(self.corpus.test, base_path,
evaluation_method=evaluation_method,
embeddings_in_memory=embeddings_in_memory)

# IMPORTANT: Switch back to train mode
self.model.train()

# checkpoint model
self.model.trained_epochs = epoch

# is this the best model so far?
is_best_model_so_far: bool = False

# if dev data is used for model selection, use dev F1 score to determine best model
if not train_with_dev and dev_score > best_score:
best_score = dev_score
is_best_model_so_far = True

# if dev data is used for training, use training loss to determine best model
if train_with_dev and current_loss < best_score:
best_score = current_loss
is_best_model_so_far = True

if is_best_model_so_far:

print('after %d - new best score: %f' % (epochs_without_improvement, best_score))

epochs_without_improvement = 0

# save model
if save_model or (anneal_mode and checkpoint):
self.model.save(base_path + "/model.pt")
print('.. model saved ... ')

else:
epochs_without_improvement += 1

# anneal after 3 epochs of no improvement if anneal mode
if epochs_without_improvement == 3 and anneal_mode:
best_score = current_loss
learning_rate /= 2

if checkpoint:
self.model = SequenceTagger.load_from_file(base_path + '/model.pt')

optimizer = torch.optim.SGD(self.model.parameters(), lr=learning_rate)

# print info
summary = '%d' % epoch + '\t({:%H:%M:%S})'.format(datetime.datetime.now()) \
+ '\t%f\t%d\t%f\tDEV %d\t' % (current_loss, epochs_without_improvement, learning_rate, dev_fp) + dev_result
+ '\t%f\t%d\t%f\tDEV %d\t' % (
current_loss, scheduler.num_bad_epochs, learning_rate, dev_fp) + dev_result
summary = summary.replace('\n', '')
summary += '\tTEST \t%d\t' % test_fp + test_result

Expand All @@ -170,19 +132,21 @@ def train(self,
loss_file.write('%s\n' % summary)
loss_file.close()

if checkpoint and scheduler.num_bad_epochs == 0:
self.model.save(base_path + "/checkpoint-model.pt")

self.model.save(base_path + "/final-model.pt")

except KeyboardInterrupt:
print('-' * 89)
print('Exiting from training early')
print('saving model')
with open(base_path + "/final-model.pt", 'wb') as model_save_file:
torch.save(self.model, model_save_file, pickle_protocol=4)
model_save_file.close()
self.model.save(base_path + "/final-model.pt")
print('done')

def evaluate(self, evaluation: List[Sentence], evaluate_with_fscore: bool = True,
def evaluate(self, evaluation: List[Sentence], out_path=None, evaluation_method: str = 'F1',
embeddings_in_memory: bool = True):

tp: int = 0
fp: int = 0

Expand All @@ -191,6 +155,8 @@ def evaluate(self, evaluation: List[Sentence], evaluate_with_fscore: bool = True
batches = [evaluation[x:x + mini_batch_size] for x in
range(0, len(evaluation), mini_batch_size)]

metric = Metric('')

lines: List[str] = []

for batch in batches:
Expand All @@ -209,7 +175,6 @@ def evaluate(self, evaluation: List[Sentence], evaluate_with_fscore: bool = True
predicted_id = tag_seq
for (token, pred_id) in zip(sentence.tokens, predicted_id):
token: Token = token
# print(token)
# get the predicted tag
predicted_tag = self.model.tag_dictionary.get_item_for_index(pred_id)
token.add_tag('predicted', predicted_tag)
Expand All @@ -231,17 +196,17 @@ def evaluate(self, evaluation: List[Sentence], evaluate_with_fscore: bool = True
if not embeddings_in_memory:
self.clear_embeddings_in_batch(batch)

test_tsv = os.path.join(self.base_path, "test.tsv")
with open(test_tsv, "w", encoding='utf-8') as outfile:
outfile.write(''.join(lines))
if out_path != None:
test_tsv = os.path.join(out_path, "test.tsv")
with open(test_tsv, "w", encoding='utf-8') as outfile:
outfile.write(''.join(lines))

if evaluate_with_fscore:
if evaluation_method == 'span-F1':
eval_script = 'resources/tasks/eval_script'

eval_data = ''.join(lines)

p = run(eval_script, stdout=PIPE, input=eval_data, encoding='utf-8')
print(p.returncode)
main_result = p.stdout
print(main_result)

Expand All @@ -254,12 +219,18 @@ def evaluate(self, evaluation: List[Sentence], evaluate_with_fscore: bool = True
main_result = re.sub('accuracy', 'acc', main_result)

f_score = float(re.findall(r'\d+\.\d+$', main_result)[0])

return f_score, fp, main_result

precision: float = tp / (tp + fp)
if evaluation_method == 'accuracy':
score = metric.accuracy()
accuracy: float = tp / (tp + fp)
print(accuracy)
return score, fp, str(score)

return precision, fp, str(precision)
if evaluation_method == 'F1':
print(metric.accuracy())
score = metric.f_score()
return score, fp, str(metric)

def clear_embeddings_in_batch(self, batch: List[Sentence]):
for sentence in batch:
Expand Down
Loading

0 comments on commit 8afe853

Please sign in to comment.