diff --git a/tools/program.py b/tools/program.py index 544d31d8b2..c543b7d85f 100755 --- a/tools/program.py +++ b/tools/program.py @@ -185,6 +185,7 @@ def train(config, eval_class, pre_best_model_dict, logger, + step_pre_epoch, log_writer=None, scaler=None, amp_level='O2', @@ -198,6 +199,7 @@ def train(config, epoch_num = config['Global']['epoch_num'] print_batch_step = config['Global']['print_batch_step'] eval_batch_step = config['Global']['eval_batch_step'] + eval_batch_epoch = config['Global'].get('eval_batch_epoch', None) profiler_options = config['profiler_options'] global_step = 0 @@ -205,8 +207,9 @@ def train(config, global_step = pre_best_model_dict['global_step'] start_eval_step = 0 if type(eval_batch_step) == list and len(eval_batch_step) >= 2: - start_eval_step = eval_batch_step[0] - eval_batch_step = eval_batch_step[1] + start_eval_step = eval_batch_step[0] if not eval_batch_epoch else 0 + eval_batch_step = eval_batch_step[ + 1] if not eval_batch_epoch else step_pre_epoch * eval_batch_epoch if len(valid_dataloader) == 0: logger.info( 'No Images in eval dataset, evaluation during training ' \ diff --git a/tools/train.py b/tools/train.py index cd6dd8c06b..faed388ec1 100755 --- a/tools/train.py +++ b/tools/train.py @@ -61,9 +61,11 @@ def main(config, device, logger, vdl_writer, seed): return if config['Eval']: - valid_dataloader = build_dataloader(config, 'Eval', device, logger, seed) + valid_dataloader = build_dataloader(config, 'Eval', device, logger, + seed) else: valid_dataloader = None + step_pre_epoch = len(train_dataloader) # build post process post_process_class = build_post_process(config['PostProcess'], @@ -93,7 +95,8 @@ def main(config, device, logger, vdl_writer, seed): 'DistillationSARLoss'][ 'ignore_index'] = char_num + 1 out_channels_list['SARLabelDecode'] = char_num + 2 - elif any('DistillationNRTRLoss' in d for d in config['Loss']['loss_config_list']): + elif any('DistillationNRTRLoss' in d + for d in config['Loss']['loss_config_list']): out_channels_list['NRTRLabelDecode'] = char_num + 3 config['Architecture']['Models'][key]['Head'][ @@ -196,9 +199,9 @@ def main(config, device, logger, vdl_writer, seed): # start train program.train(config, train_dataloader, valid_dataloader, device, model, loss_class, optimizer, lr_scheduler, post_process_class, - eval_class, pre_best_model_dict, logger, vdl_writer, scaler, - amp_level, amp_custom_black_list, amp_custom_white_list, - amp_dtype) + eval_class, pre_best_model_dict, logger, step_pre_epoch, + vdl_writer, scaler, amp_level, amp_custom_black_list, + amp_custom_white_list, amp_dtype) def test_reader(config, device, logger):