-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathimageshow.py
46 lines (39 loc) · 1.16 KB
/
imageshow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import numpy as np
import argparse
import logging
import os
import random
import numpy as np
from PIL import Image
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader,TensorDataset
import torch
import torch.nn as nn
from UHRNet import UHRNet
import matplotlib.pyplot as plt
fringepattern = torch.from_numpy(np.load('fringe_pattern.npy'))
gt = np.load('height_map.npy')
path = 'UHRNet_weight.pth'
checkpoint = torch.load(path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Drop the background part as an invalid point mask
def trans(input, gt):
for i in range(352):
for j in range(640):
if gt[i,j] <= -100:
input[i,j] = gt[i,j]
return input
############################
net = UHRNet().to(device)
net.load_state_dict(checkpoint['state_dict'])
net.eval()
with torch.no_grad():
fringepattern = fringepattern.to(device)
out = net(fringepattern.unsqueeze(0).unsqueeze(0))
out = out.detach().cpu().numpy()[0,0]
out = trans(out, gt)
plt.subplot(121)
plt.imshow(out, cmap='jet')
plt.subplot(122)
plt.imshow(gt, cmap='jet')
plt.show()