Skip to content

Commit

Permalink
Fix tracing of generator in BERT model on Windows
Browse files Browse the repository at this point in the history
  • Loading branch information
gs-olive committed Apr 23, 2024
1 parent d14833f commit a39801f
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions tests/py/dynamo/models/test_models_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import pytest
import timm
import torch
import torch_tensorrt as torchtrt
import torchvision.models as models
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity
from transformers import BertModel
from transformers.utils.fx import symbolic_trace as transformers_trace

import torch_tensorrt as torchtrt

assertions = unittest.TestCase()


Expand Down Expand Up @@ -108,7 +109,9 @@ def test_efficientnet_b0(ir):

@pytest.mark.unit
def test_bert_base_uncased(ir):
model = BertModel.from_pretrained("bert-base-uncased").cuda().eval()
model = (
BertModel.from_pretrained("bert-base-uncased", return_dict=False).cuda().eval()
)
input = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda")
input2 = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda")

Expand Down Expand Up @@ -139,8 +142,8 @@ def test_bert_base_uncased(ir):
msg=f"Number of outputs for BERT model compilation is different with Pytorch {len(model_outputs)} and TensorRT {len(trt_model_outputs)}. Please check the compilation.",
)

for key, _ in model_outputs.items():
out, trt_out = model_outputs[key], trt_model_outputs[key]
for index in range(len(model_outputs)):
out, trt_out = model_outputs[index], trt_model_outputs[index]
cos_sim = cosine_similarity(out, trt_out)
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
Expand Down

0 comments on commit a39801f

Please sign in to comment.