Skip to content

Commit

Permalink
Merge pull request #22 from zalandoresearch/GH-19-simplify-sequence-t…
Browse files Browse the repository at this point in the history
…agger

GH-19: simplify sequence tagger
  • Loading branch information
tabergma authored Jul 31, 2018
2 parents f44f9e8 + 9852667 commit fda07a7
Show file tree
Hide file tree
Showing 9 changed files with 99 additions and 137 deletions.
2 changes: 1 addition & 1 deletion flair/data_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def read_conll_ud(path_to_conll_file: str) -> List[Sentence]:
if not "=" in morph: continue;
token.add_tag(morph.split('=')[0].lower(), morph.split('=')[1])

if str(fields[10]) == 'Y':
if len(fields) > 10 and str(fields[10]) == 'Y':
token.add_tag('frame', str(fields[11]))

sentence.add_token(token)
Expand Down
4 changes: 4 additions & 0 deletions flair/models/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ def __init__(self,

self.init_weights()

# auto-spawn on GPU if available
if torch.cuda.is_available():
self.cuda()

def init_weights(self):
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
Expand Down
4 changes: 4 additions & 0 deletions flair/models/sequence_tagger_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ def __init__(self,
self.transitions.data[self.tag_dictionary.get_idx_for_item(START_TAG), :] = -10000
self.transitions.data[:, self.tag_dictionary.get_idx_for_item(STOP_TAG)] = -10000

if torch.cuda.is_available():
self.cuda()


def save(self, model_file: str):
model_state = {
'state_dict': self.state_dict(),
Expand Down
4 changes: 4 additions & 0 deletions flair/models/text_classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def __init__(self,

self._init_weights()

# auto-spawn on GPU if available
if torch.cuda.is_available():
self.cuda()

def _init_weights(self):
nn.init.xavier_uniform_(self.decoder.weight)

Expand Down
155 changes: 70 additions & 85 deletions flair/trainers/sequence_tagger_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
import re
import sys
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau

from flair.models.sequence_tagger_model import SequenceTagger
from flair.data import Sentence, Token, TaggedCorpus
from flair.training_utils import Metric


class SequenceTaggerTrainer:
Expand All @@ -23,24 +25,28 @@ def train(self,
learning_rate: float = 0.1,
mini_batch_size: int = 32,
max_epochs: int = 100,
save_model: bool = True,
anneal_factor: float = 0.5,
patience: int = 2,
save_model: bool = False,
embeddings_in_memory: bool = True,
train_with_dev: bool = False,
anneal_mode: bool = False):
train_with_dev: bool = False):

checkpoint: bool = False
evaluation_method = 'F1'
if self.model.tag_type in ['ner', 'np', 'srl']: evaluation_method = 'span-F1'
if self.model.tag_type in ['pos', 'upos']: evaluation_method = 'accuracy'
print(evaluation_method)

evaluate_with_fscore: bool = True
if self.model.tag_type not in ['ner', 'np', 'srl']: evaluate_with_fscore = False
os.makedirs(base_path, exist_ok=True)

self.base_path = base_path
os.makedirs(self.base_path, exist_ok=True)

loss_txt = os.path.join(self.base_path, "loss.txt")
loss_txt = os.path.join(base_path, "loss.txt")
open(loss_txt, "w", encoding='utf-8').close()

optimizer = torch.optim.SGD(self.model.parameters(), lr=learning_rate)

anneal_mode = 'min' if train_with_dev else 'max'
scheduler: ReduceLROnPlateau = ReduceLROnPlateau(optimizer, factor=anneal_factor, patience=patience,
mode=anneal_mode)

train_data = self.corpus.train

# if training also uses dev data, include in training set
Expand All @@ -50,19 +56,13 @@ def train(self,
# At any point you can hit Ctrl + C to break out of training early.
try:

# record overall best dev scores and best loss
best_score = 0
if train_with_dev: best_score = 10000
# best_dev_score = 0
# best_loss: float = 10000

# this variable is used for annealing schemes
epochs_without_improvement: int = 0

for epoch in range(0, max_epochs):

current_loss: int = 0

for group in optimizer.param_groups:
learning_rate = group['lr']

if not self.test_mode: random.shuffle(train_data)

batches = [train_data[x:x + mini_batch_size] for x in range(0, len(train_data), mini_batch_size)]
Expand Down Expand Up @@ -99,69 +99,32 @@ def train(self,

current_loss /= len(train_data)

# IMPORTANT: Switch to eval mode
# switch to eval mode
self.model.eval()

if not train_with_dev:
print('.. evaluating... dev... ')
dev_score, dev_fp, dev_result = self.evaluate(self.corpus.dev,
evaluate_with_fscore=evaluate_with_fscore,
dev_score, dev_fp, dev_result = self.evaluate(self.corpus.dev, base_path,
evaluation_method=evaluation_method,
embeddings_in_memory=embeddings_in_memory)
else:
dev_fp = 0
dev_result = '_'

print('test... ')
test_score, test_fp, test_result = self.evaluate(self.corpus.test,
evaluate_with_fscore=evaluate_with_fscore,
test_score, test_fp, test_result = self.evaluate(self.corpus.test, base_path,
evaluation_method=evaluation_method,
embeddings_in_memory=embeddings_in_memory)

# IMPORTANT: Switch back to train mode
# switch back to train mode
self.model.train()

# checkpoint model
self.model.trained_epochs = epoch

# is this the best model so far?
is_best_model_so_far: bool = False

# if dev data is used for model selection, use dev F1 score to determine best model
if not train_with_dev and dev_score > best_score:
best_score = dev_score
is_best_model_so_far = True

# if dev data is used for training, use training loss to determine best model
if train_with_dev and current_loss < best_score:
best_score = current_loss
is_best_model_so_far = True

if is_best_model_so_far:

print('after %d - new best score: %f' % (epochs_without_improvement, best_score))

epochs_without_improvement = 0

# save model
if save_model or (anneal_mode and checkpoint):
self.model.save(base_path + "/model.pt")
print('.. model saved ... ')
# anneal against train loss if training with dev, otherwise anneal against dev score
scheduler.step(current_loss) if train_with_dev else scheduler.step(dev_score)

else:
epochs_without_improvement += 1

# anneal after 3 epochs of no improvement if anneal mode
if epochs_without_improvement == 3 and anneal_mode:
best_score = current_loss
learning_rate /= 2

if checkpoint:
self.model = SequenceTagger.load_from_file(base_path + '/model.pt')

optimizer = torch.optim.SGD(self.model.parameters(), lr=learning_rate)

# print info
summary = '%d' % epoch + '\t({:%H:%M:%S})'.format(datetime.datetime.now()) \
+ '\t%f\t%d\t%f\tDEV %d\t' % (current_loss, epochs_without_improvement, learning_rate, dev_fp) + dev_result
+ '\t%f\t%d\t%f\tDEV %d\t' % (
current_loss, scheduler.num_bad_epochs, learning_rate, dev_fp) + dev_result
summary = summary.replace('\n', '')
summary += '\tTEST \t%d\t' % test_fp + test_result

Expand All @@ -170,19 +133,23 @@ def train(self,
loss_file.write('%s\n' % summary)
loss_file.close()

self.model.save(base_path + "/final-model.pt")
# save if model is current best and we use dev data for model selection
if save_model and not train_with_dev and current_loss == scheduler.best:
self.model.save(base_path + "/best-model.pt")

# if we do not use dev data for model selection, save final model
if save_model and train_with_dev: self.model.save(base_path + "/final-model.pt")

except KeyboardInterrupt:
print('-' * 89)
print('Exiting from training early')
print('saving model')
with open(base_path + "/final-model.pt", 'wb') as model_save_file:
torch.save(self.model, model_save_file, pickle_protocol=4)
model_save_file.close()
self.model.save(base_path + "/final-model.pt")
print('done')

def evaluate(self, evaluation: List[Sentence], evaluate_with_fscore: bool = True,
def evaluate(self, evaluation: List[Sentence], out_path=None, evaluation_method: str = 'F1',
embeddings_in_memory: bool = True):

tp: int = 0
fp: int = 0

Expand All @@ -191,6 +158,8 @@ def evaluate(self, evaluation: List[Sentence], evaluate_with_fscore: bool = True
batches = [evaluation[x:x + mini_batch_size] for x in
range(0, len(evaluation), mini_batch_size)]

metric = Metric('')

lines: List[str] = []

for batch in batches:
Expand All @@ -209,7 +178,6 @@ def evaluate(self, evaluation: List[Sentence], evaluate_with_fscore: bool = True
predicted_id = tag_seq
for (token, pred_id) in zip(sentence.tokens, predicted_id):
token: Token = token
# print(token)
# get the predicted tag
predicted_tag = self.model.tag_dictionary.get_item_for_index(pred_id)
token.add_tag('predicted', predicted_tag)
Expand All @@ -219,10 +187,24 @@ def evaluate(self, evaluation: List[Sentence], evaluate_with_fscore: bool = True

# append both to file for evaluation
eval_line = token.text + ' ' + gold_tag + ' ' + predicted_tag + "\n"
if gold_tag == predicted_tag:
tp += 1
else:
fp += 1

# positives
if predicted_tag != '':
# true positives
if predicted_tag == gold_tag:
metric.tp()
# false positive
if predicted_tag != gold_tag:
metric.fp()

# negatives
if predicted_tag == '':
# true negative
if predicted_tag == gold_tag:
metric.tn()
# false negative
if predicted_tag != gold_tag:
metric.fn()

lines.append(eval_line)

Expand All @@ -231,17 +213,17 @@ def evaluate(self, evaluation: List[Sentence], evaluate_with_fscore: bool = True
if not embeddings_in_memory:
self.clear_embeddings_in_batch(batch)

test_tsv = os.path.join(self.base_path, "test.tsv")
with open(test_tsv, "w", encoding='utf-8') as outfile:
outfile.write(''.join(lines))
if out_path is not None:
test_tsv = os.path.join(out_path, "test.tsv")
with open(test_tsv, "w", encoding='utf-8') as outfile:
outfile.write(''.join(lines))

if evaluate_with_fscore:
if evaluation_method == 'span-F1':
eval_script = 'resources/tasks/eval_script'

eval_data = ''.join(lines)

p = run(eval_script, stdout=PIPE, input=eval_data, encoding='utf-8')
print(p.returncode)
main_result = p.stdout
print(main_result)

Expand All @@ -254,12 +236,15 @@ def evaluate(self, evaluation: List[Sentence], evaluate_with_fscore: bool = True
main_result = re.sub('accuracy', 'acc', main_result)

f_score = float(re.findall(r'\d+\.\d+$', main_result)[0])
return f_score, metric._fp, main_result

return f_score, fp, main_result

precision: float = tp / (tp + fp)
if evaluation_method == 'accuracy':
score = metric.accuracy()
return score, metric._fp, str(score)

return precision, fp, str(precision)
if evaluation_method == 'F1':
score = metric.f_score()
return score, metric._fp, str(metric)

def clear_embeddings_in_batch(self, batch: List[Sentence]):
for sentence in batch:
Expand Down
Loading

0 comments on commit fda07a7

Please sign in to comment.