@@ -145,6 +145,7 @@ def export_saved_model(model, im, file, dynamic,
145
145
inputs = keras .Input (shape = (* imgsz , 3 ), batch_size = None if dynamic else batch_size )
146
146
outputs = tf_model .predict (inputs , tf_nms , agnostic_nms , topk_per_class , topk_all , iou_thres , conf_thres )
147
147
keras_model = keras .Model (inputs = inputs , outputs = outputs )
148
+ keras_model .trainable = False
148
149
keras_model .summary ()
149
150
keras_model .save (f , save_format = 'tf' )
150
151
@@ -183,15 +184,17 @@ def export_tflite(keras_model, im, file, int8, data, ncalib, prefix=colorstr('Te
183
184
184
185
print (f'\n { prefix } starting export with tensorflow { tf .__version__ } ...' )
185
186
batch_size , ch , * imgsz = list (im .shape ) # BCHW
186
- f = file . with_suffix ( ' .tflite' )
187
+ f = str ( file ). replace ( '.pt' , '-fp16 .tflite' )
187
188
188
189
converter = tf .lite .TFLiteConverter .from_keras_model (keras_model )
189
190
converter .target_spec .supported_ops = [tf .lite .OpsSet .TFLITE_BUILTINS ]
191
+ converter .target_spec .supported_types = [tf .float16 ]
190
192
converter .optimizations = [tf .lite .Optimize .DEFAULT ]
191
193
if int8 :
192
194
dataset = LoadImages (check_dataset (data )['train' ], img_size = imgsz , auto = False ) # representative data
193
195
converter .representative_dataset = lambda : representative_dataset_gen (dataset , ncalib )
194
196
converter .target_spec .supported_ops = [tf .lite .OpsSet .TFLITE_BUILTINS_INT8 ]
197
+ converter .target_spec .supported_types = []
195
198
converter .inference_input_type = tf .uint8 # or tf.int8
196
199
converter .inference_output_type = tf .uint8 # or tf.int8
197
200
converter .experimental_new_quantizer = False
@@ -249,7 +252,7 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
249
252
# Load PyTorch model
250
253
device = select_device (device )
251
254
assert not (device .type == 'cpu' and half ), '--half only compatible with GPU export, i.e. use --device 0'
252
- model = attempt_load (weights , map_location = device , inplace = True , fuse = not any ( tf_exports ) ) # load FP32 model
255
+ model = attempt_load (weights , map_location = device , inplace = True , fuse = True ) # load FP32 model
253
256
nc , names = model .nc , model .names # number of classes, class names
254
257
255
258
# Input
0 commit comments