import argparse import yaml from src.trainer import ScatSimCLRTrainer, PretextTaskTrainer def main(args): mode = args.mode if mode not in ['unsupervised', 'pretext']: raise ValueError('Unsupported mode') config_path = args.config config = yaml.load(open(config_path, 'r'), Loader=yaml.FullLoader) if mode == 'unsupervised': trainer = ScatSimCLRTrainer(config) elif mode == 'pretext': trainer = PretextTaskTrainer(config) trainer.train() if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--mode', '-m', help='Training mode. `unsupervised` - run training only with contrastive loss, ' '`pretext` - run training with contrastive loss and pretext task', choices=['unsupervised', 'pretext']) parser.add_argument('--config', '-c', help='Path to config file') args = parser.parse_args() main(args)