@@ -443,6 +443,7 @@ class AutoShape(nn.Module):
443
443
multi_label = False # NMS multiple labels per box
444
444
classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
445
445
max_det = 1000 # maximum number of detections per image
446
+ amp = False # Automatic Mixed Precision (AMP) inference
446
447
447
448
def __init__ (self , model ):
448
449
super ().__init__ ()
@@ -476,8 +477,9 @@ def forward(self, imgs, size=640, augment=False, profile=False):
476
477
477
478
t = [time_sync ()]
478
479
p = next (self .model .parameters ()) if self .pt else torch .zeros (1 ) # for device and type
480
+ autocast = self .amp and (p .device .type != 'cpu' ) # Automatic Mixed Precision (AMP) inference
479
481
if isinstance (imgs , torch .Tensor ): # torch
480
- with amp .autocast (enabled = p . device . type != 'cpu' ):
482
+ with amp .autocast (enabled = autocast ):
481
483
return self .model (imgs .to (p .device ).type_as (p ), augment , profile ) # inference
482
484
483
485
# Pre-process
@@ -506,7 +508,7 @@ def forward(self, imgs, size=640, augment=False, profile=False):
506
508
x = torch .from_numpy (x ).to (p .device ).type_as (p ) / 255 # uint8 to fp16/32
507
509
t .append (time_sync ())
508
510
509
- with amp .autocast (enabled = p . device . type != 'cpu' ):
511
+ with amp .autocast (enabled = autocast ):
510
512
# Inference
511
513
y = self .model (x , augment , profile ) # forward
512
514
t .append (time_sync ())
0 commit comments