@@ -1647,6 +1647,120 @@ def aten_ops_logical_xor(
16471647 )
16481648
16491649
1650+ def bitwise_type_validator (node : Node ) -> bool :
1651+ targets = [
1652+ torch .ops .aten .bitwise_and .Tensor ,
1653+ torch .ops .aten .bitwise_or .Tensor ,
1654+ torch .ops .aten .bitwise_xor .Tensor ,
1655+ ]
1656+ if node .target not in targets :
1657+ return False
1658+
1659+ lhs_val = node .args [0 ]
1660+ rhs_val = node .args [1 ]
1661+ lhs_meta = lhs_val .meta .get ("tensor_meta" )
1662+ rhs_meta = rhs_val .meta .get ("tensor_meta" )
1663+
1664+ if lhs_meta is None or rhs_meta is None :
1665+ return False
1666+
1667+ supported_type = [torch .bool , bool ]
1668+ return lhs_meta .dtype in supported_type and rhs_meta .dtype in supported_type
1669+
1670+
1671+ @dynamo_tensorrt_converter (torch .ops .aten .bitwise_and .Tensor , capability_validator = bitwise_type_validator ) # type: ignore[misc]
1672+ @dynamo_tensorrt_converter (torch .ops .aten .bitwise_and .Scalar ) # type: ignore[misc]
1673+ @dynamo_tensorrt_converter (torch .ops .aten .bitwise_and .Scalar_Tensor ) # type: ignore[misc]
1674+ def aten_ops_bitwise_and (
1675+ ctx : ConversionContext ,
1676+ target : Target ,
1677+ args : Tuple [Argument , ...],
1678+ kwargs : Dict [str , Argument ],
1679+ name : str ,
1680+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
1681+ return impl .elementwise .bitwise_and (
1682+ ctx ,
1683+ target ,
1684+ SourceIR .ATEN ,
1685+ name ,
1686+ args [0 ],
1687+ args [1 ],
1688+ )
1689+
1690+
1691+ @dynamo_tensorrt_converter (torch .ops .aten .bitwise_or .Tensor , capability_validator = bitwise_type_validator ) # type: ignore[misc]
1692+ @dynamo_tensorrt_converter (torch .ops .aten .bitwise_or .Scalar ) # type: ignore[misc]
1693+ @dynamo_tensorrt_converter (torch .ops .aten .bitwise_or .Scalar_Tensor ) # type: ignore[misc]
1694+ def aten_ops_bitwise_or (
1695+ ctx : ConversionContext ,
1696+ target : Target ,
1697+ args : Tuple [Argument , ...],
1698+ kwargs : Dict [str , Argument ],
1699+ name : str ,
1700+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
1701+ return impl .elementwise .bitwise_or (
1702+ ctx ,
1703+ target ,
1704+ SourceIR .ATEN ,
1705+ name ,
1706+ args [0 ],
1707+ args [1 ],
1708+ )
1709+
1710+
1711+ @dynamo_tensorrt_converter (torch .ops .aten .bitwise_xor .Tensor , capability_validator = bitwise_type_validator ) # type: ignore[misc]
1712+ @dynamo_tensorrt_converter (torch .ops .aten .bitwise_xor .Scalar ) # type: ignore[misc]
1713+ @dynamo_tensorrt_converter (torch .ops .aten .bitwise_xor .Scalar_Tensor ) # type: ignore[misc]
1714+ def aten_ops_bitwise_xor (
1715+ ctx : ConversionContext ,
1716+ target : Target ,
1717+ args : Tuple [Argument , ...],
1718+ kwargs : Dict [str , Argument ],
1719+ name : str ,
1720+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
1721+ return impl .elementwise .bitwise_xor (
1722+ ctx ,
1723+ target ,
1724+ SourceIR .ATEN ,
1725+ name ,
1726+ args [0 ],
1727+ args [1 ],
1728+ )
1729+
1730+
1731+ def bitwise_not_type_validator (node : Node ) -> bool :
1732+ val = node .args [0 ]
1733+ val_meta = val .meta .get ("tensor_meta" )
1734+
1735+ if val_meta is None :
1736+ return False
1737+
1738+ supported_type = [torch .bool , bool ]
1739+ return val_meta .dtype in supported_type
1740+
1741+
1742+ @dynamo_tensorrt_converter (torch .ops .aten .bitwise_not .default , capability_validator = bitwise_not_type_validator ) # type: ignore[misc]
1743+ @enforce_tensor_types (
1744+ {
1745+ 0 : (TRTTensor ,),
1746+ }
1747+ ) # type: ignore[misc]
1748+ def aten_ops_bitwise_not (
1749+ ctx : ConversionContext ,
1750+ target : Target ,
1751+ args : Tuple [Argument , ...],
1752+ kwargs : Dict [str , Argument ],
1753+ name : str ,
1754+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
1755+ return impl .unary .bitwise_not (
1756+ ctx ,
1757+ target ,
1758+ SourceIR .ATEN ,
1759+ name ,
1760+ args [0 ],
1761+ )
1762+
1763+
16501764@dynamo_tensorrt_converter (torch .ops .aten .eq .Tensor ) # type: ignore[misc]
16511765@dynamo_tensorrt_converter (torch .ops .aten .eq .Scalar ) # type: ignore[misc]
16521766@enforce_tensor_types (
0 commit comments