Skip to content

Commit

Permalink
fix rife trt
Browse files Browse the repository at this point in the history
Former-commit-id: 712631c
Former-commit-id: e27f23d
  • Loading branch information
TNTwise committed Nov 10, 2024
1 parent 7a0e80a commit d20fb33
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions backend/src/TensorRTHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def export_dynamo_model(
):
"""Exports a model using TensorRT Dynamo."""
model.to(device=device, dtype=dtype)
example_inputs = [input.to(device=device,dtype=dtype) for input in example_inputs]
exported_program = torch.export.export(
model, tuple(example_inputs), dynamic_shapes=None
)
Expand Down Expand Up @@ -108,8 +107,7 @@ def export_torchscript_model(

# maybe try to load it onto CUDA, and clear pytorch cache after.
model.to(device=device,dtype=dtype)
example_inputs = [input.to(device=device,dtype=dtype) for input in example_inputs]
module = torch.jit.trace(model, example_inputs) # have to put both on same device or sum
module = torch.jit.trace(model, example_inputs)
torch.cuda.empty_cache()
model = None

Expand Down

0 comments on commit d20fb33

Please sign in to comment.