Skip to content

Commit 08152df

Browse files
committed
Revised according to comments
1 parent 1abc883 commit 08152df

File tree

4 files changed

+30
-5
lines changed

4 files changed

+30
-5
lines changed

py/torch_tensorrt/dynamo/conversion/_ConversionContext.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,15 @@ class ConversionContext:
2929
)
3030

3131
def record_weight(self, name: str, weight: torch.Tensor) -> None:
32+
"""
33+
Record the weight and name for refitting and CPU reference.
34+
For the refit map, the key is the weight name that appears in the TRT engine and the value is the weight tensor.
35+
For the CPU reference holder, we need to hold the reference to the weight tensor until the whole compilation process is complete.
36+
37+
Args:
38+
name: Name of the weight
39+
weight: Weight to record
40+
"""
3241
self.weight_refit_map[name] = weight
3342
self.cpu_weights_reference_holder[name + " CPU_REFERENCE"] = weight
3443

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def to_trt_weights(
345345
- Input tensors are made contiguous before conversion
346346
- Data type is preserved from the original tensor/array
347347
"""
348-
if isinstance(value, np.ndarray):
348+
if not isinstance(value, torch.Tensor):
349349
raise AssertionError(
350350
f"to_trt_weights can only be called on torch.Tensor, got an object of type: {type(value)}"
351351
)
@@ -355,10 +355,10 @@ def to_trt_weights(
355355
supported_weight_types = ["KERNEL", "BIAS", "CONSTANT"]
356356
assert (
357357
layer_type_name in supported_layer_types
358-
), f"Unsupported layer type: {layer_type_name}. Please add the layer type to this function to enable refitting."
358+
), f"Encountered unsupported layer type: {layer_type_name}. Supported types are: {supported_layer_types}. Manually calling to_trt_weights with a custom layer type is not intended for general use."
359359
assert (
360360
weight_type_name in supported_weight_types
361-
), f"Unsupported weight type: {weight_type_name}. Please add the weight type to this function to enable refitting."
361+
), f"Encountered unsupported weight type: {weight_type_name}. Supported types are: {supported_weight_types}. Manually calling to_trt_weights with a custom weight type is not intended for general use."
362362

363363
if weight_type_name == "CONSTANT" and layer_type_name == "CONSTANT":
364364
weight_name = f"{name} CONSTANT"

py/torch_tensorrt/dynamo/conversion/impl/conv.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,15 @@ def convNd(
5555
# Process bias terms
5656
if isinstance(bias, (torch.Tensor, np.ndarray)):
5757
bias = to_torch(bias, dtype=input.dtype)
58-
bias = get_trt_tensor(ctx, bias, f"{name}_bias")
58+
bias = to_trt_weights(
59+
ctx,
60+
bias,
61+
name,
62+
layer_type_name="CONVOLUTION",
63+
weight_type_name="BIAS",
64+
target=target,
65+
source_ir=source_ir,
66+
)
5967

6068
elif isinstance(bias, TRTTensor):
6169
bias = get_trt_tensor(ctx, bias, f"{name}_bias")

py/torch_tensorrt/dynamo/conversion/impl/deconv.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,15 @@ def deconvNd(
5555
if isinstance(bias, (torch.Tensor, np.ndarray)):
5656
# Transform the bias constant into a Numpy array
5757
bias = to_torch(bias, dtype=input.dtype)
58-
bias = get_trt_tensor(ctx, bias, f"{name}_bias")
58+
bias = to_trt_weights(
59+
ctx,
60+
bias,
61+
name,
62+
layer_type_name="CONVOLUTION",
63+
weight_type_name="BIAS",
64+
target=target,
65+
source_ir=source_ir,
66+
)
5967

6068
elif isinstance(bias, TRTTensor):
6169
bias = get_trt_tensor(ctx, bias, f"{name}_bias")

0 commit comments

Comments
 (0)