-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathFusion.py
34 lines (31 loc) · 1.16 KB
/
Fusion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import os
from tqdm import tqdm
from time import time
from utils.utils import *
from dataset import TestData, imsave
def run_fusion(dataset_name='MSRS', model=None, save_dir=None):
if dataset_name == 'MSRS':
img_path = "./dataset/test/MSRS"
if save_dir is not None:
save_dir = './dataset/test/MSRS/fused'
elif dataset_name == 'RoadScene':
img_path = "./dataset/test/RoadScene"
if save_dir is not None:
save_dir = './dataset/test/RoadScene/fused'
ir_path = os.path.join(img_path, 'ir')
vi_path = os.path.join(img_path, 'vi')
os.makedirs(save_dir, exist_ok=True)
test_dataloader = TestData(ir_path, vi_path)
model.eval()
p_bar = tqdm(enumerate(test_dataloader), total=len(test_dataloader))
for idx, [ir, vi, name] in p_bar:
ir_tensor = ir.cuda()
vi_tensor = vi.cuda()
start = time()
with torch.no_grad():
fu = model.fusion_forward(ir_tensor,vi_tensor)
test_time = time() - start
imsave(fu, os.path.join(save_dir, name))
p_bar.set_description(f'fusing {name} | time : {str(test_time)}')
model.train()