Skip to content

Commit

Permalink
Merge pull request #136 from huawei-noah/zjj_bug_progress_logger_time…
Browse files Browse the repository at this point in the history
…_error

bug: evaluation time error in progress logger and skip update record …
  • Loading branch information
zhangjiajin authored Aug 16, 2021
2 parents 828c211 + 1742a11 commit e3aa1c8
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 16 deletions.
2 changes: 1 addition & 1 deletion vega/evaluator/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class HostEvaluatorConfig(ConfigSerializable):

_type_name = ClassType.HOST_EVALUATOR
type = None
evaluate_latency = None
evaluate_latency = True
cuda = True
metric = {'type': 'accuracy'}
report_freq = 10
Expand Down
4 changes: 3 additions & 1 deletion vega/report/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,9 @@ def load_dict(self, src_dic):
for key, value in src_dic.items():
if key in ["original_rewards", "rewards"]:
continue
if isinstance(value, dict) and isinstance(getattr(self, key), dict):
update_flag = isinstance(value, dict) and isinstance(getattr(self, key), dict)
update_flag = update_flag and key not in ["desc"]
if update_flag:
for value_key, value_value in value.items():
getattr(self, key)[value_key] = value_value
else:
Expand Down
28 changes: 14 additions & 14 deletions vega/trainer/callbacks/progress_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

"""ProgressLogger call defination."""
import logging
import statistics
import time
import numpy as np
from collections.abc import Iterable
Expand Down Expand Up @@ -48,8 +49,8 @@ def before_train(self, logs=None):
self.train_verbose = 0
if self.valid_report_steps is None:
self.valid_verbose = 0
self.total_time_pre_reports = 0
self.total_time = 0
self.total_time_pre_reports = []
self.time_total_reports = []
logging.debug("Start the unified trainer ... ")
self.is_chief = self.params['is_chief']
self.do_validation = self.params['do_validation']
Expand All @@ -67,7 +68,7 @@ def before_epoch(self, epoch, logs=None):

def after_train_step(self, batch_index, logs=None):
"""Be called before each batch training."""
self.total_time_pre_reports += time.perf_counter() - self.step_start_time
self.total_time_pre_reports.append(time.perf_counter() - self.step_start_time)
if self.train_verbose >= 2 and self.is_chief \
and batch_index % self.train_report_steps == 0:
metrics_results = logs.get('train_step_metrics', None)
Expand All @@ -80,32 +81,31 @@ def after_train_step(self, batch_index, logs=None):
loss_avg = 0
logging.warning("Cant't get the loss, maybe the loss doesn't update in the metric evaluator.")

current_time = self.total_time_pre_reports / self.train_report_steps
mean_time = 0
not_perf_batch = 5
if batch_index // self.train_report_steps > not_perf_batch:
self.total_time += self.total_time_pre_reports
mean_time = self.total_time / (batch_index - not_perf_batch * self.train_report_steps)
self.total_time_pre_reports = 0
time_pre_batch = statistics.mean(self.total_time_pre_reports)
self.time_total_reports.append(sum(self.total_time_pre_reports))
time_pre_report = statistics.mean(self.time_total_reports) / self.train_report_steps
self.total_time_pre_reports.clear()

if metrics_results is not None:
log_info = "worker id [{}], epoch [{}/{}], train step {}, loss [{:8.3f}, {:8.3f}], " \
"lr [{:12.7f}, time [{:4.3f}], mean time [{:4.3f}s], train metrics {}"
"lr [{:12.7f}, time pre batch [{:4.3f}], total mean time per batch [{:4.3f}s]," \
" train metrics {}"
log_info = log_info.format(
self.trainer.worker_id,
self.cur_epoch + 1, self.trainer.epochs,
self._format_batch(batch_index, self.train_num_batches),
cur_loss, loss_avg, lr, current_time, mean_time,
cur_loss, loss_avg, lr, time_pre_batch, time_pre_report,
self._format_metrics(metrics_results))
logging.info(log_info)
else:
log_info = "worker id [{}], epoch [{}/{}], train step {}, loss [{:8.3f}, {:8.3f}], lr [{:12.7f}]" \
", time [{:4.3f}s] , mean time [{:4.3f}s]"
", time pre batch [{:4.3f}s] , total mean time per batch [{:4.3f}s]"
log_info = log_info.format(
self.trainer.worker_id,
self.cur_epoch + 1,
self.trainer.epochs,
self._format_batch(batch_index, self.train_num_batches),
cur_loss, loss_avg, lr, current_time, mean_time)
cur_loss, loss_avg, lr, time_pre_batch, time_pre_report)
logging.info(log_info)

def after_valid_step(self, batch_index, logs=None):
Expand Down

0 comments on commit e3aa1c8

Please sign in to comment.