-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmain.py
69 lines (58 loc) · 2.48 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import argparse
import json
import torch
import torch.distributed
import torch.multiprocessing
import process
import utils.pythonic
def main_worker(local_rank, global_rank, task_option, config):
# distribute config
if local_rank >= 0 and task_option.get('distributed', True):
torch.cuda.set_device(local_rank)
torch.distributed.init_process_group(backend='nccl',
init_method='tcp://127.0.0.1:23456',
world_size=global_rank,
rank=local_rank)
task_name = task_option.get('name')
task_object = utils.pythonic.get_attributes(process, task_name)
if not task_object:
return
task_kwargs = {}
model_name = task_option.get('model')
if model_name:
task_kwargs['model_option'] = config['model'][model_name]
task_kwargs['model_option']['distributed_parallel'] = task_option.get('distributed', True)
task_option.pop('model')
optim_name = task_option.get('optim')
if optim_name:
task_kwargs['optim_option'] = config['optim'][optim_name]
task_option.pop('optim')
datas_name = task_option.get('datas')
if datas_name:
task_kwargs['datas_option'] = config['datas'][datas_name]
task_option.pop('datas')
logger_name = task_option.get('logger')
if logger_name:
task_kwargs['logger_option'] = config['logger'][logger_name]
task_option.pop('logger')
task = task_object(**task_option, **task_kwargs, local_rank=local_rank)
task.run()
# distribute barrier
if global_rank > 1:
torch.distributed.barrier()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Siamese Diffusion.')
parser.add_argument('-c', '--config', type=str, required=True)
args = parser.parse_args()
with open(args.config, 'r') as fr:
config = json.load(fr)
for task_key, task_option in config['tasks'].items():
if not task_option['run']:
continue
global_rank = min(task_option['global_rank'], torch.cuda.device_count())
distributed = task_option.get('distributed', True)
task_option['global_rank'] = global_rank
if global_rank > 0 and distributed:
torch.multiprocessing.spawn(main_worker, nprocs=global_rank, args=(global_rank, task_option, config))
else:
main_worker(global_rank - 1, global_rank, task_option=task_option, config=config)