-
-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathinteractive_logging.py
95 lines (83 loc) · 4.15 KB
/
interactive_logging.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from typing import List
from avalanche.evaluation.metric_results import MetricValue
from avalanche.evaluation.metric_utils import stream_type
from avalanche.logging.interactive_logging import InteractiveLogger
from tqdm import tqdm
from avalanche.training.templates.base import BaseTemplate
from avalanche_rl.logging.strategy_logger import RLStrategyLogger
class TqdmWriteInteractiveLogger(InteractiveLogger, RLStrategyLogger):
"""
Allows to print out stats to console while updating
progress bar whitout breaking it.
"""
def __init__(self, log_every: int = 1):
# print("__init__")
super().__init__()
self.log_every = log_every
self.step_counter: int = 0
def print_current_metrics(self):
# print("print_current_metrics")
sorted_vals = sorted(self.metric_vals.values(),
key=lambda x: x[0])
for name, x, val in sorted_vals:
val = self._val_to_str(val)
tqdm.write(f'\t{name} = {val}', file=self.file)
def before_training_exp(self, strategy: 'BaseTemplate',
metric_values: List['MetricValue'], **kwargs):
# print("before_training_exp")
super().before_training_exp(strategy, metric_values, **kwargs)
self._progress.total = strategy.current_experience_steps.value
def after_training_exp(self, strategy: 'BaseTemplate',
metric_values: List['MetricValue'], **kwargs):
# print("after_training_exp")
self._end_progress()
return super().after_training_exp(strategy, metric_values, **kwargs)
# ??????????????
def after_training_iteration(self, strategy: 'BaseTemplate',
metric_values: List['MetricValue'], **kwargs):
self._progress.update()
self._progress.refresh()
super().after_update(strategy, metric_values, **kwargs)
if self.step_counter % self.log_every == 0:
self.print_current_metrics()
self.metric_vals = {}
self.step_counter += 1
def before_eval(self, strategy: 'BaseTemplate',
metric_values: List['MetricValue'], **kwargs):
# print("before_eval")
self.metric_vals = {}
tqdm.write('\n-- >> Start of eval phase << --', file=self.file)
def before_eval_exp(self, strategy: 'BaseTemplate',
metric_values: List['MetricValue'], **kwargs):
# print("before_eval_exp")
# super().before_eval_exp(strategy, metric_values, **kwargs)
# self._progress.total = strategy.eval_exp_len
action_name = 'training' if strategy.is_training else 'eval'
exp_id = strategy.experience.current_experience
task_id = strategy.experience.task_label
stream = stream_type(strategy.experience)
tqdm.write('-- Starting {} on experience {} (Task {}) from {} stream --'
.format(action_name, exp_id, task_id, stream), file=self.file
)
def after_eval_exp(self, strategy: 'BaseTemplate',
metric_values: List['MetricValue'], **kwargs):
# print("after_eval_exp")
# for val in metric_values:
self.log_metrics(metric_values)
self.print_current_metrics()
exp_id = strategy.experience.current_experience
tqdm.write(f'> Eval on experience {exp_id} (Task '
f'{strategy.experience.task_label}) '
f'from {stream_type(strategy.experience)} stream ended.',
file=self.file)
def after_eval(self, strategy: 'BaseTemplate',
metric_values: List['MetricValue'], **kwargs):
tqdm.write('-- >> End of eval phase << --\n', file=self.file)
# self.print_current_metrics()
self.metric_vals = {}
def before_training(self, strategy: 'BaseTemplate',
metric_values: List['MetricValue'], **kwargs):
tqdm.write('-- >> Start of training phase << --', file=self.file)
def after_training(self, strategy: 'BaseTemplate',
metric_values: List['MetricValue'], **kwargs):
tqdm.write('-- >> End of training phase << --', file=self.file)