diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index b3c6d1824369..935312136d02 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -153,9 +153,12 @@ candidates = ( "onnxruntime", "onnxruntime-gpu", + "ort_nightly_gpu", "onnxruntime-directml", "onnxruntime-openvino", "ort_nightly_directml", + "onnxruntime-rocm", + "onnxruntime-training", ) _onnxruntime_version = None # For the metadata, we have to look for both onnxruntime and onnxruntime-gpu