-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtestnew.py
53 lines (48 loc) · 2.11 KB
/
testnew.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import modelnew as model
import torch
import functions
import numpy
import os
import argparse
import scipy.io as sio
import re
def test_matRead(data):
data=data[None, :, :, :]
data=data.transpose(0,3,1,2)/2047.
data=torch.from_numpy(data)
data = data.to(torch.device('cuda:0')).type(torch.cuda.FloatTensor)
return data
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--mspath', help='test lrms image name', default='')# default='mspath')
parser.add_argument('--panpath', help='test hrpan image name', default='')# default='panpath')
parser.add_argument('--modelpath', help='output model dir', default='')# default='model/best.pth')
parser.add_argument('--saveimgpath', help='output model dir', default='')# default='result/')
parser.add_argument('--device', default=torch.device('cuda:0'))
opt = parser.parse_args()
net = model.Net().to(opt.device)
modelname = opt.modelpath
net.load_state_dict(torch.load(modelname))
num_params = sum(param.numel() for param in net.parameters())
print("Number of parameter: %.2fM" %(num_params/1e6))
with torch.no_grad():
for msfilename in os.listdir(opt.mspath):
# num = msfilename.split('m')[0]
num=re.split('[S|.]',msfilename)[1]
print(opt.mspath + msfilename)
# ms_val = io.imread(opt.mspath + msfilename)#'lrms.tif'
ms_val=sio.loadmat(opt.mspath + msfilename)['LRMS']#'lrms.mat'
ms_val = test_matRead(ms_val)
# panname = msfilename.split('m')[0]+'p.tif' #'pan.tif'
panname='PAN'+str(num)+'.mat'
# pval = io.imread(opt.panpath + panname)
pan_val=sio.loadmat(opt.panpath + panname)['PAN']#'pan.mat'
pan_val = pan_val[:, :, None]
pan_val = test_matRead(pan_val)
in_s = net(ms_val, pan_val)
outname = opt.saveimgpath + num +'.mat'
output=functions.convert_image_np(in_s.detach(), opt).astype(numpy.uint16)
# io.imsave(outname, convert)
sio.savemat(outname, {'result': output})
if __name__ == '__main__':
main()