Skip to content

Commit 5bdb28e

Browse files
authored
Default PyTorch Hub to autocast(False) (ultralytics#5926)
1 parent c77a5a8 commit 5bdb28e

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

models/common.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,7 @@ class AutoShape(nn.Module):
443443
multi_label = False # NMS multiple labels per box
444444
classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
445445
max_det = 1000 # maximum number of detections per image
446+
amp = False # Automatic Mixed Precision (AMP) inference
446447

447448
def __init__(self, model):
448449
super().__init__()
@@ -476,8 +477,9 @@ def forward(self, imgs, size=640, augment=False, profile=False):
476477

477478
t = [time_sync()]
478479
p = next(self.model.parameters()) if self.pt else torch.zeros(1) # for device and type
480+
autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
479481
if isinstance(imgs, torch.Tensor): # torch
480-
with amp.autocast(enabled=p.device.type != 'cpu'):
482+
with amp.autocast(enabled=autocast):
481483
return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
482484

483485
# Pre-process
@@ -506,7 +508,7 @@ def forward(self, imgs, size=640, augment=False, profile=False):
506508
x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
507509
t.append(time_sync())
508510

509-
with amp.autocast(enabled=p.device.type != 'cpu'):
511+
with amp.autocast(enabled=autocast):
510512
# Inference
511513
y = self.model(x, augment, profile) # forward
512514
t.append(time_sync())

0 commit comments

Comments
 (0)