diff --git a/core/conversion/converters/BUILD b/core/conversion/converters/BUILD index f9948eea7d..ff28a4a892 100755 --- a/core/conversion/converters/BUILD +++ b/core/conversion/converters/BUILD @@ -54,6 +54,7 @@ cc_library( "NodeConverterRegistry.cpp", "impl/activation.cpp", "impl/batch_norm.cpp", + "impl/bitwise.cpp", "impl/cast.cpp", "impl/concat.cpp", "impl/constant.cpp", diff --git a/core/conversion/converters/impl/bitwise.cpp b/core/conversion/converters/impl/bitwise.cpp new file mode 100644 index 0000000000..992c11bdf7 --- /dev/null +++ b/core/conversion/converters/impl/bitwise.cpp @@ -0,0 +1,55 @@ +#include "core/conversion/converters/converters.h" +#include "core/util/prelude.h" + +#include + +namespace torch_tensorrt { +namespace core { +namespace conversion { +namespace converters { +namespace impl { + +auto bitwise_not_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern( + {"aten::bitwise_not(Tensor self) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in = args[0].ITensorOrFreeze(ctx); + nvinfer1::ILayer* out; + + if (in->getType() == nvinfer1::DataType::kINT32) { + // Integer case, using ~x = -x - 1 + auto neg_one = torch::tensor({-1}, util::TRTDataTypeToScalarType(in->getType())); + auto neg_one_const = tensor_to_const(ctx, neg_one); + auto neg = add_elementwise( + ctx, + nvinfer1::ElementWiseOperation::kPROD, + in, + neg_one_const, + util::node_info(n) + std::string("_Negation")); + TORCHTRT_CHECK(neg, "Unable to create prod layer from node: " << *n); + out = add_elementwise( + ctx, + nvinfer1::ElementWiseOperation::kSUM, + neg->getOutput(0), + neg_one_const, + util::node_info(n) + std::string("_SubOne")); + TORCHTRT_CHECK(out, "Unable to create sum layer from node: " << *n); + } else if (in->getType() == nvinfer1::DataType::kBOOL) { + // Boolean case + out = ctx->net->addUnary(*in, nvinfer1::UnaryOperation::kNOT); + TORCHTRT_CHECK(out, "Unable to create logical not layer from node: " << *n); + } else { + LOG_ERROR("Input tensor must be 32 bit integer or boolean"); + return false; + } + + out->setName(util::node_info(n).c_str()); + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out->getOutput(0)); + LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); + + return true; + }}); + +} // namespace impl +} // namespace converters +} // namespace conversion +} // namespace core +} // namespace torch_tensorrt diff --git a/core/lowering/register_trt_placeholder_ops.cpp b/core/lowering/register_trt_placeholder_ops.cpp index 5ba8171208..17d7d3f47a 100644 --- a/core/lowering/register_trt_placeholder_ops.cpp +++ b/core/lowering/register_trt_placeholder_ops.cpp @@ -10,7 +10,10 @@ c10::AliasAnalysisKind aliasAnalysisFromSchema() { RegisterOperators trt_placeholder_ops_reg({ /// Op marks a Tensor to be conveted from an Torch Tensor /// to a TRT constant Tensor - Operator("trt::const(Tensor val) -> Tensor", [](Stack& stack) { /*noop*/ }, aliasAnalysisFromSchema()), + Operator( + "trt::const(Tensor val) -> Tensor", + [](Stack& stack) { /*noop*/ }, + aliasAnalysisFromSchema()), }); } // namespace jit diff --git a/tests/core/conversion/converters/BUILD b/tests/core/conversion/converters/BUILD index 3dc7865b9e..5843acae75 100644 --- a/tests/core/conversion/converters/BUILD +++ b/tests/core/conversion/converters/BUILD @@ -15,6 +15,10 @@ converter_test( name = "test_batch_norm", ) +converter_test( + name = "test_bitwise", +) + converter_test( name = "test_instance_norm", ) @@ -136,6 +140,7 @@ test_suite( tests = [ ":test_activation", ":test_batch_norm", + ":test_bitwise", ":test_instance_norm", ":test_cast", ":test_clone", diff --git a/tests/core/conversion/converters/test_bitwise.cpp b/tests/core/conversion/converters/test_bitwise.cpp new file mode 100644 index 0000000000..7826b51c44 --- /dev/null +++ b/tests/core/conversion/converters/test_bitwise.cpp @@ -0,0 +1,42 @@ +#include +#include "core/compiler.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/csrc/jit/ir/irparser.h" + +std::string gen_test_graph() { + return R"IR( + graph(%0: Tensor): + %3 : Tensor = aten::bitwise_not(%0) + return (%3))IR"; +} + +#define test_bitwise_not(dtype) \ + TEST(Converters, ATenBitwiseNot##dtype##ConvertsCorrectly) { \ + const auto graph = gen_test_graph(); \ + \ + auto g = std::make_shared(); \ + torch::jit::parseIR(graph, g.get()); \ + \ + at::Tensor in; \ + if (strcmp(#dtype, "Integer") == 0) \ + in = at::randint(-128, 128, {10}, {at::kCUDA}).toType(at::kInt); \ + if (strcmp(#dtype, "Boolean") == 0) \ + in = at::randint(0, 1, {10}, {at::kCUDA}).toType(at::kBool); \ + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); \ + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); \ + \ + in = at::clone(in); \ + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); \ + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); \ + \ + auto jit_int = jit_results[0].toType(at::kInt); \ + auto trt_int = trt_results[0].toType(at::kInt); \ + \ + ASSERT_TRUE(torch_tensorrt::tests::util::exactlyEqual(jit_int, trt_int)); \ + } + +test_bitwise_not(Integer); +test_bitwise_not(Boolean); + +#undef test_bitwise_not diff --git a/tests/util/run_graph_engine.cpp b/tests/util/run_graph_engine.cpp index df52b54b26..86d999ab62 100644 --- a/tests/util/run_graph_engine.cpp +++ b/tests/util/run_graph_engine.cpp @@ -4,6 +4,7 @@ #include "core/ir/ir.h" #include "core/runtime/runtime.h" #include "core/util/prelude.h" +#include "core/util/trt_util.h" #include "cuda_runtime_api.h" #include "torch/csrc/jit/ir/ir.h" #include "torch/csrc/jit/ir/irparser.h" @@ -19,7 +20,7 @@ namespace util { std::vector toInputs(std::vector ten) { std::vector a; for (auto i : ten) { - a.push_back(core::ir::Input(core::util::toVec(i.sizes()))); + a.push_back(core::ir::Input(core::util::toVec(i.sizes()), core::util::ScalarTypeToTRTDataType(i.scalar_type()))); } return std::move(a); }