diff --git a/tests/py/models/test_models.py b/tests/py/models/test_models.py index 6cc9759626..037e9f93d1 100644 --- a/tests/py/models/test_models.py +++ b/tests/py/models/test_models.py @@ -86,7 +86,7 @@ def test_efficientnet_b0(self): def test_bert_base_uncased(self): self.model = cm.BertModule().cuda() - self.input = torch.randint(0, 5, (1, 14), dtype=torch.int32).to("cuda") + self.input = torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda") compile_spec = { "inputs": [