@@ -754,12 +754,12 @@ def aten_ops_cumsum(
754754 )
755755
756756
757- @dynamo_tensorrt_converter (torch .ops .aten .tile .default ) # type: ignore[misc]
757+ @dynamo_tensorrt_converter (torch .ops .aten .tile .default )
758758@enforce_tensor_types (
759759 {
760760 0 : (TRTTensor ,),
761761 }
762- ) # type: ignore[misc]
762+ )
763763def aten_ops_tile (
764764 ctx : ConversionContext ,
765765 target : Target ,
@@ -777,7 +777,7 @@ def aten_ops_tile(
777777 )
778778
779779
780- @dynamo_tensorrt_converter (torch .ops .aten .permute .default ) # type: ignore[misc]
780+ @dynamo_tensorrt_converter (torch .ops .aten .permute .default )
781781@enforce_tensor_types (
782782 {
783783 0 : (TRTTensor ,),
@@ -1702,29 +1702,63 @@ def aten_ops_logical_xor(
17021702
17031703
17041704def bitwise_type_validator (node : Node ) -> bool :
1705- targets = [
1705+ supported_type = [torch .bool , bool ]
1706+
1707+ tensor_targets = [
17061708 torch .ops .aten .bitwise_and .Tensor ,
17071709 torch .ops .aten .bitwise_or .Tensor ,
17081710 torch .ops .aten .bitwise_xor .Tensor ,
17091711 ]
1710- if node .target not in targets :
1711- return False
1712+ scalar_targets = [
1713+ torch .ops .aten .bitwise_and .Scalar ,
1714+ torch .ops .aten .bitwise_or .Scalar ,
1715+ torch .ops .aten .bitwise_xor .Scalar ,
1716+ ]
1717+ scalar_tensor_targets = [
1718+ torch .ops .aten .bitwise_and .Scalar_Tensor ,
1719+ torch .ops .aten .bitwise_or .Scalar_Tensor ,
1720+ torch .ops .aten .bitwise_xor .Scalar_Tensor ,
1721+ ]
17121722
1713- lhs_val = node .args [0 ]
1714- rhs_val = node .args [1 ]
1715- lhs_meta = lhs_val .meta .get ("tensor_meta" )
1716- rhs_meta = rhs_val .meta .get ("tensor_meta" )
1723+ if node .target in tensor_targets :
1724+ lhs_val = node .args [0 ]
1725+ rhs_val = node .args [1 ]
1726+ lhs_meta = lhs_val .meta .get ("tensor_meta" )
1727+ rhs_meta = rhs_val .meta .get ("tensor_meta" )
1728+ if lhs_meta is None or rhs_meta is None :
1729+ return False
1730+ return lhs_meta .dtype in supported_type and rhs_meta .dtype in supported_type
17171731
1718- if lhs_meta is None or rhs_meta is None :
1719- return False
1732+ elif node .target in scalar_targets :
1733+ lhs_val = node .args [0 ]
1734+ rhs_val = node .args [1 ]
1735+ lhs_meta = lhs_val .meta .get ("tensor_meta" )
1736+ if lhs_meta is None :
1737+ return False
1738+ return lhs_meta .dtype in supported_type and isinstance (rhs_val , bool )
17201739
1721- supported_type = [torch .bool , bool ]
1722- return lhs_meta .dtype in supported_type and rhs_meta .dtype in supported_type
1740+ elif node .target in scalar_tensor_targets :
1741+ lhs_val = node .args [0 ]
1742+ rhs_val = node .args [1 ]
1743+ rhs_meta = rhs_val .meta .get ("tensor_meta" )
1744+ if rhs_meta is None :
1745+ return False
1746+ return isinstance (lhs_val , bool ) and rhs_meta .dtype in supported_type
1747+
1748+ else :
1749+ return False
17231750
17241751
1725- @dynamo_tensorrt_converter (torch .ops .aten .bitwise_and .Tensor , capability_validator = bitwise_type_validator ) # type: ignore[misc]
1726- @dynamo_tensorrt_converter (torch .ops .aten .bitwise_and .Scalar ) # type: ignore[misc]
1727- @dynamo_tensorrt_converter (torch .ops .aten .bitwise_and .Scalar_Tensor ) # type: ignore[misc]
1752+ @dynamo_tensorrt_converter (
1753+ torch .ops .aten .bitwise_and .Tensor , capability_validator = bitwise_type_validator
1754+ )
1755+ @dynamo_tensorrt_converter (
1756+ torch .ops .aten .bitwise_and .Scalar , capability_validator = bitwise_type_validator
1757+ )
1758+ @dynamo_tensorrt_converter (
1759+ torch .ops .aten .bitwise_and .Scalar_Tensor ,
1760+ capability_validator = bitwise_type_validator ,
1761+ )
17281762def aten_ops_bitwise_and (
17291763 ctx : ConversionContext ,
17301764 target : Target ,
@@ -1742,9 +1776,15 @@ def aten_ops_bitwise_and(
17421776 )
17431777
17441778
1745- @dynamo_tensorrt_converter (torch .ops .aten .bitwise_or .Tensor , capability_validator = bitwise_type_validator ) # type: ignore[misc]
1746- @dynamo_tensorrt_converter (torch .ops .aten .bitwise_or .Scalar ) # type: ignore[misc]
1747- @dynamo_tensorrt_converter (torch .ops .aten .bitwise_or .Scalar_Tensor ) # type: ignore[misc]
1779+ @dynamo_tensorrt_converter (
1780+ torch .ops .aten .bitwise_or .Tensor , capability_validator = bitwise_type_validator
1781+ )
1782+ @dynamo_tensorrt_converter (
1783+ torch .ops .aten .bitwise_or .Scalar , capability_validator = bitwise_type_validator
1784+ )
1785+ @dynamo_tensorrt_converter (
1786+ torch .ops .aten .bitwise_or .Scalar_Tensor , capability_validator = bitwise_type_validator
1787+ )
17481788def aten_ops_bitwise_or (
17491789 ctx : ConversionContext ,
17501790 target : Target ,
@@ -1762,9 +1802,16 @@ def aten_ops_bitwise_or(
17621802 )
17631803
17641804
1765- @dynamo_tensorrt_converter (torch .ops .aten .bitwise_xor .Tensor , capability_validator = bitwise_type_validator ) # type: ignore[misc]
1766- @dynamo_tensorrt_converter (torch .ops .aten .bitwise_xor .Scalar ) # type: ignore[misc]
1767- @dynamo_tensorrt_converter (torch .ops .aten .bitwise_xor .Scalar_Tensor ) # type: ignore[misc]
1805+ @dynamo_tensorrt_converter (
1806+ torch .ops .aten .bitwise_xor .Tensor , capability_validator = bitwise_type_validator
1807+ )
1808+ @dynamo_tensorrt_converter (
1809+ torch .ops .aten .bitwise_xor .Scalar , capability_validator = bitwise_type_validator
1810+ )
1811+ @dynamo_tensorrt_converter (
1812+ torch .ops .aten .bitwise_xor .Scalar_Tensor ,
1813+ capability_validator = bitwise_type_validator ,
1814+ )
17681815def aten_ops_bitwise_xor (
17691816 ctx : ConversionContext ,
17701817 target : Target ,
@@ -1793,12 +1840,14 @@ def bitwise_not_type_validator(node: Node) -> bool:
17931840 return val_meta .dtype in supported_type
17941841
17951842
1796- @dynamo_tensorrt_converter (torch .ops .aten .bitwise_not .default , capability_validator = bitwise_not_type_validator ) # type: ignore[misc]
1843+ @dynamo_tensorrt_converter (
1844+ torch .ops .aten .bitwise_not .default , capability_validator = bitwise_not_type_validator
1845+ )
17971846@enforce_tensor_types (
17981847 {
17991848 0 : (TRTTensor ,),
18001849 }
1801- ) # type: ignore[misc]
1850+ )
18021851def aten_ops_bitwise_not (
18031852 ctx : ConversionContext ,
18041853 target : Target ,
@@ -1815,13 +1864,13 @@ def aten_ops_bitwise_not(
18151864 )
18161865
18171866
1818- @dynamo_tensorrt_converter (torch .ops .aten .eq .Tensor ) # type: ignore[misc]
1819- @dynamo_tensorrt_converter (torch .ops .aten .eq .Scalar ) # type: ignore[misc]
1867+ @dynamo_tensorrt_converter (torch .ops .aten .eq .Tensor )
1868+ @dynamo_tensorrt_converter (torch .ops .aten .eq .Scalar )
18201869@enforce_tensor_types (
18211870 {
18221871 0 : (TRTTensor ,),
18231872 }
1824- ) # type: ignore[misc]
1873+ )
18251874def aten_ops_eq (
18261875 ctx : ConversionContext ,
18271876 target : Target ,
@@ -1839,13 +1888,13 @@ def aten_ops_eq(
18391888 )
18401889
18411890
1842- @dynamo_tensorrt_converter (torch .ops .aten .ne .Tensor ) # type: ignore[misc]
1843- @dynamo_tensorrt_converter (torch .ops .aten .ne .Scalar ) # type: ignore[misc]
1891+ @dynamo_tensorrt_converter (torch .ops .aten .ne .Tensor )
1892+ @dynamo_tensorrt_converter (torch .ops .aten .ne .Scalar )
18441893@enforce_tensor_types (
18451894 {
18461895 0 : (TRTTensor ,),
18471896 }
1848- ) # type: ignore[misc]
1897+ )
18491898def aten_ops_ne (
18501899 ctx : ConversionContext ,
18511900 target : Target ,
@@ -1863,13 +1912,13 @@ def aten_ops_ne(
18631912 )
18641913
18651914
1866- @dynamo_tensorrt_converter (torch .ops .aten .gt .Tensor ) # type: ignore[misc]
1867- @dynamo_tensorrt_converter (torch .ops .aten .gt .Scalar ) # type: ignore[misc]
1915+ @dynamo_tensorrt_converter (torch .ops .aten .gt .Tensor )
1916+ @dynamo_tensorrt_converter (torch .ops .aten .gt .Scalar )
18681917@enforce_tensor_types (
18691918 {
18701919 0 : (TRTTensor ,),
18711920 }
1872- ) # type: ignore[misc]
1921+ )
18731922def aten_ops_gt (
18741923 ctx : ConversionContext ,
18751924 target : Target ,
@@ -1887,13 +1936,13 @@ def aten_ops_gt(
18871936 )
18881937
18891938
1890- @dynamo_tensorrt_converter (torch .ops .aten .ge .Tensor ) # type: ignore[misc]
1891- @dynamo_tensorrt_converter (torch .ops .aten .ge .Scalar ) # type: ignore[misc]
1939+ @dynamo_tensorrt_converter (torch .ops .aten .ge .Tensor )
1940+ @dynamo_tensorrt_converter (torch .ops .aten .ge .Scalar )
18921941@enforce_tensor_types (
18931942 {
18941943 0 : (TRTTensor ,),
18951944 }
1896- ) # type: ignore[misc]
1945+ )
18971946def aten_ops_ge (
18981947 ctx : ConversionContext ,
18991948 target : Target ,
@@ -1911,13 +1960,13 @@ def aten_ops_ge(
19111960 )
19121961
19131962
1914- @dynamo_tensorrt_converter (torch .ops .aten .lt .Tensor ) # type: ignore[misc]
1915- @dynamo_tensorrt_converter (torch .ops .aten .lt .Scalar ) # type: ignore[misc]
1963+ @dynamo_tensorrt_converter (torch .ops .aten .lt .Tensor )
1964+ @dynamo_tensorrt_converter (torch .ops .aten .lt .Scalar )
19161965@enforce_tensor_types (
19171966 {
19181967 0 : (TRTTensor ,),
19191968 }
1920- ) # type: ignore[misc]
1969+ )
19211970def aten_ops_lt (
19221971 ctx : ConversionContext ,
19231972 target : Target ,
@@ -1935,13 +1984,13 @@ def aten_ops_lt(
19351984 )
19361985
19371986
1938- @dynamo_tensorrt_converter (torch .ops .aten .le .Tensor ) # type: ignore[misc]
1939- @dynamo_tensorrt_converter (torch .ops .aten .le .Scalar ) # type: ignore[misc]
1987+ @dynamo_tensorrt_converter (torch .ops .aten .le .Tensor )
1988+ @dynamo_tensorrt_converter (torch .ops .aten .le .Scalar )
19401989@enforce_tensor_types (
19411990 {
19421991 0 : (TRTTensor ,),
19431992 }
1944- ) # type: ignore[misc]
1993+ )
19451994def aten_ops_le (
19461995 ctx : ConversionContext ,
19471996 target : Target ,
@@ -2191,14 +2240,14 @@ def aten_ops_argmax(
21912240 )
21922241
21932242
2194- @dynamo_tensorrt_converter (torch .ops .aten .addmm .default ) # type: ignore[misc]
2243+ @dynamo_tensorrt_converter (torch .ops .aten .addmm .default )
21952244@enforce_tensor_types (
21962245 {
21972246 0 : (TRTTensor ,),
21982247 1 : (np .ndarray , torch .Tensor , TRTTensor ),
21992248 2 : (np .ndarray , torch .Tensor , TRTTensor ),
22002249 }
2201- ) # type: ignore[misc]
2250+ )
22022251def aten_ops_addmm (
22032252 ctx : ConversionContext ,
22042253 target : Target ,
0 commit comments