-
Notifications
You must be signed in to change notification settings - Fork 685
AttributeError: 'Parameter' object has no attribute '_trt' #565
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Hi, do you solve this problem?? |
Unfortunately, the problem is not solved yet... |
Here is a patch for the yolor paper branch which uses torch2trt in yolor.patch
diff --git a/detect.py b/detect.py
index f2d9f36..21a4f84 100644
--- a/detect.py
+++ b/detect.py
@@ -7,6 +7,10 @@ import torch
import torch.backends.cudnn as cudnn
from numpy import random
+import logging
+
+from torch2trt import torch2trt, tensorrt_converter, get_arg, trt, add_missing_trt_tensors
+
from models.experimental import attempt_load
from utils.datasets import LoadStreams, LoadImages
from utils.general import check_img_size, non_max_suppression, apply_classifier, scale_coords, xyxy2xywh, \
@@ -14,6 +18,20 @@ from utils.general import check_img_size, non_max_suppression, apply_classifier,
from utils.plots import plot_one_box
from utils.torch_utils import select_device, load_classifier, time_synchronized
+# REGISTER NEW CONVERTERS
+
+@tensorrt_converter('torch.nn.functional.silu')
+def convert_silu(ctx):
+ input = get_arg(ctx, 'input', pos=0, default=None)
+ output = ctx.method_return
+ input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
+
+ layer = ctx.network.add_activation(input_trt, trt.ActivationType.SIGMOID)
+ layer = ctx.network.add_elementwise(input_trt, layer.get_output(0), trt.ElementWiseOperation.PROD)
+
+ output._trt = layer.get_output(0)
+
+# missing: @tensorrt_converter('torch.nn.parameter.Parameter')
def detect(save_img=False):
source, weights, view_img, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
@@ -32,6 +50,14 @@ def detect(save_img=False):
# Load model
model = attempt_load(weights, map_location=device) # load FP32 model
imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size
+
+ x = torch.ones((1, 3, imgsz, imgsz), device=device)
+ try:
+ model_trt = torch2trt(model, [x], fp16_mode=True)
+ except:
+ logging.exception('could not create tensorRT model')
+ exit()
+
if half:
model.half() # to FP16
@@ -59,6 +85,7 @@ def detect(save_img=False):
t0 = time.time()
img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img
_ = model(img.half() if half else img) if device.type != 'cpu' else None # run once
+ _ = model_trt(img) if device.type != 'cpu' else None # run once
for path, img, im0s, vid_cap in dataset:
img = torch.from_numpy(img).to(device)
img = img.half() if half else img.float() # uint8 to fp16/32
@@ -67,12 +94,17 @@ def detect(save_img=False):
img = img.unsqueeze(0)
# Inference
+ t_trt = time_synchronized()
+ pred = model_trt(img)[0]
+ print('trt pred (%.3fs)' % (time_synchronized() - t_trt))
+
t1 = time_synchronized()
pred = model(img, augment=opt.augment)[0]
+ t2 = time_synchronized()
+ print('torch pred. (%.3fs)' % (t2 - t1))
# Apply NMS
pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
- t2 = time_synchronized()
# Apply Classifier
if classify:
diff --git a/models/common.py b/models/common.py
index 0a1e7e3..d991fcf 100644
--- a/models/common.py
+++ b/models/common.py
@@ -41,7 +41,7 @@ class ImplicitA(nn.Module):
nn.init.normal_(self.implicit, std=.02)
def forward(self, x):
- return self.implicit.expand_as(x) + x
+ return self.implicit.expand(x.size()) + x
class ImplicitM(nn.Module):
@@ -52,7 +52,7 @@ class ImplicitM(nn.Module):
nn.init.normal_(self.implicit, mean=1., std=.02)
def forward(self, x):
- return self.implicit.expand_as(x) * x
+ return self.implicit.expand(x.size()) * x
class ReOrg(nn.Module):
@@ -236,7 +236,7 @@ class BottleneckCSPSE(nn.Module):
self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
def forward(self, x):
- x = x * self.cvsig(self.cs(self.avg_pool(x))).expand_as(x)
+ x = x * self.cvsig(self.cs(self.avg_pool(x))).expand(x.size())
y1 = self.cv3(self.m(self.cv1(x)))
y2 = self.cv2(x)
return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
@@ -259,7 +259,7 @@ class BottleneckCSPSEA(nn.Module):
self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
def forward(self, x):
- x = x + x * self.cvsig(self.cs(self.avg_pool(x))).expand_as(x)
+ x = x + x * self.cvsig(self.cs(self.avg_pool(x))).expand(x.size())
y1 = self.cv3(self.m(self.cv1(x)))
y2 = self.cv2(x)
return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
\```
</p>
</details>
When executing
```bash
python3 detect.py --weights model.pt --source /tmp/example_images --conf 0.25 --img-size 800 --device 0 I'm getting the same |
Hi, have you solved this problem? |
Environment
Ubuntu 20.04
Python 3.8.8
torch 1.8.1
torch2trt 0.2.0
cuda 11.1
cudnn 8.1.1
TensorRT 7.2.2
Issue
I'm trying to convert YOLOR(https://github.com/WongKinYiu/yolor) implemented in PyTorch into TensorRT.
There are two kinds of layers used in YOLOR but not supported by torch2trt now.
torch.nn.functional.silu
torch.Tensor.expand_as
silu
Thanks to #527, there is no problem here.
expand_as
Thanks to #487, a converter for
torch.Tensor.expand
is provided.Since
torch.Tensor.expand(other.size())
equals totorch.Tensor.expand_as(other)
(https://pytorch.org/docs/stable/tensors.html),I replaced all
expand_as(other)
withexpand(other.size())
inyolor/utils/layers.py
.I made a script for conversion:
yolor/torch2trt_conversion.py
In the end, I got
'Parameter' object
means 'torch.nn.parameter.Parameter'. This object has attributedata (torch.Tensor)
.So, I replaced
input
withinput.data
, but it did not work.What is the problem?
The text was updated successfully, but these errors were encountered: