-
Notifications
You must be signed in to change notification settings - Fork 90
/
Copy pathckpt_pre.py
49 lines (31 loc) · 1.09 KB
/
ckpt_pre.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#coding:utf-8
from __future__ import print_function
import numpy as np
import cv2
from cv2 import dnn
import sys
import tensorflow as tf
from tensorflow.python.framework import graph_util
import os
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
#### data
img = cv2.imread( '1593250301105_1f43f6a0e8.png')
print("img shape: ", img.shape)
rows = img.shape[ 0]
cols = img.shape[ 1]
img = cv2.resize(img, ( 224, 224))
#img = img[:, :, [ 2, 1, 0]] # BGR2RGB
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = np.multiply(img, 1.0 / 255.0)
#### model
sess = tf.Session()
saver = tf.train.import_meta_graph('./keras_model.ckpt.meta') # 加载模型结构
saver.restore(sess, tf.train.latest_checkpoint('./')) # 只需要指定目录就可以恢复所有变量信息
# 获取placeholder变量
input_x = sess.graph.get_tensor_by_name('input_1:0')
# 获取需要进行计算的operator
op = sess.graph.get_tensor_by_name('dense_1/Softmax:0')
ret = sess.run(op,
feed_dict={ input_x: np.array([img],dtype = np.float32)})
print("ret: ",ret)