Skip to content

Commit bf6c929

Browse files
committed
feat: Handle empty schemas for unsupported ops
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 43a53ce commit bf6c929

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

core/conversion/conversion.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#include "core/conversion/evaluators/evaluators.h"
77
#include "core/conversion/var/Var.h"
88
#include "core/util/prelude.h"
9-
9+
#include <ATen/core/operator_name.h>
1010
#include "c10/util/intrusive_ptr.h"
1111
#include "core/conversion/converters/converter_util.h"
1212
#include "core/conversion/tensorcontainer/TensorContainer.h"
@@ -491,11 +491,20 @@ std::unordered_map<c10::OperatorName, std::string> GetUnsupportedOpsInBlock(cons
491491
auto schema = n->maybeSchema();
492492
// Some ops like torch::jit::prim::Loop, torch::jit::prim::If, torch::jit::prim::DictConstruct don't have a schema but they are supported.
493493
// torch::jit::prim::DictConstruct is supported via fallback only
494-
if (schema && !OpSupported(n)) {
495-
std::stringstream ss;
496-
ss << *schema;
497-
unsupported_ops[schema->operator_name()] = ss.str();
494+
if (!OpSupported(n)) {
495+
if (schema){
496+
std::stringstream ss;
497+
ss << *schema;
498+
unsupported_ops[schema->operator_name()] = ss.str();
499+
} else {
500+
std::stringstream ss;
501+
ss << util::node_info(n);
502+
// operator.overload is a filler name just to call the constructor.
503+
c10::OperatorName op(ss.str(), "operator.overload");
504+
unsupported_ops[op] = ss.str();
505+
}
498506
}
507+
499508
for (const auto sub_b : n->blocks()) {
500509
auto sub_b_unsupported_ops = GetUnsupportedOpsInBlock(sub_b);
501510
unsupported_ops.insert(sub_b_unsupported_ops.begin(), sub_b_unsupported_ops.end());
@@ -530,7 +539,7 @@ std::set<std::string> ConvertableOpsInBlock(const torch::jit::Block* b) {
530539

531540
bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_errors) {
532541
auto unsupported_ops = GetUnsupportedOpsInBlock(b);
533-
542+
LOG_DEBUG("======unsupported_ops size ===========: " << unsupported_ops.size());
534543
if (unsupported_ops.size() != 0) {
535544
std::stringstream unsupported_msg;
536545
unsupported_msg

0 commit comments

Comments
 (0)