Skip to content

Commit

Permalink
support eval pre epoch (#11003)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyubo0722 authored Sep 26, 2023
1 parent e49e491 commit 4ba32bc
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
7 changes: 5 additions & 2 deletions tools/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def train(config,
eval_class,
pre_best_model_dict,
logger,
step_pre_epoch,
log_writer=None,
scaler=None,
amp_level='O2',
Expand All @@ -198,15 +199,17 @@ def train(config,
epoch_num = config['Global']['epoch_num']
print_batch_step = config['Global']['print_batch_step']
eval_batch_step = config['Global']['eval_batch_step']
eval_batch_epoch = config['Global'].get('eval_batch_epoch', None)
profiler_options = config['profiler_options']

global_step = 0
if 'global_step' in pre_best_model_dict:
global_step = pre_best_model_dict['global_step']
start_eval_step = 0
if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
start_eval_step = eval_batch_step[0]
eval_batch_step = eval_batch_step[1]
start_eval_step = eval_batch_step[0] if not eval_batch_epoch else 0
eval_batch_step = eval_batch_step[
1] if not eval_batch_epoch else step_pre_epoch * eval_batch_epoch
if len(valid_dataloader) == 0:
logger.info(
'No Images in eval dataset, evaluation during training ' \
Expand Down
13 changes: 8 additions & 5 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,11 @@ def main(config, device, logger, vdl_writer, seed):
return

if config['Eval']:
valid_dataloader = build_dataloader(config, 'Eval', device, logger, seed)
valid_dataloader = build_dataloader(config, 'Eval', device, logger,
seed)
else:
valid_dataloader = None
step_pre_epoch = len(train_dataloader)

# build post process
post_process_class = build_post_process(config['PostProcess'],
Expand Down Expand Up @@ -93,7 +95,8 @@ def main(config, device, logger, vdl_writer, seed):
'DistillationSARLoss'][
'ignore_index'] = char_num + 1
out_channels_list['SARLabelDecode'] = char_num + 2
elif any('DistillationNRTRLoss' in d for d in config['Loss']['loss_config_list']):
elif any('DistillationNRTRLoss' in d
for d in config['Loss']['loss_config_list']):
out_channels_list['NRTRLabelDecode'] = char_num + 3

config['Architecture']['Models'][key]['Head'][
Expand Down Expand Up @@ -196,9 +199,9 @@ def main(config, device, logger, vdl_writer, seed):
# start train
program.train(config, train_dataloader, valid_dataloader, device, model,
loss_class, optimizer, lr_scheduler, post_process_class,
eval_class, pre_best_model_dict, logger, vdl_writer, scaler,
amp_level, amp_custom_black_list, amp_custom_white_list,
amp_dtype)
eval_class, pre_best_model_dict, logger, step_pre_epoch,
vdl_writer, scaler, amp_level, amp_custom_black_list,
amp_custom_white_list, amp_dtype)


def test_reader(config, device, logger):
Expand Down

0 comments on commit 4ba32bc

Please sign in to comment.