Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,18 +277,18 @@ def _populate_trt_builder_config(
trt.MemoryPoolType.DLA_GLOBAL_DRAM,
self.compilation_settings.dla_global_dram_size,
)
if not self.compilation_settings.use_explicit_typing:
if dtype.float16 in self.compilation_settings.enabled_precisions:
builder_config.set_flag(trt.BuilderFlag.FP16)

if dtype.float16 in self.compilation_settings.enabled_precisions:
builder_config.set_flag(trt.BuilderFlag.FP16)
if dtype.int8 in self.compilation_settings.enabled_precisions:
builder_config.set_flag(trt.BuilderFlag.INT8)

if dtype.int8 in self.compilation_settings.enabled_precisions:
builder_config.set_flag(trt.BuilderFlag.INT8)
if dtype.fp8 in self.compilation_settings.enabled_precisions:
builder_config.set_flag(trt.BuilderFlag.FP8)

if dtype.fp8 in self.compilation_settings.enabled_precisions:
builder_config.set_flag(trt.BuilderFlag.FP8)

if dtype.bfloat16 in self.compilation_settings.enabled_precisions:
builder_config.set_flag(trt.BuilderFlag.BF16)
if dtype.bfloat16 in self.compilation_settings.enabled_precisions:
builder_config.set_flag(trt.BuilderFlag.BF16)

if self.compilation_settings.sparse_weights:
builder_config.set_flag(trt.BuilderFlag.SPARSE_WEIGHTS)
Expand Down
10 changes: 5 additions & 5 deletions py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,28 +60,28 @@ def batch_norm(
):
# We name the weight here according to the state_dict name
weight = (
get_trt_tensor(ctx, 1.0, f"{name}_weight")
get_trt_tensor(ctx, 1.0, f"{name}_weight", dtype=input.dtype)
if weight is None
else get_trt_tensor(ctx, weight, f"{name}_weight")
)
bias = (
get_trt_tensor(ctx, 0.0, f"{name}_bias")
get_trt_tensor(ctx, 0.0, f"{name}_bias", dtype=input.dtype)
if bias is None
else get_trt_tensor(ctx, bias, f"{name}_bias")
)
running_mean = (
get_trt_tensor(ctx, 0.0, f"{name}_running_mean")
get_trt_tensor(ctx, 0.0, f"{name}_running_mean", dtype=input.dtype)
if running_mean is None
else get_trt_tensor(ctx, running_mean, f"{name}_running_mean")
)
running_var = (
get_trt_tensor(ctx, 1.0, f"{name}_running_var")
get_trt_tensor(ctx, 1.0, f"{name}_running_var", dtype=input.dtype)
if running_var is None
else get_trt_tensor(ctx, running_var, f"{name}_running_var")
)

# eps_tensor for numerical stability
eps_tensor = get_trt_tensor(ctx, eps, f"{name}_eps")
eps_tensor = get_trt_tensor(ctx, eps, f"{name}_eps", dtype=input.dtype)

# adjusted_var = running_var + eps
adjusted_var = impl.elementwise.add(
Expand Down
4 changes: 3 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,9 @@ def scaled_dot_product_attention_decomposition(
attn_weight = query @ key.transpose(-2, -1)

if scale is None:
scale = torch.sqrt(torch.scalar_tensor(query.size(-1), dtype=torch.int))
scale = torch.sqrt(torch.scalar_tensor(query.size(-1), dtype=torch.int)).to(
query.dtype
)
attn_weight = attn_weight / scale
else:
attn_weight = attn_weight * scale
Expand Down
22 changes: 13 additions & 9 deletions tests/py/dynamo/models/test_dtype_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def forward(self, x):
use_python_runtime=False,
cache_built_engines=False,
reuse_cached_engines=False,
use_explicit_typing=True,
)

torch_model_results = mod(in_tensor)
Expand Down Expand Up @@ -82,12 +83,13 @@ def forward(self, x):
use_python_runtime=True,
cache_built_engines=False,
reuse_cached_engines=False,
use_explicit_typing=True,
)

torch_model_results = mod(in_tensor)
with torch_tensorrt.logging.debug():
optimized_model_results = trt_mod(in_tensor)

assert torch_model_results.dtype == optimized_model_results.dtype
max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
Expand Down Expand Up @@ -128,11 +130,12 @@ def forward(self, x):
use_python_runtime=False,
cache_built_engines=False,
reuse_cached_engines=False,
use_explicit_typing=True,
)

torch_model_results = mod(in_tensor)
optimized_model_results = trt_mod(in_tensor)

assert torch_model_results.dtype == optimized_model_results.dtype
max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
Expand Down Expand Up @@ -169,11 +172,12 @@ def forward(self, x):
use_python_runtime=True,
cache_built_engines=False,
reuse_cached_engines=False,
use_explicit_typing=True,
)

torch_model_results = mod(in_tensor)
optimized_model_results = trt_mod(in_tensor)

assert torch_model_results.dtype == optimized_model_results.dtype
max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
Expand Down Expand Up @@ -218,16 +222,16 @@ def forward(self, x):
exp_mod,
inputs=[in_tensor],
pass_through_build_failures=True,
enabled_precisions={torch.float, torch.bfloat16, torch.half},
min_block_size=1,
use_python_runtime=False,
cache_built_engines=False,
reuse_cached_engines=False,
use_explicit_typing=True,
)

torch_model_results = mod(in_tensor)
optimized_model_results = trt_mod(in_tensor)

assert torch_model_results.dtype == optimized_model_results.dtype
max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
Expand Down Expand Up @@ -258,16 +262,16 @@ def forward(self, x):
exp_mod,
inputs=[in_tensor],
pass_through_build_failures=True,
enabled_precisions={torch.float, torch.bfloat16, torch.half},
min_block_size=1,
use_python_runtime=True,
cache_built_engines=False,
reuse_cached_engines=False,
use_explicit_typing=True,
)

torch_model_results = mod(in_tensor)
optimized_model_results = trt_mod(in_tensor)

assert torch_model_results.dtype == optimized_model_results.dtype
max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
Expand Down Expand Up @@ -296,16 +300,16 @@ def forward(self, x):
mod,
ir="torch_compile",
inputs=inputs,
enabled_precisions={torch.bfloat16},
min_block_size=1,
device=device,
cache_built_engines=False,
reuse_cached_engines=False,
use_explicit_typing=True,
)

torch_model_results = mod(*inputs)
optimized_model_results = trt_mod(*inputs)

assert torch_model_results.dtype == optimized_model_results.dtype
max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
Expand Down
19 changes: 12 additions & 7 deletions tests/py/dynamo/models/test_dyn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,26 +178,27 @@ def forward(self, x):
not importlib.util.find_spec("torchvision"), "torchvision not installed"
)
@pytest.mark.unit
def test_resnet_dynamic(ir):
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
def test_resnet_dynamic(ir, dtype):
"""
Tests the Resnet18 model (which is fully convertible) with dynamic shapes
"""
import torchvision.models as models

model = models.resnet18(pretrained=True).eval().to("cuda")
model = models.resnet18(pretrained=True).eval().to("cuda").to(dtype)

compile_spec = {
"device": torchtrt.Device("cuda:0"),
"enabled_precisions": {torch.float},
"ir": ir,
"pass_through_build_failures": True,
"min_block_size": 1,
"cache_built_engines": False,
"reuse_cached_engines": False,
"use_explicit_typing": True,
}

if ir == "torch_compile":
input_bs2 = torch.randn((2, 3, 224, 224)).to("cuda")
input_bs2 = torch.randn((2, 3, 224, 224)).to("cuda").to(dtype)
torch._dynamo.mark_dynamic(input_bs2, 0, min=1, max=8)
# Compile the model
trt_model = torch.compile(model, backend="tensorrt", options=compile_spec)
Expand All @@ -208,14 +209,18 @@ def test_resnet_dynamic(ir):
min_shape=(1, 3, 224, 224),
opt_shape=(4, 3, 224, 224),
max_shape=(8, 3, 224, 224),
dtype=torch.float32,
dtype=dtype,
name="x",
)
]
trt_model = torchtrt.compile(model, **compile_spec)

input_bs6 = torch.randn((6, 3, 224, 224)).to("cuda")
cos_sim = cosine_similarity(model(input_bs6), trt_model(input_bs6))
input_bs6 = torch.randn((6, 3, 224, 224)).to("cuda").to(dtype)
pyt_output = model(input_bs6)
trt_output = trt_model(input_bs6)
assert pyt_output.dtype == trt_output.dtype
assert trt_output.dtype == dtype
cos_sim = cosine_similarity(pyt_output, trt_output)
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_resnet_dynamic model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
Expand Down
52 changes: 33 additions & 19 deletions tests/py/dynamo/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,28 +136,31 @@ def test_resnet18_torch_exec_ops(ir):
not importlib.util.find_spec("torchvision"),
"torchvision is not installed",
)
def test_mobilenet_v2(ir):
model = models.mobilenet_v2(pretrained=True).eval().to("cuda")
input = torch.randn((1, 3, 224, 224)).to("cuda")
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
def test_mobilenet_v2(ir, dtype):
model = models.mobilenet_v2(pretrained=True).eval().to("cuda").to(dtype)
input = torch.randn((1, 3, 224, 224)).to("cuda").to(dtype)

compile_spec = {
"inputs": [
torchtrt.Input(
input.shape, dtype=torch.float, format=torch.contiguous_format
)
torchtrt.Input(input.shape, dtype=dtype, format=torch.contiguous_format)
],
"device": torchtrt.Device("cuda:0"),
"enabled_precisions": {torch.float},
"ir": ir,
"pass_through_build_failures": True,
"optimization_level": 1,
"min_block_size": 10,
"cache_built_engines": False,
"reuse_cached_engines": False,
"use_explicit_typing": True,
}

trt_mod = torchtrt.compile(model, **compile_spec)
cos_sim = cosine_similarity(model(input), trt_mod(input))
pyt_output = model(input)
trt_output = trt_mod(input)
assert pyt_output.dtype == trt_output.dtype
assert pyt_output.dtype == dtype
cos_sim = cosine_similarity(pyt_output, trt_output)
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"Mobilenet v2 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
Expand All @@ -172,28 +175,36 @@ def test_mobilenet_v2(ir):
not importlib.util.find_spec("timm") or not importlib.util.find_spec("torchvision"),
"timm or torchvision not installed",
)
def test_efficientnet_b0(ir):
model = timm.create_model("efficientnet_b0", pretrained=True).eval().to("cuda")
input = torch.randn((1, 3, 224, 224)).to("cuda")
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
def test_efficientnet_b0(ir, dtype):
model = (
timm.create_model("efficientnet_b0", pretrained=True)
.eval()
.to("cuda")
.to(dtype)
)
input = torch.randn((1, 3, 224, 224)).to("cuda").to(dtype)

compile_spec = {
"inputs": [
torchtrt.Input(
input.shape, dtype=torch.float, format=torch.contiguous_format
)
torchtrt.Input(input.shape, dtype=dtype, format=torch.contiguous_format)
],
"device": torchtrt.Device("cuda:0"),
"enabled_precisions": {torch.float},
"ir": ir,
"pass_through_build_failures": True,
"optimization_level": 1,
"min_block_size": 10,
"cache_built_engines": False,
"reuse_cached_engines": False,
"use_explicit_typing": True,
}

trt_mod = torchtrt.compile(model, **compile_spec)
cos_sim = cosine_similarity(model(input), trt_mod(input))
pyt_output = model(input)
trt_output = trt_mod(input)
assert pyt_output.dtype == trt_output.dtype
assert pyt_output.dtype == dtype
cos_sim = cosine_similarity(pyt_output, trt_output)
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
Expand All @@ -208,10 +219,11 @@ def test_efficientnet_b0(ir):
not importlib.util.find_spec("transformers"),
"transformers is required to run this test",
)
def test_bert_base_uncased(ir):
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
def test_bert_base_uncased(ir, dtype):
from transformers import BertModel

model = BertModel.from_pretrained("bert-base-uncased").cuda().eval()
model = BertModel.from_pretrained("bert-base-uncased").cuda().eval().to(dtype)
input = torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda")
input2 = torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda")

Expand All @@ -229,21 +241,23 @@ def test_bert_base_uncased(ir):
),
],
"device": torchtrt.Device("cuda:0"),
"enabled_precisions": {torch.float},
"truncate_double": True,
"ir": ir,
"pass_through_build_failures": True,
"optimization_level": 1,
"min_block_size": 15,
"cache_built_engines": False,
"reuse_cached_engines": False,
"use_explicit_typing": True,
}
trt_mod = torchtrt.compile(model, **compile_spec)

model_outputs = model(input, input2)
trt_model_outputs = trt_mod(input, input2)
for key in model_outputs.keys():
out, trt_out = model_outputs[key], trt_model_outputs[key]
assert out.dtype == trt_out.dtype
assert out.dtype == dtype
cos_sim = cosine_similarity(out, trt_out)
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
Expand Down
Loading
Loading