Skip to content

Commit

Permalink
GH-19: anneal against dev score or train loss
Browse files Browse the repository at this point in the history
  • Loading branch information
aakbik authored and tabergma committed Jul 31, 2018
1 parent 0827948 commit 98baf77
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 14 deletions.
3 changes: 2 additions & 1 deletion flair/models/sequence_tagger_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def __init__(self,
dropout=0.5,
bidirectional=True)

self.nonlinearity = nn.Tanh()

# final linear map to tag space
if self.use_rnn:
self.linear = nn.Linear(hidden_size * 2, len(tag_dictionary))
Expand All @@ -101,7 +103,6 @@ def __init__(self,
self.transitions.data[self.tag_dictionary.get_idx_for_item(START_TAG), :] = -10000
self.transitions.data[:, self.tag_dictionary.get_idx_for_item(STOP_TAG)] = -10000

# auto-spawn on GPU if available
if torch.cuda.is_available():
self.cuda()

Expand Down
29 changes: 16 additions & 13 deletions flair/trainers/sequence_tagger_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def train(self,
mini_batch_size: int = 32,
max_epochs: int = 100,
anneal_factor: float = 0.5,
patience: int = 3,
checkpoint: bool = False,
patience: int = 2,
save_model: bool = False,
embeddings_in_memory: bool = True,
train_with_dev: bool = False):

Expand All @@ -42,8 +42,10 @@ def train(self,
open(loss_txt, "w", encoding='utf-8').close()

optimizer = torch.optim.SGD(self.model.parameters(), lr=learning_rate)
scheduler: ReduceLROnPlateau = ReduceLROnPlateau(optimizer, verbose=True, factor=anneal_factor,
patience=patience)

anneal_mode = 'min' if train_with_dev else 'max'
scheduler: ReduceLROnPlateau = ReduceLROnPlateau(optimizer, factor=anneal_factor, patience=patience,
mode=anneal_mode)

train_data = self.corpus.train

Expand Down Expand Up @@ -97,9 +99,6 @@ def train(self,

current_loss /= len(train_data)

# anneal against train loss
scheduler.step(current_loss)

# switch to eval mode
self.model.eval()

Expand All @@ -117,13 +116,15 @@ def train(self,
evaluation_method=evaluation_method,
embeddings_in_memory=embeddings_in_memory)

# IMPORTANT: Switch back to train mode
# switch back to train mode
self.model.train()

# print info
# anneal against train loss if training with dev, otherwise anneal against dev score
scheduler.step(current_loss) if train_with_dev else scheduler.step(dev_score)

summary = '%d' % epoch + '\t({:%H:%M:%S})'.format(datetime.datetime.now()) \
+ '\t%f\t%d\t%f\tDEV %d\t' % (
current_loss, scheduler.num_bad_epochs, learning_rate, dev_fp) + dev_result
current_loss, scheduler.num_bad_epochs, learning_rate, dev_fp) + dev_result
summary = summary.replace('\n', '')
summary += '\tTEST \t%d\t' % test_fp + test_result

Expand All @@ -132,10 +133,12 @@ def train(self,
loss_file.write('%s\n' % summary)
loss_file.close()

if checkpoint and scheduler.num_bad_epochs == 0:
self.model.save(base_path + "/checkpoint-model.pt")
# save if model is current best and we use dev data for model selection
if save_model and not train_with_dev and current_loss == scheduler.best:
self.model.save(base_path + "/best-model.pt")

self.model.save(base_path + "/final-model.pt")
# if we do not use dev data for model selection, save final model
if save_model and train_with_dev: self.model.save(base_path + "/final-model.pt")

except KeyboardInterrupt:
print('-' * 89)
Expand Down

0 comments on commit 98baf77

Please sign in to comment.