Skip to content

Commit 1abc883

Browse files
committed
Revised the naming mechanism
1 parent 2520a68 commit 1abc883

File tree

4 files changed

+60
-63
lines changed

4 files changed

+60
-63
lines changed

py/torch_tensorrt/dynamo/conversion/_ConversionContext.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from dataclasses import dataclass, field
22
from typing import Union
33

4-
import numpy as np
54
import torch
65
from torch_tensorrt.dynamo._settings import CompilationSettings
76
from torch_tensorrt.dynamo.types import TRTNetwork
@@ -24,10 +23,14 @@ class ConversionContext:
2423
default_factory=CompilationSettings
2524
)
2625
requires_output_allocator: bool = False
27-
weight_refit_map: dict[str, np.array] = field(default_factory=dict)
26+
weight_refit_map: dict[str, torch.Tensor] = field(default_factory=dict)
2827
cpu_weights_reference_holder: dict[str, Union[torch.Tensor]] = field(
2928
default_factory=dict
3029
)
3130

31+
def record_weight(self, name: str, weight: torch.Tensor) -> None:
32+
self.weight_refit_map[name] = weight
33+
self.cpu_weights_reference_holder[name + " CPU_REFERENCE"] = weight
34+
3235
def clear_cpu_weights_reference_holder(self) -> None:
3336
self.cpu_weights_reference_holder.clear()

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 48 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -321,15 +321,16 @@ def cast_int_or_float_to_bool(
321321

322322

323323
def to_trt_weights(
324-
value: Any,
325-
record_weight: bool = False,
326-
name: Optional[str] = None,
327-
ctx: Optional[ConversionContext] = None,
324+
ctx: ConversionContext,
325+
value: torch.Tensor,
326+
name: str,
327+
layer_type_name: str,
328+
weight_type_name: str,
328329
target: Optional[Union[Target, str]] = None,
329-
layer_type_name: Optional[str] = None,
330-
weight_type_name: Optional[str] = None,
331330
source_ir: Optional[SourceIR] = None,
332331
target_quantized_type: Optional[trt.DataType] = None,
332+
dtype: Optional[trt.DataType] = None,
333+
count: Optional[int] = None,
333334
) -> trt.Weights:
334335
"""
335336
Convert a PyTorch tensor or NumPy array to TensorRT weights.
@@ -344,57 +345,50 @@ def to_trt_weights(
344345
- Input tensors are made contiguous before conversion
345346
- Data type is preserved from the original tensor/array
346347
"""
347-
if record_weight:
348-
assert name is not None, "name must be provided if record_weight is True"
349-
assert ctx is not None, "ctx must be provided if record_weight is True"
350-
assert target is not None, "target must be provided if record_weight is True"
351-
assert (
352-
layer_type_name is not None
353-
), "layer_type_name must be provided if record_weight is True"
354-
assert (
355-
weight_type_name is not None
356-
), "weight_type_name must be provided if record_weight is True"
348+
if isinstance(value, np.ndarray):
349+
raise AssertionError(
350+
f"to_trt_weights can only be called on torch.Tensor, got an object of type: {type(value)}"
351+
)
357352

358-
supported_layer_types = ["CONVOLUTION", "DECONVOLUTION"]
359-
supported_weight_types = ["KERNEL"]
360-
assert (
361-
layer_type_name in supported_layer_types
362-
), f"Unsupported layer type: {layer_type_name}. Please add the layer type to this function to enable refitting."
353+
# Weight Recording
354+
supported_layer_types = ["CONVOLUTION", "DECONVOLUTION", "CONSTANT"]
355+
supported_weight_types = ["KERNEL", "BIAS", "CONSTANT"]
356+
assert (
357+
layer_type_name in supported_layer_types
358+
), f"Unsupported layer type: {layer_type_name}. Please add the layer type to this function to enable refitting."
359+
assert (
360+
weight_type_name in supported_weight_types
361+
), f"Unsupported weight type: {weight_type_name}. Please add the weight type to this function to enable refitting."
362+
363+
if weight_type_name == "CONSTANT" and layer_type_name == "CONSTANT":
364+
weight_name = f"{name} CONSTANT"
365+
ctx.record_weight(weight_name, value)
366+
367+
else:
363368
assert (
364-
weight_type_name in supported_weight_types
365-
), f"Unsupported weight type: {weight_type_name}. Please add the weight type to this function to enable refitting."
369+
target is not None
370+
), "target must be provided if the weight type and layer type is not CONSTANT"
366371
source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN
367372
target_name = (
368373
f"{source_ir}_ops.{target}"
369374
if isinstance(target, str)
370375
else f"{source_ir}_ops.{target.__name__}"
371376
)
372377

373-
name = f"[{layer_type_name}]-[{target_name}]-[{name}] {weight_type_name}"
374-
record_weight_in_ctx(ctx, name, value)
375-
376-
if isinstance(value, torch.Tensor):
377-
# Tensor must be contiguous before conversion
378-
value = value.contiguous()
379-
value_trt_dtype = _enums.dtype._from(value.dtype).to(trt.DataType)
380-
return trt.Weights(value_trt_dtype, value.data_ptr(), value.nelement())
381-
elif isinstance(value, np.ndarray):
382-
value = np.ascontiguousarray(value)
383-
value_np_dtype = _enums.dtype._from(value.dtype).to(np.dtype, use_default=True)
384-
return trt.Weights(value_np_dtype, value.data, value.size)
385-
else:
386-
raise AssertionError(
387-
f"to_trt_weights can only be called on torch.Tensor or np.ndarray, got an object of type: {type(value)}"
388-
)
378+
weight_name = f"[{layer_type_name}]-[{target_name}]-[{name}] {weight_type_name}"
379+
ctx.record_weight(weight_name, value)
389380

381+
# TRT Weights Creation
390382

391-
def record_weight_in_ctx(
392-
ctx: ConversionContext,
393-
name: str,
394-
value: torch.Tensor,
395-
) -> None:
396-
ctx.weight_refit_map[name] = value
397-
ctx.cpu_weights_reference_holder[name] = value
383+
# Tensor must be contiguous before conversion
384+
value = value.contiguous()
385+
if dtype is None:
386+
dtype = _enums.dtype._from(value.dtype).to(trt.DataType)
387+
388+
if count is None:
389+
count = value.nelement()
390+
391+
return trt.Weights(dtype, value.data_ptr(), count)
398392

399393

400394
def create_constant(
@@ -451,24 +445,26 @@ def create_constant(
451445
"Currently supported target_quantized_type for uint8 is FP4, got {target_quantized_type=}"
452446
)
453447
shape[-1] = shape[-1] * 2
454-
weights = trt.Weights(
455-
type=trt.DataType.FP4,
456-
ptr=torch_value.data_ptr(),
448+
weights = to_trt_weights(
449+
ctx,
450+
torch_value,
451+
name,
452+
"CONSTANT",
453+
"CONSTANT",
454+
dtype=trt.DataType.FP4,
457455
count=torch_value.numel() * 2,
458456
)
459457
constant = ctx.net.add_constant(
460458
shape,
461459
weights,
462460
)
463461
constant.name = name
464-
record_weight_in_ctx(ctx, name + " FP4_CONSTANT", torch_value)
465462
return constant.get_output(0)
466463

467464
# Record the weight in ctx for refit and cpu memory reference
468-
record_weight_in_ctx(ctx, name + " CONSTANT", torch_value)
469465

470466
# Convert the torch.Tensor to a trt.Weights object
471-
trt_weights = to_trt_weights(torch_value, record_weight=False)
467+
trt_weights = to_trt_weights(ctx, torch_value, name, "CONSTANT", "CONSTANT")
472468
constant = ctx.net.add_constant(
473469
shape,
474470
trt_weights,

py/torch_tensorrt/dynamo/conversion/impl/conv.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,12 @@ def convNd(
8686
num_output_maps = weight.shape[0]
8787
kernel_shape = weight.shape[2:]
8888
weight = to_trt_weights(
89+
ctx,
8990
weight,
90-
record_weight=True,
91-
name=name,
92-
ctx=ctx,
93-
target=target,
91+
name,
9492
layer_type_name="CONVOLUTION",
9593
weight_type_name="KERNEL",
94+
target=target,
9695
source_ir=source_ir,
9796
)
9897

py/torch_tensorrt/dynamo/conversion/impl/deconv.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,12 @@ def deconvNd(
8686
num_output_maps = weight.shape[1]
8787
kernel_shape = weight.shape[2:]
8888
weight = to_trt_weights(
89+
ctx,
8990
weight,
90-
record_weight=True,
91-
name=name,
92-
ctx=ctx,
93-
target=target,
94-
layer_type_name="DECONVOLUTION",
91+
name,
92+
layer_type_name="CONVOLUTION",
9593
weight_type_name="KERNEL",
94+
target=target,
9695
source_ir=source_ir,
9796
)
9897

0 commit comments

Comments
 (0)