diff --git a/flair/trainers/trainer.py b/flair/trainers/trainer.py index 257ac54606..4c5e13df2b 100644 --- a/flair/trainers/trainer.py +++ b/flair/trainers/trainer.py @@ -88,16 +88,16 @@ def find_learning_rate(self, f.write(f'{itr}\t{datetime.datetime.now():%H:%M:%S}\t{learning_rate}\t{loss_item}\n') if stop_early and loss_item > 4 * best_loss: - log_line() + log_line(log) log.info('loss diverged - stopping early!') break self.model.load_state_dict(model_state) self.model.to(model_device) - log_line() + log_line(log) log.info(f'learning rate finder finished - plot {learning_rate_tsv}') - log_line() + log_line(log) return Path(learning_rate_tsv) diff --git a/flair/training_utils.py b/flair/training_utils.py index be071570b5..3be4d8e43e 100644 --- a/flair/training_utils.py +++ b/flair/training_utils.py @@ -9,9 +9,6 @@ from functools import reduce -log = logging.getLogger(__name__) - - class Metric(object): def __init__(self, name): @@ -232,8 +229,5 @@ def convert_labels_to_one_hot(label_list: List[List[str]], label_dict: Dictionar return [[1 if l in labels else 0 for l in label_dict.get_items()] for labels in label_list] -def log_line(logger = None): - if logger is not None: - logger.info('-' * 100) - else: - log.info('-' * 100) +def log_line(logger): + logger.info('-' * 100)