Skip to content

If loop support #976

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion core/conversion/conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ bool OpSupported(const torch::jit::Node* n) {
return evaluators::shouldEvalAtConversionTime(n) || converters::node_is_convertable(n);
}

bool SpecialCaseSupport(const torch::jit::Node* n) {
return n->kind() == torch::jit::prim::Loop || n->kind() == torch::jit::prim::If;
}

c10::optional<torch::jit::IValue> EvaluateNode(ConversionCtx* ctx, const torch::jit::Node* n, int level, int limit) {
// Check to see if you can just go through and eval all of these AOT (saves
// the recursion) Also probably a better way to deal with the two error cases;
Expand Down Expand Up @@ -499,7 +503,7 @@ std::unordered_map<c10::OperatorName, std::string> GetUnsupportedOpsInBlock(cons
auto schema = n->maybeSchema();
// Some ops like torch::jit::prim::Loop, torch::jit::prim::If, torch::jit::prim::DictConstruct don't have a schema
// but they are supported. torch::jit::prim::DictConstruct is supported via fallback only
if (!OpSupported(n)) {
if (!OpSupported(n) && !SpecialCaseSupport(n)) {
if (schema) {
std::stringstream ss;
ss << *schema;
Expand Down
Empty file modified core/conversion/evaluators/eval_util.cpp
100755 → 100644
Empty file.
Empty file modified core/conversion/evaluators/eval_util.h
100755 → 100644
Empty file.
Empty file modified core/conversion/evaluators/prim.cpp
100755 → 100644
Empty file.
28 changes: 15 additions & 13 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
# TOP_DIR
TOP_DIR=os.path.dirname(os.path.realpath(__file__)) if not 'TOP_DIR' in os.environ else os.environ["TOP_DIR"]

nox.options.sessions = ["l0_api_tests-3"]
SUPPORTED_PYTHON_VERSIONS=["3.7", "3.8", "3.9", "3.10"]

nox.options.sessions = ["l0_api_tests-3.7"]

def install_deps(session):
print("Installing deps")
Expand Down Expand Up @@ -268,62 +270,62 @@ def run_l2_multi_gpu_tests(session, use_host_env=False):
run_multi_gpu_tests(session, use_host_env)
cleanup(session)

@nox.session(python=["3"], reuse_venv=True)
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
def l0_api_tests(session):
"""When a developer needs to check correctness for a PR or something"""
run_l0_api_tests(session, use_host_env=False)

@nox.session(python=["3"], reuse_venv=True)
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
def l0_api_tests_host_deps(session):
"""When a developer needs to check basic api functionality using host dependencies"""
run_l0_api_tests(session, use_host_env=True)

@nox.session(python=["3"], reuse_venv=True)
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
def l0_dla_tests_host_deps(session):
"""When a developer needs to check basic api functionality using host dependencies"""
run_l0_dla_tests(session, use_host_env=True)

@nox.session(python=["3"], reuse_venv=True)
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
def l1_accuracy_tests(session):
"""Checking accuracy performance on various usecases"""
run_l1_accuracy_tests(session, use_host_env=False)

@nox.session(python=["3"], reuse_venv=True)
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
def l1_accuracy_tests_host_deps(session):
"""Checking accuracy performance on various usecases using host dependencies"""
run_l1_accuracy_tests(session, use_host_env=True)

@nox.session(python=["3"], reuse_venv=True)
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
def l1_int8_accuracy_tests(session):
"""Checking accuracy performance on various usecases"""
run_l1_int8_accuracy_tests(session, use_host_env=False)

@nox.session(python=["3"], reuse_venv=True)
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
def l1_int8_accuracy_tests_host_deps(session):
"""Checking accuracy performance on various usecases using host dependencies"""
run_l1_int8_accuracy_tests(session, use_host_env=True)

@nox.session(python=["3"], reuse_venv=True)
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
def l2_trt_compatibility_tests(session):
"""Makes sure that TensorRT Python and Torch-TensorRT can work together"""
run_l2_trt_compatibility_tests(session, use_host_env=False)

@nox.session(python=["3"], reuse_venv=True)
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
def l2_trt_compatibility_tests_host_deps(session):
"""Makes sure that TensorRT Python and Torch-TensorRT can work together using host dependencies"""
run_l2_trt_compatibility_tests(session, use_host_env=True)

@nox.session(python=["3"], reuse_venv=True)
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
def l2_multi_gpu_tests(session):
"""Makes sure that Torch-TensorRT can operate on multi-gpu systems"""
run_l2_multi_gpu_tests(session, use_host_env=False)

@nox.session(python=["3"], reuse_venv=True)
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
def l2_multi_gpu_tests_host_deps(session):
"""Makes sure that Torch-TensorRT can operate on multi-gpu systems using host dependencies"""
run_l2_multi_gpu_tests(session, use_host_env=True)

@nox.session(python=["3"], reuse_venv=True)
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
def download_test_models(session):
"""Grab all the models needed for testing"""
download_models(session, use_host_env=True)
2 changes: 2 additions & 0 deletions py/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
-f https://download.pytorch.org/whl/torch_stable.html
-f https://download.pytorch.org/whl/torch/
--extra-index-url https://download.pytorch.org/whl/cu113
torch==1.11.0+cu113
pybind11==2.6.2
Empty file modified tests/core/lowering/test_reduce_to_pass.cpp
100755 → 100644
Empty file.