forked from JiahuiLei/EFEM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathefem.py
137 lines (112 loc) · 4.22 KB
/
efem.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# this should be an elegant and almost final version
import os
import os.path as osp
import torch
import yaml
import numpy as np
import shutil
import logging
from lib_efem.data_utils import get_dataset
from lib_efem.misc import (
create_log_dir,
cfg_with_default,
setup_seed,
config_logging,
save_scannet_format,
)
from lib_efem.model_utils import load_models_dict
from lib_efem.database import load_database
from lib_efem.solver import Solver
import time
SEED_DEFAULT = 12345
def main(cfg, device, i=0, m=1, SEED=12345): # i, m are to split the dataset to several jobs
if SEED != SEED_DEFAULT:
cfg["log_dir"] = cfg["log_dir"] + f"_seed={SEED}"
log_resume_flag = cfg_with_default(cfg, ["log_resume"], False)
# create log
log_dir, viz_dir, bck_dir = create_log_dir(
osp.join(cfg["working_dir"], cfg["log_dir"]), resume=log_resume_flag
)
output_dir = osp.join(log_dir, "results")
os.makedirs(output_dir, exist_ok=True)
# load finished
if log_resume_flag:
finished_scene_id = [d[:-4] for d in os.listdir(output_dir) if d.endswith(".npz")]
else:
finished_scene_id = []
config_logging(osp.join(log_dir, "logs"), debug=False, log_fn="init.log")
os.system(f"cp {__file__} {bck_dir}")
os.system(f"cp -r lib_efem configs {bck_dir}")
logging.info(f"Use seed {SEED}")
# Load Model
model = load_models_dict(cfg, device)
dataset = get_dataset(cfg)
solver = Solver(cfg)
# build database if necessary
if cfg_with_default(cfg, ["use_database"], True):
database_dict = load_database(cfg)
else:
database_dict = None
iter = cfg_with_default(cfg, ["iter"], range(len(dataset)))
time_list = []
for scene_id in iter:
if scene_id % m != i:
continue # other job will handle this
if dataset.scene_id_list[scene_id] in finished_scene_id:
continue
config_logging(osp.join(log_dir, "logs"), debug=False, log_fn=f"{scene_id}.log")
logging.info("=" * max(shutil.get_terminal_size()[0] - 100, 30))
logging.info(f"scene_id {scene_id}")
logging.info("=" * max(shutil.get_terminal_size()[0] - 100, 30))
data_dict = dataset[scene_id]
data_dict = dataset.to_device(data_dict, device)
# solve
start_time = time.time()
ppt_dict, prop_cnt_statistic = solver.solve(
model_dict=model,
data_dict=data_dict,
database_dict=database_dict,
viz_prefix=f"{scene_id}_s{data_dict['scene']}",
viz_dir=viz_dir,
seed=SEED,
)
solver_time = time.time() - start_time
time_list.append(solver_time)
logging.info(f"Solver takes {solver_time:.3f}s")
# save solution
solution = {}
for k, v in ppt_dict.items():
solution[k] = v.fetch_output()
np.savez_compressed(
osp.join(output_dir, f"{data_dict['scene']}.npz"),
**solution,
meta_info={"prop_cnt_statistic": prop_cnt_statistic, "solver_time": solver_time},
)
# save to scannet format
save_scannet_format(
solution,
data_dict['scene'],
dst_dir=output_dir + "_eval",
scannet_flag=cfg_with_default(cfg, ["scannet_flag"], False),
)
time_list = np.array(time_list)
average_time = time_list.mean()
np.save(osp.join(log_dir, "time.npy"), time_list)
np.savetxt(osp.join(log_dir, "ave_time.txt"), [average_time])
logging.info(f"ave time = {average_time}")
if __name__ == "__main__":
import argparse
arg_parser = argparse.ArgumentParser(description="Run")
arg_parser.add_argument("--config", "-c", required=True)
arg_parser.add_argument("-i", type=int, default=0)
arg_parser.add_argument("-m", type=int, default=1)
arg_parser.add_argument("-s", type=int, default=SEED_DEFAULT)
args = arg_parser.parse_args()
seed = args.s
setup_seed(seed)
# * Note: All the path is relative to this files dir
cwd = osp.dirname(__file__)
with open(osp.join(cwd, args.config), "r") as f:
cfg = yaml.full_load(f)
cfg["working_dir"] = cwd
main(cfg, device=torch.device("cuda"), i=args.i, m=args.m, SEED=seed)