Skip to content

Cherry pick jetson enablement from 2.8 release branch to main #3765

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions .github/workflows/build-test-linux-aarch64-jetpack.yml
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
name: Build and test Linux aarch64 wheels for Jetpack

on:
# TODO: Uncomment this when we have a stable release
# pull_request:
# push:
# branches:
# - main
# - nightly
# - release/*
# tags:
# # NOTE: Binary build pipelines should only get triggered on release candidate builds
# # Release candidate tags look like: v1.11.0-rc1
# - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+
pull_request:
push:
branches:
- main
- nightly
- release/*
tags:
# NOTE: Binary build pipelines should only get triggered on release candidate builds
# Release candidate tags look like: v1.11.0-rc1
- v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+
workflow_dispatch:

jobs:
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/build_wheels_linux_aarch64.yml
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ jobs:
if [[ ${{ inputs.is-jetpack }} == false ]]; then
${CONDA_RUN} python setup.py bdist_wheel
else
${CONDA_RUN} python setup.py bdist_wheel --jetpack --plat-name=linux_tegra_aarch64
${CONDA_RUN} python setup.py bdist_wheel --jetpack
fi
- name: Repair Manylinux_2_28 Wheel
shell: bash -l {0}
Expand Down Expand Up @@ -337,8 +337,8 @@ jobs:
needs: build
name: upload-wheel-${{ matrix.python_version }}-${{ matrix.desired_cuda }}-${{ matrix.gpu_arch_type }}-${{ inputs.is-jetpack }}
uses: pytorch/test-infra/.github/workflows/_binary_upload.yml@main
# for jetpack builds, only upload to pytorch index for nightly builds
if: ${{ inputs.is-jetpack == false || (github.event_name == 'push' && startsWith(github.event.ref, 'refs/heads/nightly')) }}
# for jetpack builds, do not upload to pytorch nightly index, only upload to https://pypi.jetson-ai-lab.io/ manually for each release
if: ${{ inputs.is-jetpack == false }}
with:
repository: ${{ inputs.repository }}
ref: ${{ inputs.ref }}
Expand Down
3 changes: 1 addition & 2 deletions MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,9 @@ http_archive(
http_archive(
name = "torch_l4t",
build_file = "@//third_party/libtorch:BUILD",
sha256 = "6eff643c0a7acda92734cc798338f733ff35c7df1a4434576f5ff7c66fc97319",
strip_prefix = "torch",
type = "zip",
urls = ["https://pypi.jetson-ai-lab.dev/jp6/cu126/+f/6ef/f643c0a7acda9/torch-2.7.0-cp310-cp310-linux_aarch64.whl"],
urls = ["https://pypi.jetson-ai-lab.io/jp6/cu126/+f/62a/1beee9f2f1470/torch-2.8.0-cp310-cp310-linux_aarch64.whl"],
)

# Download these tarballs manually from the NVIDIA website
Expand Down
4 changes: 2 additions & 2 deletions docsrc/getting_started/jetpack.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,14 @@ Build Environment Setup
.. code-block:: sh

# Can only install the torch and torchvision wheel from the JPL repo which is built specifically for JetPack 6.2
python -m pip install torch==2.7.0 torchvision==0.22.0 --index-url=https://pypi.jetson-ai-lab.dev/jp6/cu126/
python -m pip install torch==2.8.0 torchvision==0.23.0 --index-url=https://pypi.jetson-ai-lab.io/jp6/cu126


Building the Wheel
==================

.. code-block:: sh
python setup.py bdist_wheel
python setup.py bdist_wheel --jetpack

Installation
============
Expand Down
4 changes: 2 additions & 2 deletions packaging/pre_build_script.sh
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ curl -L https://github.com/bazelbuild/bazelisk/releases/download/v1.26.0/bazelis
pip uninstall -y torch torchvision

if [[ ${IS_JETPACK} == true ]]; then
# install torch 2.7 for jp6.2
pip install torch==2.7.0 --index-url=https://pypi.jetson-ai-lab.dev/jp6/cu126/
# install torch 2.8 for jp6.2
pip install torch==2.8.0 --index-url=https://pypi.jetson-ai-lab.io/jp6/cu126/
else
TORCH=$(grep "^torch>" py/requirements.txt)
INDEX_URL=https://download.pytorch.org/whl/${CHANNEL}/${CU_VERSION}
Expand Down
9 changes: 5 additions & 4 deletions py/torch_tensorrt/_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import tensorrt as trt
import torch
from torch_tensorrt._features import ENABLED_FEATURES, needs_torch_tensorrt_runtime
from torch_tensorrt._utils import is_tensorrt_version_supported


class dtype(Enum):
Expand Down Expand Up @@ -199,8 +200,6 @@ def _from(
return dtype.i8
elif t == trt.DataType.FP8:
return dtype.f8
elif t == trt.DataType.FP4:
return dtype.fp4
elif t == trt.DataType.INT32:
return dtype.i32
elif t == trt.DataType.INT64:
Expand All @@ -214,6 +213,8 @@ def _from(
elif t == trt.DataType.BF16:
return dtype.bf16
else:
if is_tensorrt_version_supported("10.8.0") and t == trt.DataType.FP4:
return dtype.fp4
raise TypeError(
f"Provided an unsupported data type as a data type for translation (support: bool, int, half, float, bfloat16), got: {t}"
)
Expand Down Expand Up @@ -409,11 +410,11 @@ def to(
return trt.DataType.BOOL
elif self == dtype.bf16:
return trt.DataType.BF16
elif self == dtype.f4:
return trt.DataType.FP4
elif use_default:
return trt.DataType.FLOAT
else:
if is_tensorrt_version_supported("10.8.0") and self == dtype.f4:
return trt.DataType.FP4
raise TypeError("Unsupported tensorrt dtype")

elif t == np.dtype:
Expand Down
23 changes: 23 additions & 0 deletions py/torch_tensorrt/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,26 @@ def check_cross_compile_trt_win_lib() -> bool:
target_lib = ".*libnvinfer_builder_resource_win.so.*"
return any(re.match(target_lib, lib) for lib in loaded_libs)
return False


def is_tensorrt_version_supported(min_version: str = "10.8.0") -> bool:
"""
Check if the installed TensorRT version supports the specified minimum version.
Args:
min_version (str): Minimum required TensorRT version (default: "10.8.0" for FP4 support)
Returns:
bool: True if TensorRT version is >= min_version, False otherwise
Example:
>>> if is_tensorrt_version_supported("10.8.0"):
... # Use FP4 features
... pass
"""
try:
from importlib import metadata

from packaging.version import Version

return bool(Version(metadata.version("tensorrt")) >= Version(min_version))
except (ImportError, ValueError):
# If tensorrt is not installed or version cannot be determined
return False
64 changes: 33 additions & 31 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import torch
from torch.fx.node import Argument, Node, Target
from torch_tensorrt._utils import is_tensorrt_version_supported
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
Expand Down Expand Up @@ -620,40 +621,41 @@ def aten_ops_quantize_op(
)


try:
import modelopt.torch.quantization as mtq # noqa: F401
if is_tensorrt_version_supported("10.8.0"):
try:
import modelopt.torch.quantization as mtq # noqa: F401

assert torch.ops.tensorrt.dynamic_block_quantize_op.default
except Exception as e:
_LOGGER.warning(
"Unable to import quantize op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling quantized models"
)
else:
assert torch.ops.tensorrt.dynamic_block_quantize_op.default
except Exception as e:
_LOGGER.warning(
"Unable to import quantize op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling quantized models"
)
else:

@dynamo_tensorrt_converter(
torch.ops.tensorrt.dynamic_block_quantize_op.default,
supports_dynamic_shapes=True,
)
def aten_ops_dynamic_block_quantize_op(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.dynamic_block_quantize.quantize(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
args[2],
args[3],
args[4],
args[5],
args[6],
@dynamo_tensorrt_converter(
torch.ops.tensorrt.dynamic_block_quantize_op.default,
supports_dynamic_shapes=True,
)
def aten_ops_dynamic_block_quantize_op(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.dynamic_block_quantize.quantize(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
args[2],
args[3],
args[4],
args[5],
args[6],
)


@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim, supports_dynamic_shapes=True)
Expand Down
52 changes: 28 additions & 24 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
ConverterRegistry,
DynamoConverterImplSignature,
)

from torch_tensorrt._utils import is_tensorrt_version_supported
from ..types import Shape, TRTDataType, TRTLayer, TRTTensor

_LOGGER: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -448,31 +448,35 @@ def create_constant(
if torch_value is not None:

if torch_value.dtype == torch.uint8:
if (
target_quantized_type is None
or target_quantized_type != trt.DataType.FP4
):
# Iconstant layer does not support Uint8, it only support that FP4 data packed in uint8
if is_tensorrt_version_supported("10.8.0"):
if (
target_quantized_type is None
or target_quantized_type != trt.DataType.FP4
):
# Iconstant layer does not support Uint8, it only support that FP4 data packed in uint8
raise ValueError(
"Currently supported target_quantized_type for uint8 is FP4, got {target_quantized_type=}"
)
shape[-1] = shape[-1] * 2
weights = to_trt_weights(
ctx,
torch_value,
name,
"CONSTANT",
"CONSTANT",
dtype=trt.DataType.FP4,
count=torch_value.numel() * 2,
)
constant = ctx.net.add_constant(
shape,
weights,
)
constant.name = name
return constant.get_output(0)
else:
raise ValueError(
"Currently supported target_quantized_type for uint8 is FP4, got {target_quantized_type=}"
"Currently FP4 is only supported in TensorRT 10.8.0 and above"
)
shape[-1] = shape[-1] * 2
weights = to_trt_weights(
ctx,
torch_value,
name,
"CONSTANT",
"CONSTANT",
dtype=trt.DataType.FP4,
count=torch_value.numel() * 2,
)
constant = ctx.net.add_constant(
shape,
weights,
)
constant.name = name
return constant.get_output(0)

# Record the weight in ctx for refit and cpu memory reference

# Convert the torch.Tensor to a trt.Weights object
Expand Down
8 changes: 4 additions & 4 deletions py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
dynamo_tensorrt_converter,
)
from torch_tensorrt.dynamo.conversion.converter_utils import load_tensorrt_llm
from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import (
tensorrt_fused_nccl_all_gather_op,
tensorrt_fused_nccl_reduce_scatter_op,
)

_LOGGER: logging.Logger = logging.getLogger(__name__)

if load_tensorrt_llm():
from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import (
tensorrt_fused_nccl_all_gather_op,
tensorrt_fused_nccl_reduce_scatter_op,
)

@dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op)
def fused_nccl_gather(
Expand Down
Loading
Loading