diff --git a/noxfile.py b/noxfile.py index b17b506b7a..72483786eb 100644 --- a/noxfile.py +++ b/noxfile.py @@ -202,6 +202,7 @@ def run_base_tests(session): else: session.run_always("pytest", test) + def run_fx_core_tests(session): print("Running FX core tests") session.chdir(os.path.join(TOP_DIR, "py/torch_tensorrt/fx/test")) @@ -214,6 +215,7 @@ def run_fx_core_tests(session): else: session.run_always("pytest", test) + def run_fx_converter_tests(session): print("Running FX converter tests") session.chdir(os.path.join(TOP_DIR, "py/torch_tensorrt/fx/test")) @@ -229,6 +231,7 @@ def run_fx_converter_tests(session): else: session.run_always("pytest", test, skip_tests) + def run_fx_lower_tests(session): print("Running FX passes and trt_lower tests") session.chdir(os.path.join(TOP_DIR, "py/torch_tensorrt/fx/test")) @@ -237,7 +240,7 @@ def run_fx_lower_tests(session): # "passes/test_fuse_permute_linear_trt.py", "passes/test_remove_duplicate_output_args.py", "passes/test_fuse_permute_matmul_trt.py", - #"passes/test_graph_opts.py" + # "passes/test_graph_opts.py" "trt_lower", ] for test in tests: @@ -246,6 +249,7 @@ def run_fx_lower_tests(session): else: session.run_always("pytest", test) + def run_fx_quant_tests(session): print("Running FX Quant tests") session.chdir(os.path.join(TOP_DIR, "py/torch_tensorrt/fx/test")) @@ -261,6 +265,7 @@ def run_fx_quant_tests(session): else: session.run_always("pytest", test, skip_tests) + def run_fx_tracer_tests(session): print("Running FX Tracer tests") session.chdir(os.path.join(TOP_DIR, "py/torch_tensorrt/fx/test")) @@ -269,7 +274,7 @@ def run_fx_tracer_tests(session): tests = [ "tracer/test_acc_shape_prop.py", "tracer/test_acc_tracer.py", - #"tracer/test_dispatch_tracer.py" + # "tracer/test_dispatch_tracer.py" ] for test in tests: if USE_HOST_DEPS: @@ -277,6 +282,7 @@ def run_fx_tracer_tests(session): else: session.run_always("pytest", test) + def run_fx_tools_tests(session): print("Running FX tools tests") session.chdir(os.path.join(TOP_DIR, "py/torch_tensorrt/fx/test")) @@ -396,6 +402,7 @@ def run_l0_api_tests(session): run_base_tests(session) cleanup(session) + def run_l0_fx_tests(session): if not USE_HOST_DEPS: install_deps(session) @@ -405,6 +412,7 @@ def run_l0_fx_tests(session): run_fx_lower_tests(session) cleanup(session) + def run_l0_fx_core_tests(session): if not USE_HOST_DEPS: install_deps(session) @@ -412,6 +420,7 @@ def run_l0_fx_core_tests(session): run_fx_core_tests(session) cleanup(session) + def run_l0_fx_converter_tests(session): if not USE_HOST_DEPS: install_deps(session) @@ -419,6 +428,7 @@ def run_l0_fx_converter_tests(session): run_fx_converter_tests(session) cleanup(session) + def run_l0_fx_lower_tests(session): if not USE_HOST_DEPS: install_deps(session) @@ -426,6 +436,7 @@ def run_l0_fx_lower_tests(session): run_fx_lower_tests(session) cleanup(session) + def run_l0_dla_tests(session): if not USE_HOST_DEPS: install_deps(session) @@ -443,6 +454,7 @@ def run_l1_model_tests(session): run_model_tests(session) cleanup(session) + def run_l1_int8_accuracy_tests(session): if not USE_HOST_DEPS: install_deps(session) @@ -452,6 +464,7 @@ def run_l1_int8_accuracy_tests(session): run_int8_accuracy_tests(session) cleanup(session) + def run_l1_fx_tests(session): if not USE_HOST_DEPS: install_deps(session) @@ -461,6 +474,7 @@ def run_l1_fx_tests(session): run_fx_tools_tests(session) cleanup(session) + def run_l2_trt_compatibility_tests(session): if not USE_HOST_DEPS: install_deps(session) @@ -483,26 +497,31 @@ def l0_api_tests(session): """When a developer needs to check correctness for a PR or something""" run_l0_api_tests(session) + @nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True) def l0_fx_tests(session): """When a developer needs to check correctness for a PR or something""" run_l0_fx_tests(session) + @nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True) def l0_fx_core_tests(session): """When a developer needs to check correctness for a PR or something""" run_l0_fx_core_tests(session) + @nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True) def l0_fx_converter_tests(session): """When a developer needs to check correctness for a PR or something""" run_l0_fx_converter_tests(session) + @nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True) def l0_fx_lower_tests(session): """When a developer needs to check correctness for a PR or something""" run_l0_fx_lower_tests(session) + @nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True) def l0_dla_tests(session): """When a developer needs to check basic api functionality using host dependencies""" @@ -514,11 +533,13 @@ def l1_model_tests(session): """When a user needs to test the functionality of standard models compilation and results""" run_l1_model_tests(session) + @nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True) def l1_fx_tests(session): """When a user needs to test the functionality of standard models compilation and results""" run_l1_fx_tests(session) + @nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True) def l1_int8_accuracy_tests(session): """Checking accuracy performance on various usecases""" @@ -534,4 +555,4 @@ def l2_trt_compatibility_tests(session): @nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True) def l2_multi_gpu_tests(session): """Makes sure that Torch-TensorRT can operate on multi-gpu systems""" - run_l2_multi_gpu_tests(session) \ No newline at end of file + run_l2_multi_gpu_tests(session)