From d70ff44bbdb41373b6c762516580097d2c6e5a5a Mon Sep 17 00:00:00 2001 From: fatcat-z Date: Thu, 27 Apr 2023 11:17:20 +0800 Subject: [PATCH 1/4] Add an environment variable to disable ort session cache for windows tests. --- .github/workflows/main.yaml | 1 + onnxscript/evaluator.py | 7 ++++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index db19cba05d..36570576d5 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -66,6 +66,7 @@ jobs: - name: Run tests run: nox -t ${{ matrix.nox-tag }} --forcecolor -- -v --cov=onnxscript --cov-report=xml --cov-append --cov-branch -n=auto env: + CACHE_ORT_SESSION: "${{ matrix.os == 'windows-latest' && '0' || '1' }}" CATCH_ORT_SEGFAULT: "${{ matrix.os == 'ubuntu-latest' && '1' || '0' }}" - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 diff --git a/onnxscript/evaluator.py b/onnxscript/evaluator.py index d3890be2c6..8cfc99c45e 100644 --- a/onnxscript/evaluator.py +++ b/onnxscript/evaluator.py @@ -6,6 +6,7 @@ import abc import contextlib +import os import pprint import typing from typing import ( @@ -346,7 +347,7 @@ def _compute_num_outputs(schema: onnx.defs.OpSchema, *args: Any, **kwargs: Any): _cache_models: dict[Any, ort.InferenceSession] = {} - +cache_ort_session = os.environ.get("CACHE_ORT_SESSION", "1") def _cache_(model, providers): # Delay import onnxruntime so that onnxscript can be used without @@ -354,11 +355,15 @@ def _cache_(model, providers): import onnxruntime as ort # pylint: disable=import-outside-toplevel serialized = model.SerializeToString() + if cache_ort_session == "0": + return ort.InferenceSession(serialized, providers=providers) + key = serialized, tuple(providers) if key in _cache_models: return _cache_models[key] session = ort.InferenceSession(serialized, providers=providers) _cache_models[key] = session + return session From 1fadd936016691491a3648cfe256dec15f1ecc04 Mon Sep 17 00:00:00 2001 From: fatcat-z Date: Thu, 27 Apr 2023 12:16:55 +0800 Subject: [PATCH 2/4] Fix a lint issue. --- onnxscript/evaluator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxscript/evaluator.py b/onnxscript/evaluator.py index 8cfc99c45e..eabb8ad5a8 100644 --- a/onnxscript/evaluator.py +++ b/onnxscript/evaluator.py @@ -349,6 +349,7 @@ def _compute_num_outputs(schema: onnx.defs.OpSchema, *args: Any, **kwargs: Any): _cache_models: dict[Any, ort.InferenceSession] = {} cache_ort_session = os.environ.get("CACHE_ORT_SESSION", "1") + def _cache_(model, providers): # Delay import onnxruntime so that onnxscript can be used without # installing onnxruntime. From 16244e7141d57ce19406ed238ac37cdc51f78a67 Mon Sep 17 00:00:00 2001 From: fatcat-z Date: Fri, 28 Apr 2023 23:57:17 +0800 Subject: [PATCH 3/4] Add a centralized place for possible feature switches. --- onnxscript/_internal/feature_switch.py | 10 ++++++++++ onnxscript/evaluator.py | 7 ++----- 2 files changed, 12 insertions(+), 5 deletions(-) create mode 100644 onnxscript/_internal/feature_switch.py diff --git a/onnxscript/_internal/feature_switch.py b/onnxscript/_internal/feature_switch.py new file mode 100644 index 0000000000..3303bb7ec4 --- /dev/null +++ b/onnxscript/_internal/feature_switch.py @@ -0,0 +1,10 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Switches to determine if the corresponding feature of onnxscript is enabled or not.""" + +import os + +# By default: Enable +cache_ort_session = os.environ.get("CACHE_ORT_SESSION", "1") diff --git a/onnxscript/evaluator.py b/onnxscript/evaluator.py index eabb8ad5a8..3af12724c5 100644 --- a/onnxscript/evaluator.py +++ b/onnxscript/evaluator.py @@ -6,7 +6,6 @@ import abc import contextlib -import os import pprint import typing from typing import ( @@ -28,7 +27,7 @@ from typing_extensions import TypeAlias from onnxscript import autocast, irbuilder, onnx_opset, tensor, utils, values -from onnxscript._internal import param_manipulation +from onnxscript._internal import feature_switch, param_manipulation if typing.TYPE_CHECKING: import onnxruntime as ort @@ -347,8 +346,6 @@ def _compute_num_outputs(schema: onnx.defs.OpSchema, *args: Any, **kwargs: Any): _cache_models: dict[Any, ort.InferenceSession] = {} -cache_ort_session = os.environ.get("CACHE_ORT_SESSION", "1") - def _cache_(model, providers): # Delay import onnxruntime so that onnxscript can be used without @@ -356,7 +353,7 @@ def _cache_(model, providers): import onnxruntime as ort # pylint: disable=import-outside-toplevel serialized = model.SerializeToString() - if cache_ort_session == "0": + if feature_switch.cache_ort_session == "0": return ort.InferenceSession(serialized, providers=providers) key = serialized, tuple(providers) From 4083b0791a81809538e43005567357182d6807b9 Mon Sep 17 00:00:00 2001 From: fatcat-z Date: Sat, 29 Apr 2023 00:05:22 +0800 Subject: [PATCH 4/4] Address comments. --- .github/workflows/main.yaml | 2 +- onnxscript/_internal/feature_switch.py | 2 +- onnxscript/evaluator.py | 17 +++++++++-------- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 36570576d5..f77791225c 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -66,7 +66,7 @@ jobs: - name: Run tests run: nox -t ${{ matrix.nox-tag }} --forcecolor -- -v --cov=onnxscript --cov-report=xml --cov-append --cov-branch -n=auto env: - CACHE_ORT_SESSION: "${{ matrix.os == 'windows-latest' && '0' || '1' }}" + CACHE_ORT_SESSIONS: "${{ matrix.os == 'windows-latest' && '0' || '1' }}" CATCH_ORT_SEGFAULT: "${{ matrix.os == 'ubuntu-latest' && '1' || '0' }}" - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 diff --git a/onnxscript/_internal/feature_switch.py b/onnxscript/_internal/feature_switch.py index 3303bb7ec4..4b92d04a33 100644 --- a/onnxscript/_internal/feature_switch.py +++ b/onnxscript/_internal/feature_switch.py @@ -7,4 +7,4 @@ import os # By default: Enable -cache_ort_session = os.environ.get("CACHE_ORT_SESSION", "1") +CACHE_ORT_SESSIONS: bool = os.getenv("CACHE_ORT_SESSIONS", "1") != "0" diff --git a/onnxscript/evaluator.py b/onnxscript/evaluator.py index 3af12724c5..8a5f81d1ce 100644 --- a/onnxscript/evaluator.py +++ b/onnxscript/evaluator.py @@ -347,22 +347,23 @@ def _compute_num_outputs(schema: onnx.defs.OpSchema, *args: Any, **kwargs: Any): _cache_models: dict[Any, ort.InferenceSession] = {} + def _cache_(model, providers): # Delay import onnxruntime so that onnxscript can be used without # installing onnxruntime. import onnxruntime as ort # pylint: disable=import-outside-toplevel serialized = model.SerializeToString() - if feature_switch.cache_ort_session == "0": - return ort.InferenceSession(serialized, providers=providers) + if feature_switch.CACHE_ORT_SESSIONS: + key = serialized, tuple(providers) + if key in _cache_models: + return _cache_models[key] + session = ort.InferenceSession(serialized, providers=providers) + _cache_models[key] = session - key = serialized, tuple(providers) - if key in _cache_models: - return _cache_models[key] - session = ort.InferenceSession(serialized, providers=providers) - _cache_models[key] = session + return session - return session + return ort.InferenceSession(serialized, providers=providers) def _os_to_ort_value(v):