From 80dedbdf9851edd9d0d47d4c555d1a5549de00e6 Mon Sep 17 00:00:00 2001 From: Andrey Pavlenko Date: Wed, 29 May 2024 18:22:18 +0200 Subject: [PATCH] FIX-#111: oneDnn JsonParser: Convert auto_broadcast attribute to boolean Fixes #111 --- src/dnnl/JsonParser.cpp | 19 ++++++++++++++++++- test/dnnl/TestJsonParser.cpp | 4 ++-- test/dnnl/resources/add_relu.json | 2 +- 3 files changed, 21 insertions(+), 4 deletions(-) diff --git a/src/dnnl/JsonParser.cpp b/src/dnnl/JsonParser.cpp index f9b291cb7..1c0e15b94 100644 --- a/src/dnnl/JsonParser.cpp +++ b/src/dnnl/JsonParser.cpp @@ -144,7 +144,24 @@ void JsonParser::readOp() { _reader.begin_object(); while (_reader.next_object_item(&_str)) { auto name = mlir::StringAttr::get(_builder.getContext(), _str); - _attributes.emplace_back(name, readAttr()); + auto value = readAttr(); + + if (name == "auto_broadcast") { // Convert to boolean + if (value.getTypeID() != mlir::StringAttr::getTypeID()) { + _str = "auto_broadcast"; + throwErr("Invalid attribute type: "); + } + if (_str == "numpy") { + value = _builder.getBoolAttr(true); + } else if (_str == "none") { + value = _builder.getBoolAttr(false); + } else { + throwErr( + "Invalid auto_broadcast attribute value: "); + } + } + + _attributes.emplace_back(name, value); } } else if (_str == "inputs") { _reader.begin_array(); diff --git a/test/dnnl/TestJsonParser.cpp b/test/dnnl/TestJsonParser.cpp index 1d8093a4c..99e1c3b8f 100644 --- a/test/dnnl/TestJsonParser.cpp +++ b/test/dnnl/TestJsonParser.cpp @@ -92,8 +92,8 @@ TEST(TestJsonParser, AddRelu) { ASSERT_EQ(attrs.size(), 1); ASSERT_EQ(attrs.begin()->getName(), mlir::StringAttr::get(addOp->getContext(), "auto_broadcast")); - ASSERT_EQ(mlir::cast(attrs.begin()->getValue()).getValue(), - "numpy"); + ASSERT_FALSE( + mlir::cast(attrs.begin()->getValue()).getValue()); checkTensorType(addOp->getOperandTypes()[0]); checkTensorType(addOp->getOperandTypes()[1]); checkTensorType(addOp->getResultTypes()[0]); diff --git a/test/dnnl/resources/add_relu.json b/test/dnnl/resources/add_relu.json index bb1e725c6..1a68fb8c0 100644 --- a/test/dnnl/resources/add_relu.json +++ b/test/dnnl/resources/add_relu.json @@ -17,7 +17,7 @@ "attrs": { "auto_broadcast": { "type": "string", - "value": "numpy" + "value": "none" } }, "inputs": [