forked from 1zb/pytorch-image-comp-rnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathencoder.py
123 lines (97 loc) · 3.98 KB
/
encoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# -*- coding: utf-8 -*-
"""
Created on Tue Apr 27 08:51:44 2021
@author: Moaz Edmont
"""
import argparse
import numpy as np
from imageio import imread, imsave
#from scipy.misc import imread, imresize, imsave
#from skimage import io, color, data
#from PIL import Image
import torch
from torch.autograd import Variable
parser = argparse.ArgumentParser()
parser.add_argument(
'--model', '-m', required=True, type=str, help='path to model')
parser.add_argument(
'--input', '-i', required=True, type=str, help='input image')
parser.add_argument(
'--output', '-o', required=True, type=str, help='output codes')
parser.add_argument('--cuda', '-g', action='store_true', help='enables cuda')
parser.add_argument(
'--iterations', type=int, default=16, help='unroll iterations')
args = parser.parse_args()
#image = imread(args.input, mode='RGB')
image = imread('example.png', pilmode='RGB')
image = torch.from_numpy(
np.expand_dims(
np.transpose(image.astype(np.float32) / 255.0, (2, 0, 1)), 0))
batch_size, input_channels, height, width = image.size()
assert height % 32 == 0 and width % 32 == 0
image = Variable(image)
import network
encoder = network.EncoderCell()
binarizer = network.Binarizer()
decoder = network.DecoderCell()
encoder.eval()
binarizer.eval()
decoder.eval()
encoder.load_state_dict(torch.load(args.model))
binarizer.load_state_dict(
torch.load(args.model.replace('encoder', 'binarizer')))
decoder.load_state_dict(torch.load(args.model.replace('encoder', 'decoder')))
encoder_h_1 = (Variable(
torch.zeros(batch_size, 256, height // 4, width // 4)),
Variable(
torch.zeros(batch_size, 256, height // 4, width // 4)))
encoder_h_2 = (Variable(
torch.zeros(batch_size, 512, height // 8, width // 8)),
Variable(
torch.zeros(batch_size, 512, height // 8, width // 8)))
encoder_h_3 = (Variable(
torch.zeros(batch_size, 512, height // 16, width // 16)),
Variable(
torch.zeros(batch_size, 512, height // 16, width // 16)))
decoder_h_1 = (Variable(
torch.zeros(batch_size, 512, height // 16, width // 16)),
Variable(
torch.zeros(batch_size, 512, height // 16, width // 16)))
decoder_h_2 = (Variable(
torch.zeros(batch_size, 512, height // 8, width // 8)),
Variable(
torch.zeros(batch_size, 512, height // 8, width // 8)))
decoder_h_3 = (Variable(
torch.zeros(batch_size, 256, height // 4, width // 4)),
Variable(
torch.zeros(batch_size, 256, height // 4, width // 4)))
decoder_h_4 = (Variable(
torch.zeros(batch_size, 128, height // 2, width // 2)),
Variable(
torch.zeros(batch_size, 128, height // 2, width // 2)))
if args.cuda:
encoder = encoder.cuda()
binarizer = binarizer.cuda()
decoder = decoder.cuda()
image = image.cuda()
encoder_h_1 = (encoder_h_1[0].cuda(), encoder_h_1[1].cuda())
encoder_h_2 = (encoder_h_2[0].cuda(), encoder_h_2[1].cuda())
encoder_h_3 = (encoder_h_3[0].cuda(), encoder_h_3[1].cuda())
decoder_h_1 = (decoder_h_1[0].cuda(), decoder_h_1[1].cuda())
decoder_h_2 = (decoder_h_2[0].cuda(), decoder_h_2[1].cuda())
decoder_h_3 = (decoder_h_3[0].cuda(), decoder_h_3[1].cuda())
decoder_h_4 = (decoder_h_4[0].cuda(), decoder_h_4[1].cuda())
codes = []
res = image - 0.5
for iters in range(args.iterations):
encoded, encoder_h_1, encoder_h_2, encoder_h_3 = encoder(
res, encoder_h_1, encoder_h_2, encoder_h_3)
code = binarizer(encoded)
output, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4 = decoder(
code, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4)
res = res - output
codes.append(code.data.cpu().numpy())
print('Iter: {:02d}; Loss: {:.06f}'.format(iters, res.data.abs().mean()))
codes = (np.stack(codes).astype(np.int8) + 1) // 2
export = np.packbits(codes.reshape(-1))
np.savez_compressed(args.output, shape=codes.shape, codes=export)