diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp index 285398311fd19..49a391938eaf6 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -165,8 +165,8 @@ static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) { if (srcInt.getType().isInteger(1)) return srcInt; - auto one = spirv::ConstantOp::getOne(srcInt.getType(), loc, builder); - return builder.createOrFold(loc, srcInt, one); + auto one = spirv::ConstantOp::getZero(srcInt.getType(), loc, builder); + return builder.createOrFold(loc, srcInt, one); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir index 6dd5b1988e2a2..8906de9db3724 100644 --- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir @@ -76,8 +76,8 @@ func.func @load_i1(%src: memref<4xi1, #spirv.storage_class>, %i : // CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32 // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[ZERO]], %[[IDX_CAST]]] // CHECK: %[[VAL:.+]] = spirv.Load "StorageBuffer" %[[ADDR]] : i8 - // CHECK: %[[ONE_I8:.+]] = spirv.Constant 1 : i8 - // CHECK: %[[BOOL:.+]] = spirv.IEqual %[[VAL]], %[[ONE_I8]] : i8 + // CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8 + // CHECK: %[[BOOL:.+]] = spirv.INotEqual %[[VAL]], %[[ZERO_I8]] : i8 %0 = memref.load %src[%i] : memref<4xi1, #spirv.storage_class> // CHECK: return %[[BOOL]] return %0: i1 @@ -234,8 +234,8 @@ func.func @load_i1(%src: memref<4xi1, #spirv.storage_class>, %i // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]] // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[IDX_CAST]]] // CHECK: %[[VAL:.+]] = spirv.Load "CrossWorkgroup" %[[ADDR]] : i8 - // CHECK: %[[ONE_I8:.+]] = spirv.Constant 1 : i8 - // CHECK: %[[BOOL:.+]] = spirv.IEqual %[[VAL]], %[[ONE_I8]] : i8 + // CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8 + // CHECK: %[[BOOL:.+]] = spirv.INotEqual %[[VAL]], %[[ZERO_I8]] : i8 %0 = memref.load %src[%i] : memref<4xi1, #spirv.storage_class> // CHECK: return %[[BOOL]] return %0: i1