diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index e5e2e780df..311f4277b0 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -1,6 +1,7 @@ #include "torch/csrc/jit/passes/common_subexpression_elimination.h" #include "torch/csrc/jit/passes/create_functional_graphs.h" #include "torch/csrc/jit/passes/dead_code_elimination.h" +#include "torch/csrc/jit/passes/erase_number_types.h" #include "torch/csrc/jit/passes/freeze_module.h" #include "torch/csrc/jit/passes/fuse_linear.h" #include "torch/csrc/jit/passes/guard_elimination.h" @@ -63,6 +64,8 @@ void LowerGraph(std::shared_ptr& g, LowerInfo lower_info) { passes::RemoveNOPs(g); passes::AliasOperators(g); passes::SiluToSigmoidMultipication(g); + passes::RemoveSingleUse0DTensors(g); + passes::RemoveUnnecessaryCasts(g); LOG_GRAPH(*g); } diff --git a/core/lowering/passes/BUILD b/core/lowering/passes/BUILD index fde517c428..de0e488376 100644 --- a/core/lowering/passes/BUILD +++ b/core/lowering/passes/BUILD @@ -23,6 +23,7 @@ cc_library( "view_to_reshape.cpp", "remove_dropout.cpp", "remove_nops.cpp", + "remove_unnecessary_casts.cpp", "silu_to_sigmoid_multiplication.cpp", "unpack_addmm.cpp", "unpack_batch_norm.cpp", diff --git a/core/lowering/passes/passes.h b/core/lowering/passes/passes.h index a79d6c1cdc..348b56997f 100644 --- a/core/lowering/passes/passes.h +++ b/core/lowering/passes/passes.h @@ -27,6 +27,8 @@ void RemoveContiguous(std::shared_ptr& graph); void ViewToReshape(std::shared_ptr& graph); void RemoveDropout(std::shared_ptr& graph); void RemoveNOPs(std::shared_ptr graph); +void RemoveSingleUse0DTensors(std::shared_ptr& g); +void RemoveUnnecessaryCasts(std::shared_ptr& graph); void UnpackAddMM(std::shared_ptr& graph); void UnpackBatchNorm(std::shared_ptr& graph); void UnpackLogSoftmax(std::shared_ptr& graph); diff --git a/core/lowering/passes/remove_set_attrs.cpp b/core/lowering/passes/remove_set_attrs.cpp new file mode 100644 index 0000000000..6645707f49 --- /dev/null +++ b/core/lowering/passes/remove_set_attrs.cpp @@ -0,0 +1,35 @@ +#include +#include + +#include "torch/csrc/jit/passes/subgraph_rewrite.h" + +#include "core/lowering/passes/passes.h" +#include "core/util/prelude.h" + +namespace torch_tensorrt { +namespace core { +namespace lowering { +namespace passes { + +void RemoveSetAttrs(const torch::jit::Module& mod, std::string method_name) { + auto g = mod.get_method(method_name).graph(); + + std::string set_attr_pattern = R"IR( + graph(%self, %0): + None = prim::SetAttr[name="_has_warned"](%self, %0) + return ())IR"; + std::string no_set_attr_pattern = R"IR( + graph(%self, %0): + return ())IR"; + + // remove contiguous + torch::jit::SubgraphRewriter remove_set_attr; + remove_set_attr.RegisterRewritePattern(set_attr_pattern, no_set_attr_pattern); + remove_set_attr.runOnGraph(g); + LOG_GRAPH("Post remove contiguous: " << *g); +} + +} // namespace passes +} // namespace lowering +} // namespace core +} // namespace torch_tensorrt diff --git a/core/lowering/passes/remove_unnecessary_casts.cpp b/core/lowering/passes/remove_unnecessary_casts.cpp new file mode 100644 index 0000000000..d7c9c77d71 --- /dev/null +++ b/core/lowering/passes/remove_unnecessary_casts.cpp @@ -0,0 +1,168 @@ +#include "torch/csrc/jit/ir/constants.h" +#include "torch/csrc/jit/passes/subgraph_rewrite.h" + +#include "core/util/prelude.h" + +#include + +namespace torch_tensorrt { +namespace core { +namespace lowering { +namespace passes { + +// Presumably this is safe since torch::jit::EraseNumberTypesOnBlock exists which just +// removes prim::TensorToNum, aten::Float, aten::Int and prim::NumToTensor nodes outright +void RemoveUnnecessaryCasts(std::shared_ptr& graph) { + std::string int_cast_pattern = R"IR( + graph(%1: int): + %2: Tensor = aten::NumToTensor(%1) + %3: int = aten::Int(%2) + return (%3))IR"; + std::string int_clean_pattern = R"IR( + graph(%1: int): + return (%1))IR"; + + std::string float_cast_pattern = R"IR( + graph(%1: float): + %2: Tensor = aten::NumToTensor(%1) + %3: float = aten::Float(%2) + return (%3))IR"; + std::string float_clean_pattern = R"IR( + graph(%1: float): + return (%1))IR"; + + std::string bool_cast_pattern = R"IR( + graph(%1: bool): + %2: Tensor = aten::NumToTensor(%1) + %3: bool = aten::Bool(%2) + return (%3))IR"; + std::string bool_clean_pattern = R"IR( + graph(%1: bool): + return (%1))IR"; + + torch::jit::SubgraphRewriter int_cast_rewriter; + int_cast_rewriter.RegisterRewritePattern(int_cast_pattern, int_clean_pattern); + int_cast_rewriter.runOnGraph(graph); + + torch::jit::SubgraphRewriter float_cast_rewriter; + float_cast_rewriter.RegisterRewritePattern(float_cast_pattern, float_clean_pattern); + float_cast_rewriter.runOnGraph(graph); + + torch::jit::SubgraphRewriter bool_cast_rewriter; + bool_cast_rewriter.RegisterRewritePattern(bool_cast_pattern, bool_clean_pattern); + bool_cast_rewriter.runOnGraph(graph); + + LOG_GRAPH("After RemoveUnnecessaryCasts: " << *graph); +} + +void RemoveSingleUse0DTensors(std::shared_ptr& g) { + for (auto it = g->block()->nodes().begin(), end = g->block()->nodes().end(); it != end; ++it) { + if (it->kind() == torch::jit::prim::Constant) { + // Going from a constant and is single use means we can fuse + if (it->output()->type()->isSubtypeOf(c10::TensorType::get())) { + // Get the tensor stored in constant + at::Tensor t = *torch::jit::constant_as(it->output()); + // If shape is 0D + if (t.sizes() == std::vector({})) { + LOG_GRAPH("Found a 0D Tensor: " << it->output()->debugName()); + LOG_GRAPH("Number of uses: " << it->output()->uses().size()); + // If the tensor is only used once + if (it->output()->uses().size() == 1) { + auto use = it->output()->uses()[0]; + auto user = use.user; + + // Is a NumToTensor / aten::[Int/Float] case + if (user->outputs().size() == 1 && user->outputs()[0]->type()->isSubtypeOf(c10::TensorType::get())) { + if (user->output()->uses().size() == 1) { + auto potential_cast = user->output()->uses()[0].user; + // The downstream user is aten::Int + if (potential_cast->kind() == c10::Symbol::fromQualString("aten::Int") || + potential_cast->kind() == c10::Symbol::fromQualString("aten::Float")) { + LOG_GRAPH("Downstream user is aten::Int/aten::Float"); + auto arg = use.offset; + + for (size_t k = 0; k < user->inputs().size(); ++k) { + if (k != arg) { + if (user->inputs()[k]->type()->isSubtypeOf(c10::TensorType::get())) { + LOG_GRAPH("Input " << k << " is a Tensor"); + if (user->inputs()[k]->node()->kind() == c10::Symbol::fromQualString("prim::NumToTensor")) { + auto num_to_tensor = user->inputs()[k]->node(); + + LOG_GRAPH( + "Found a prim::NumToTensor / aten::[Int/Float] pair with an intermediate operation:\n " + << *(*it) << *num_to_tensor << *user << *potential_cast); + + // Replace the Tensor Constant with a scalar constant + LOG_GRAPH("Deleting 0-dim Tensor: " << **it); + torch::jit::WithInsertPoint gaurd(*it); + + auto new_const_val = g->insertConstant(t.item(), c10::nullopt, it->scope()); + new_const_val->copyMetadata(it->output()); + // How to determine the internal scalar type instead of assuming? + if (potential_cast->kind() == c10::aten::Int) { + new_const_val->setType(c10::IntType::get()); + } else if (potential_cast->kind() == c10::aten::Float) { + new_const_val->setType(c10::FloatType::get()); + } + it->output()->replaceAllUsesWith(new_const_val); + it.destroyCurrent(); + + LOG_GRAPH("New constant: " << *new_const_val->node()); + + // Delete NumToTensor + LOG_GRAPH("Deleting NumToTensor: " << *num_to_tensor); + num_to_tensor->output()->replaceAllUsesWith(num_to_tensor->inputs()[0]); + num_to_tensor->destroy(); + + // Change intermediate op output type + LOG_GRAPH(user->schema()); + + torch::jit::Node* new_node; + switch (user->kind()) { + // Use this to handle special cases where the scalar version of the intermediate operator + // has a different schema than the original + case c10::aten::add: + new_node = g->create( + user->kind(), + torch::jit::ArrayRef({user->inputs()[0], user->inputs()[1]}), + 1); + new_node->insertAfter(user); + new_node->outputs()[0]->setType(c10::IntType::get()); + user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); + user->destroy(); + break; + default: + new_node = g->create(user->kind(), user->inputs(), 1); + new_node->insertAfter(user); + new_node->outputs()[0]->setType(c10::IntType::get()); + user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); + user->destroy(); + break; + } + + LOG_GRAPH("New intermediate operation: " << *new_node); + LOG_GRAPH(new_node->schema()); + + // Delete aten::Int + LOG_GRAPH("Deleting aten::[Int/Float]: " << *potential_cast); + potential_cast->output()->replaceAllUsesWith(potential_cast->inputs()[0]); + potential_cast->destroy(); + } + } + } + } + } + } + } + } + } + } + } + } + LOG_GRAPH("Post removing single use 0-dim Tensor operations: " << *g); +} + +} // namespace passes +} // namespace lowering +} // namespace core +} // namespace torch_tensorrt diff --git a/noxfile.py b/noxfile.py index 08bcc348fb..ebc9a6048a 100644 --- a/noxfile.py +++ b/noxfile.py @@ -18,7 +18,7 @@ def install_deps(session): def download_models(session, use_host_env=False): print("Downloading test models") - session.install('timm') + session.install("-r", os.path.join(TOP_DIR, "tests", "modules", "requirements.txt")) print(TOP_DIR) session.chdir(os.path.join(TOP_DIR, "tests", "modules")) if use_host_env: diff --git a/tests/core/lowering/BUILD b/tests/core/lowering/BUILD index b97f9ba451..6eebb79d3f 100644 --- a/tests/core/lowering/BUILD +++ b/tests/core/lowering/BUILD @@ -50,6 +50,10 @@ lowering_test( name = "test_remove_detach_pass", ) +lowering_test( + name = "test_remove_unnecessary_casts", +) + lowering_test( name = "test_view_to_reshape_pass", ) @@ -81,6 +85,7 @@ test_suite( ":test_remove_detach_pass", ":test_view_to_reshape_pass", ":test_remove_dropout_pass", + ":test_remove_unnecessary_casts", ":test_reduce_to_pass", ":test_reduce_gelu", ":test_unpack_hardswish", diff --git a/tests/core/lowering/test_remove_unnecessary_casts.cpp b/tests/core/lowering/test_remove_unnecessary_casts.cpp new file mode 100644 index 0000000000..62f913e49a --- /dev/null +++ b/tests/core/lowering/test_remove_unnecessary_casts.cpp @@ -0,0 +1,155 @@ +#include +#include "core/compiler.h" +#include "core/lowering/passes/passes.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/csrc/jit/ir/irparser.h" +#include "torch/csrc/jit/ir/subgraph_matcher.h" + +TEST(LoweringPasses, RemoveUnnecessaryCastIntCorrectly) { + std::string source_graph = R"IR( + graph(%1: int): + %2: Tensor = aten::NumToTensor(%1) + %3: int = aten::Int(%2) + %4: int = aten::add(%3, %3, %3) + return (%4))IR"; + std::string target_graph = R"IR( + graph(%1: int): + %4: int = aten::add(%1, %1, %1) + return (%4))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, sg.get()); + torch_tensorrt::core::lowering::passes::RemoveUnnecessaryCasts(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, tg.get()); + + ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); +} + +TEST(LoweringPasses, RemoveUnnecessaryCastFloatCorrectly) { + std::string source_graph = R"IR( + graph(%1: float): + %2: Tensor = aten::NumToTensor(%1) + %3: float = aten::Float(%2) + %4: float = aten::add(%3, %3, %3) + return (%3))IR"; + std::string target_graph = R"IR( + graph(%1: float): + %4: float = aten::add(%1, %1, %1) + return (%4))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, sg.get()); + torch_tensorrt::core::lowering::passes::RemoveUnnecessaryCasts(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, tg.get()); + + ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); +} + +TEST(LoweringPasses, RemoveUnnecessaryCastBoolCorrectly) { + std::string source_graph = R"IR( + graph(%1: bool): + %2: Tensor = aten::NumToTensor(%1) + %3: bool = aten::Bool(%2) + %4: bool = aten::__and__(%3, %3) + return (%3))IR"; + std::string target_graph = R"IR( + graph(%1: bool): + %4: bool = aten::__and__(%1, %1) + return (%4))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, sg.get()); + torch_tensorrt::core::lowering::passes::RemoveUnnecessaryCasts(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, tg.get()); + + ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); +} + +TEST(LoweringPasses, RemoveSingleUse0DTensorsIntCorrectly) { + std::string source_graph = R"IR( + graph(%0: int): + %1: Tensor = prim::Constant[value=[8]]() + %2: int = prim::Constant[value=1]() + %3: Tensor = prim::NumToTensor(%0) + %4: Tensor = aten::add(%1, %3, %2) + %5: int = aten::Int(%4) + %6: int = aten::add(%5, %5) + return (%6))IR"; + std::string target_graph = R"IR( + graph(%0: int): + %1: int = prim::Constant[value=8]() + %4: int = aten::add(%1, %0) + %6: int = aten::add(%4, %4) + return (%6))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, sg.get()); + + auto first_op = *(sg->block()->nodes().begin()); + torch::jit::WithInsertPoint guard(first_op); + torch::jit::Value* r = sg->insertConstant(c10::scalar_to_tensor(8), c10::nullopt, first_op->scope()); + r->copyMetadata(first_op->output()); + r->setType(c10::TensorType::get()); + first_op->output()->replaceAllUsesWith(r); + first_op->destroy(); + + torch_tensorrt::core::lowering::passes::RemoveSingleUse0DTensors(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, tg.get()); + + ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); +} + +TEST(LoweringPasses, RemoveSingleUse0DTensorsFloatCorrectly) { + std::string source_graph = R"IR( + graph(%0: float): + %1: Tensor = prim::Constant[value=[8.]]() + %2: float = prim::Constant[value=1.]() + %3: Tensor = prim::NumToTensor(%0) + %4: Tensor = aten::add(%1, %3, %2) + %5: float = aten::Float(%4) + %6: float = aten::add(%5, %5) + return (%6))IR"; + std::string target_graph = R"IR( + graph(%0: float): + %1: float = prim::Constant[value=8.]() + %4: float = aten::add(%1, %0) + %6: float = aten::add(%4, %4) + return (%6))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, sg.get()); + + auto first_op = *(sg->block()->nodes().begin()); + torch::jit::WithInsertPoint guard(first_op); + torch::jit::Value* r = sg->insertConstant(c10::scalar_to_tensor(8.0), c10::nullopt, first_op->scope()); + r->copyMetadata(first_op->output()); + r->setType(c10::TensorType::get()); + first_op->output()->replaceAllUsesWith(r); + first_op->destroy(); + + torch_tensorrt::core::lowering::passes::RemoveSingleUse0DTensors(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, tg.get()); + + ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); +} \ No newline at end of file diff --git a/tests/cpp/cpp_api_test.h b/tests/cpp/cpp_api_test.h index 2291f814cd..3addfbc2ed 100644 --- a/tests/cpp/cpp_api_test.h +++ b/tests/cpp/cpp_api_test.h @@ -6,12 +6,12 @@ #include "torch/script.h" #include "torch_tensorrt/torch_tensorrt.h" -using PathAndInSize = std::tuple>, float>; +using PathAndInput = std::tuple>, std::vector, float>; -class CppAPITests : public testing::TestWithParam { +class CppAPITests : public testing::TestWithParam { public: void SetUp() override { - PathAndInSize params = GetParam(); + PathAndInput params = GetParam(); std::string path = std::get<0>(params); try { // Deserialize the ScriptModule from a file using torch::jit::load(). @@ -21,7 +21,8 @@ class CppAPITests : public testing::TestWithParam { ASSERT_TRUE(false); } input_shapes = std::get<1>(params); - threshold = std::get<2>(params); + input_types = std::get<2>(params); + threshold = std::get<3>(params); } void TearDown() { @@ -32,5 +33,6 @@ class CppAPITests : public testing::TestWithParam { protected: torch::jit::script::Module mod; std::vector> input_shapes; + std::vector input_types; float threshold; }; diff --git a/tests/cpp/test_compiled_modules.cpp b/tests/cpp/test_compiled_modules.cpp index c61a8f76f1..595dd7044f 100644 --- a/tests/cpp/test_compiled_modules.cpp +++ b/tests/cpp/test_compiled_modules.cpp @@ -3,20 +3,42 @@ TEST_P(CppAPITests, CompiledModuleIsClose) { std::vector jit_inputs_ivalues; std::vector trt_inputs_ivalues; - for (auto in_shape : input_shapes) { - auto in = at::randint(5, in_shape, {at::kCUDA}); + std::vector shapes; + for (uint64_t i = 0; i < input_shapes.size(); i++) { + auto in = at::randint(5, input_shapes[i], {at::kCUDA}).to(input_types[i]); jit_inputs_ivalues.push_back(in.clone()); trt_inputs_ivalues.push_back(in.clone()); + auto in_spec = torch_tensorrt::Input(input_shapes[i]); + in_spec.dtype = input_types[i]; + shapes.push_back(in_spec); + std::cout << in_spec << std::endl; } torch::jit::IValue jit_results_ivalues = torch_tensorrt::tests::util::RunModuleForward(mod, jit_inputs_ivalues); std::vector jit_results; - jit_results.push_back(jit_results_ivalues.toTensor()); + if (jit_results_ivalues.isTuple()) { + auto tuple = jit_results_ivalues.toTuple(); + for (auto t : tuple->elements()) { + jit_results.push_back(t.toTensor()); + } + } else { + jit_results.push_back(jit_results_ivalues.toTensor()); + } + + auto spec = torch_tensorrt::ts::CompileSpec(shapes); + spec.truncate_long_and_double = true; - auto trt_mod = torch_tensorrt::ts::compile(mod, input_shapes); + auto trt_mod = torch_tensorrt::ts::compile(mod, spec); torch::jit::IValue trt_results_ivalues = torch_tensorrt::tests::util::RunModuleForward(trt_mod, trt_inputs_ivalues); std::vector trt_results; - trt_results.push_back(trt_results_ivalues.toTensor()); + if (trt_results_ivalues.isTuple()) { + auto tuple = trt_results_ivalues.toTuple(); + for (auto t : tuple->elements()) { + trt_results.push_back(t.toTensor()); + } + } else { + trt_results.push_back(trt_results_ivalues.toTensor()); + } for (size_t i = 0; i < trt_results.size(); i++) { ASSERT_TRUE( @@ -30,13 +52,14 @@ INSTANTIATE_TEST_SUITE_P( CompiledModuleForwardIsCloseSuite, CppAPITests, testing::Values( - PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}), - PathAndInSize({"tests/modules/resnet50_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}), - PathAndInSize({"tests/modules/mobilenet_v2_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}), - PathAndInSize({"tests/modules/resnet18_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}), - PathAndInSize({"tests/modules/resnet50_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}), - PathAndInSize({"tests/modules/mobilenet_v2_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}), - PathAndInSize({"tests/modules/efficientnet_b0_scripted.jit.pt", {{1, 3, 224, 224}}, 8e-3}), - PathAndInSize({"tests/modules/vit_scripted.jit.pt", {{1, 3, 224, 224}}, 8e-2}))); + PathAndInput({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), + PathAndInput({"tests/modules/resnet50_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), + PathAndInput({"tests/modules/mobilenet_v2_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), + PathAndInput({"tests/modules/resnet18_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), + PathAndInput({"tests/modules/resnet50_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), + PathAndInput({"tests/modules/mobilenet_v2_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), + PathAndInput({"tests/modules/efficientnet_b0_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 8e-3}), + PathAndInput({"tests/modules/bert_base_uncased_traced.jit.pt", {{1, 14}, {1, 14}}, {at::kInt, at::kInt}, 8e-2}), + PathAndInput({"tests/modules/vit_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 8e-2}))); #endif diff --git a/tests/cpp/test_default_input_types.cpp b/tests/cpp/test_default_input_types.cpp index 63904c7416..a79ddafe0c 100644 --- a/tests/cpp/test_default_input_types.cpp +++ b/tests/cpp/test_default_input_types.cpp @@ -78,7 +78,7 @@ TEST_P(CppAPITests, InputsRespectUserSettingFP16WeightsFP32In) { } auto in = torch_tensorrt::Input(input_shapes[0]); - in.dtype = torch::kF32; + in.dtype = torch::kFloat; auto spec = torch_tensorrt::ts::CompileSpec({in}); spec.enabled_precisions.insert(torch_tensorrt::DataType::kHalf); @@ -116,4 +116,5 @@ TEST_P(CppAPITests, InputsRespectUserSettingFP32WeightsFP16In) { INSTANTIATE_TEST_SUITE_P( CompiledModuleForwardIsCloseSuite, CppAPITests, - testing::Values(PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}))); + testing::Values( + PathAndInput({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat} /*unused*/, 2e-5}))); diff --git a/tests/cpp/test_example_tensors.cpp b/tests/cpp/test_example_tensors.cpp index fc77d9e4d4..6561cd16a0 100644 --- a/tests/cpp/test_example_tensors.cpp +++ b/tests/cpp/test_example_tensors.cpp @@ -3,8 +3,8 @@ TEST_P(CppAPITests, InputsFromTensors) { std::vector jit_inputs_ivalues; std::vector trt_inputs_ivalues; - for (auto in_shape : input_shapes) { - auto in = at::randn(in_shape, {at::kCUDA}); + for (uint64_t i = 0; i < input_shapes.size(); i++) { + auto in = at::randn(input_shapes[i], {at::kCUDA}).to(input_types[i]); jit_inputs_ivalues.push_back(in.clone()); trt_inputs_ivalues.push_back(in.clone()); } @@ -20,4 +20,4 @@ TEST_P(CppAPITests, InputsFromTensors) { INSTANTIATE_TEST_SUITE_P( CompiledModuleForwardIsCloseSuite, CppAPITests, - testing::Values(PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}))); + testing::Values(PathAndInput({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}))); diff --git a/tests/cpp/test_modules_as_engines.cpp b/tests/cpp/test_modules_as_engines.cpp index ab4ccc1ae7..c77919c8b9 100644 --- a/tests/cpp/test_modules_as_engines.cpp +++ b/tests/cpp/test_modules_as_engines.cpp @@ -4,8 +4,8 @@ TEST_P(CppAPITests, ModuleAsEngineIsClose) { std::vector inputs; std::vector inputs_ivalues; - for (auto in_shape : input_shapes) { - inputs.push_back(at::randint(5, in_shape, {at::kCUDA})); + for (uint64_t i = 0; i < input_shapes.size(); i++) { + inputs.push_back(at::randint(5, input_shapes[i], {at::kCUDA}).to(input_types[i])); inputs_ivalues.push_back(inputs[inputs.size() - 1].clone()); } @@ -21,8 +21,8 @@ TEST_P(CppAPITests, ModuleAsEngineIsClose) { TEST_P(CppAPITests, ModuleToEngineToModuleIsClose) { std::vector inputs; std::vector inputs_ivalues; - for (auto in_shape : input_shapes) { - inputs.push_back(at::randint(5, in_shape, {at::kCUDA})); + for (uint64_t i = 0; i < input_shapes.size(); i++) { + inputs.push_back(at::randint(5, input_shapes[i], {at::kCUDA}).to(input_types[i])); inputs_ivalues.push_back(inputs[inputs.size() - 1].clone()); } @@ -57,13 +57,13 @@ INSTANTIATE_TEST_SUITE_P( ModuleAsEngineForwardIsCloseSuite, CppAPITests, testing::Values( - PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}), - PathAndInSize({"tests/modules/resnet50_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}), - PathAndInSize({"tests/modules/mobilenet_v2_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}), - PathAndInSize({"tests/modules/resnet18_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}), - PathAndInSize({"tests/modules/resnet50_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}), - PathAndInSize({"tests/modules/mobilenet_v2_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}), - PathAndInSize({"tests/modules/efficientnet_b0_scripted.jit.pt", {{1, 3, 224, 224}}, 2e-5}), - PathAndInSize({"tests/modules/vit_scripted.jit.pt", {{1, 3, 224, 224}}, 8e-2}))); + PathAndInput({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), + PathAndInput({"tests/modules/resnet50_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), + PathAndInput({"tests/modules/mobilenet_v2_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), + PathAndInput({"tests/modules/resnet18_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), + PathAndInput({"tests/modules/resnet50_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), + PathAndInput({"tests/modules/mobilenet_v2_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), + PathAndInput({"tests/modules/efficientnet_b0_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), + PathAndInput({"tests/modules/vit_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 8e-2}))); #endif diff --git a/tests/cpp/test_multi_gpu_serde.cpp b/tests/cpp/test_multi_gpu_serde.cpp index 2356583fa3..366c287c32 100644 --- a/tests/cpp/test_multi_gpu_serde.cpp +++ b/tests/cpp/test_multi_gpu_serde.cpp @@ -4,8 +4,8 @@ TEST_P(CppAPITests, CompiledModuleIsClose) { std::vector jit_inputs_ivalues; std::vector trt_inputs_ivalues; - for (auto in_shape : input_shapes) { - auto in = at::randint(5, in_shape, {at::kCUDA}); + for (uint64_t i = 0; i < input_shapes.size(); i++) { + auto in = at::randint(5, input_shapes[i], {at::kCUDA}).to(input_types[i]); jit_inputs_ivalues.push_back(in.clone()); trt_inputs_ivalues.push_back(in.clone()); } @@ -31,4 +31,4 @@ TEST_P(CppAPITests, CompiledModuleIsClose) { INSTANTIATE_TEST_SUITE_P( CompiledModuleForwardIsCloseSuite, CppAPITests, - testing::Values(PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}))); \ No newline at end of file + testing::Values(PathAndInput({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}))); \ No newline at end of file diff --git a/tests/cpp/test_serialization.cpp b/tests/cpp/test_serialization.cpp index 877e42e6ab..0086500be5 100644 --- a/tests/cpp/test_serialization.cpp +++ b/tests/cpp/test_serialization.cpp @@ -21,8 +21,8 @@ std::vector toInputRangesDynamic(std::vector post_serialized_inputs_ivalues; std::vector pre_serialized_inputs_ivalues; - for (auto in_shape : input_shapes) { - auto in = at::randint(5, in_shape, {at::kCUDA}); + for (uint64_t i = 0; i < input_shapes.size(); i++) { + auto in = at::randint(5, input_shapes[i], {at::kCUDA}).to(input_types[i]); post_serialized_inputs_ivalues.push_back(in.clone()); pre_serialized_inputs_ivalues.push_back(in.clone()); } @@ -50,8 +50,8 @@ TEST_P(CppAPITests, SerializedModuleIsStillCorrect) { TEST_P(CppAPITests, SerializedDynamicModuleIsStillCorrect) { std::vector post_serialized_inputs_ivalues; std::vector pre_serialized_inputs_ivalues; - for (auto in_shape : input_shapes) { - auto in = at::randint(5, in_shape, {at::kCUDA}); + for (uint64_t i = 0; i < input_shapes.size(); i++) { + auto in = at::randint(5, input_shapes[i], {at::kCUDA}).to(input_types[i]); post_serialized_inputs_ivalues.push_back(in.clone()); pre_serialized_inputs_ivalues.push_back(in.clone()); } @@ -81,5 +81,5 @@ INSTANTIATE_TEST_SUITE_P( CompiledModuleForwardIsCloseSuite, CppAPITests, testing::Values( - PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5}), - PathAndInSize({"tests/modules/pooling_traced.jit.pt", {{1, 3, 10, 10}}, 2e-5}))); + PathAndInput({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 2e-5}), + PathAndInput({"tests/modules/pooling_traced.jit.pt", {{1, 3, 10, 10}}, {at::kFloat}, 2e-5}))); diff --git a/tests/modules/hub.py b/tests/modules/hub.py index 5a02774a3c..7b707f5785 100644 --- a/tests/modules/hub.py +++ b/tests/modules/hub.py @@ -3,6 +3,7 @@ import torch.nn.functional as F import torchvision.models as models import timm +from transformers import BertModel, BertTokenizer, BertConfig torch.hub._validate_not_a_forked_repo = lambda a, b, c: True @@ -189,3 +190,30 @@ def forward(self, x): conditional_model = FallbackIf().eval().cuda() conditional_script_model = torch.jit.script(conditional_model) torch.jit.save(conditional_script_model, "conditional_scripted.jit.pt") + +enc = BertTokenizer.from_pretrained("bert-base-uncased") +text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]" +tokenized_text = enc.tokenize(text) +masked_index = 8 +tokenized_text[masked_index] = "[MASK]" +indexed_tokens = enc.convert_tokens_to_ids(tokenized_text) +segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1] +tokens_tensor = torch.tensor([indexed_tokens]) +segments_tensors = torch.tensor([segments_ids]) +dummy_input = [tokens_tensor, segments_tensors] + +config = BertConfig( + vocab_size_or_config_json_file=32000, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + torchscript=True, +) + +model = BertModel(config) +model.eval() +model = BertModel.from_pretrained("bert-base-uncased", torchscript=True) + +traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors]) +torch.jit.save(traced_model, "bert_base_uncased_traced.jit.pt") diff --git a/tests/modules/requirements.txt b/tests/modules/requirements.txt index c7eb11119d..9e2d9a26c4 100644 --- a/tests/modules/requirements.txt +++ b/tests/modules/requirements.txt @@ -1,3 +1,4 @@ -f https://download.pytorch.org/whl/torch_stable.html +#torch==1.10.0+cu113 timm==v0.4.12 -torch==1.10.0+cu113 +transformers==4.17.0 \ No newline at end of file