Skip to content

Commit ef323be

Browse files
committed
Update mlir lit tests impacted by mlir update.
llvm/llvm-project#107109 The core dialect conversion no longer tries to materialize ops on the fly during conversion, but rather inserts unrealized conversion casts to be safe, and removes these inverse cast pairs later after the entire conversion is done. The old analysis is removed. Personally I think this is fine, as we end our pipeline with --reconcil-unrealized-casts already. Changes: Catalyst/ConversionTest.mlir and Catalyst/MemrefLoadStoreLoweringTBAA.mlir: changing op order Quantum/ConversionTest.mlir: changing op order; adding two unrealized conversion cast pairs Gradient/ConversionTest.mlir: 5 unused Values are removed
1 parent d5cdb13 commit ef323be

File tree

4 files changed

+15
-17
lines changed

4 files changed

+15
-17
lines changed

mlir/test/Catalyst/ConversionTest.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,9 @@ module @test1 {
147147
// CHECK-SAME:)
148148
func.func private @foo(%arg0: tensor<f64>) -> tensor<f64> {
149149
// CHECK: [[memref0:%.+]] = bufferization.to_memref [[arg0]]
150-
// CHECK: [[struct0:%.+]] = builtin.unrealized_conversion_cast [[memref0]]
151150
// CHECK: [[ptr0:%.+]] = llvm.alloca {{.*}}
152151
// CHECK: [[ptr1:%.+]] = llvm.alloca {{.*}}
152+
// CHECK: [[struct0:%.+]] = builtin.unrealized_conversion_cast [[memref0]]
153153

154154
// CHECK: [[tensor1:%.+]] = bufferization.alloc_tensor()
155155
// CHECK: [[memref1:%.+]] = bufferization.to_memref [[tensor1]]

mlir/test/Catalyst/MemrefLoadStoreLoweringTBAA.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ module @my_model {
3535
llvm.func @my_func(...)
3636
llvm.func @__enzyme_autodiff0(...)
3737
func.func @func_i32(%arg0: memref<i32>, %arg1: memref<4xi32>) -> (memref<i32>, memref<4xi32>) {
38-
// CHECK: [[castArg0:%.+]] = builtin.unrealized_conversion_cast %arg0 : memref<i32> to !llvm.struct<(ptr, ptr, i64)>
3938
// CHECK: [[castArg1:%.+]] = builtin.unrealized_conversion_cast %arg1 : memref<4xi32> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
39+
// CHECK: [[castArg0:%.+]] = builtin.unrealized_conversion_cast %arg0 : memref<i32> to !llvm.struct<(ptr, ptr, i64)>
4040
// CHECK: [[extract0:%.+]] = llvm.extractvalue [[castArg0]][1] : !llvm.struct<(ptr, ptr, i64)>
4141
// CHECK: [[load:%.+]] = llvm.load [[extract0]] {tbaa = [[[tag]]]} : !llvm.ptr -> i32
4242
// CHECK: [[idx:%.+]] = index.constant 0
@@ -59,8 +59,8 @@ module @my_model {
5959
module @my_model {
6060
llvm.func @__enzyme_autodiff1(...)
6161
func.func @func_f32(%arg0: memref<f32>, %arg1: memref<4xf32>) -> (memref<f32>, memref<4xf32>) {
62-
// CHECK: [[castArg0:%.+]] = builtin.unrealized_conversion_cast %arg0 : memref<f32> to !llvm.struct<(ptr, ptr, i64)>
6362
// CHECK: [[castArg1:%.+]] = builtin.unrealized_conversion_cast %arg1 : memref<4xf32> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
63+
// CHECK: [[castArg0:%.+]] = builtin.unrealized_conversion_cast %arg0 : memref<f32> to !llvm.struct<(ptr, ptr, i64)>
6464
// CHECK: [[extract0:%.+]] = llvm.extractvalue [[castArg0]][1] : !llvm.struct<(ptr, ptr, i64)>
6565
// CHECK: [[load:%.+]] = llvm.load [[extract0]] {tbaa = [[[tag]]]} : !llvm.ptr -> f32
6666
// CHECK: [[idx:%.+]] = index.constant 0
@@ -83,8 +83,8 @@ module @my_model {
8383
module @my_model {
8484
llvm.func @__enzyme_autodiff1(...)
8585
func.func @func_f64(%arg0: memref<f64>, %arg1: memref<4xf64>) -> (memref<f64>, memref<4xf64>) {
86-
// CHECK: [[castArg0:%.+]] = builtin.unrealized_conversion_cast %arg0 : memref<f64> to !llvm.struct<(ptr, ptr, i64)>
8786
// CHECK: [[castArg1:%.+]] = builtin.unrealized_conversion_cast %arg1 : memref<4xf64> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
87+
// CHECK: [[castArg0:%.+]] = builtin.unrealized_conversion_cast %arg0 : memref<f64> to !llvm.struct<(ptr, ptr, i64)>
8888
// CHECK: [[extract0:%.+]] = llvm.extractvalue [[castArg0]][1] : !llvm.struct<(ptr, ptr, i64)>
8989
// CHECK: [[load:%.+]] = llvm.load [[extract0]] {tbaa = [[[tag]]]} : !llvm.ptr -> f64
9090
// CHECK: [[idx:%.+]] = index.constant 0
@@ -109,10 +109,10 @@ module @my_model {
109109
module @my_model {
110110
llvm.func @__enzyme_autodiff2(...)
111111
func.func @func_mix_f64_index(%arg0: memref<f64>, %arg1: memref<4xf64>, %arg2: memref<index>, %arg3: memref<3xindex>) -> (memref<4xf64>, memref<3xindex>) {
112-
// CHECK: [[castArg0:%.+]] = builtin.unrealized_conversion_cast %arg0 : memref<f64> to !llvm.struct<(ptr, ptr, i64)>
113-
// CHECK: [[castArg1:%.+]] = builtin.unrealized_conversion_cast %arg1 : memref<4xf64> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
114-
// CHECK: [[castArg2:%.+]] = builtin.unrealized_conversion_cast %arg2 : memref<index> to !llvm.struct<(ptr, ptr, i64)>
115112
// CHECK: [[castArg3:%.+]] = builtin.unrealized_conversion_cast %arg3 : memref<3xindex> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
113+
// CHECK: [[castArg2:%.+]] = builtin.unrealized_conversion_cast %arg2 : memref<index> to !llvm.struct<(ptr, ptr, i64)>
114+
// CHECK: [[castArg1:%.+]] = builtin.unrealized_conversion_cast %arg1 : memref<4xf64> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
115+
// CHECK: [[castArg0:%.+]] = builtin.unrealized_conversion_cast %arg0 : memref<f64> to !llvm.struct<(ptr, ptr, i64)>
116116
// CHECK: [[extract0:%.+]] = llvm.extractvalue [[castArg0]][1] : !llvm.struct<(ptr, ptr, i64)>
117117
// CHECK: [[load0:%.+]] = llvm.load [[extract0]] {tbaa = [[[tagdouble]]]} : !llvm.ptr -> f64
118118
// CHECK: [[idx:%.+]] = index.constant 0

mlir/test/Gradient/ConversionTest.mlir

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,8 @@ module @test0 {
6262
// CHECK: [[in0struct:%.+]] = llvm.load [[in0ptr]] : !llvm.ptr -> !llvm.struct<(ptr, ptr, i64)>
6363
// CHECK: [[in0memref:%.+]] = builtin.unrealized_conversion_cast [[in0struct]] : !llvm.struct<(ptr, ptr, i64)> to memref<f64>
6464
// CHECK: [[diff0struct:%.+]] = llvm.load [[diff0ptr]] : !llvm.ptr -> !llvm.struct<(ptr, ptr, i64)>
65-
// CHECK: [[diff0memref:%.+]] = builtin.unrealized_conversion_cast [[diff0struct]] : !llvm.struct<(ptr, ptr, i64)> to memref<f64>
6665
// CHECK: [[out0struct:%.+]] = llvm.load [[out0ptr]] : !llvm.ptr -> !llvm.struct<(ptr, ptr, i64)>
67-
// CHECK: [[out0memref:%.+]] = builtin.unrealized_conversion_cast [[out0struct]] : !llvm.struct<(ptr, ptr, i64)> to memref<f64>
6866
// CHECK: [[cotan0struct:%.+]] = llvm.load [[cotan0ptr]] : !llvm.ptr -> !llvm.struct<(ptr, ptr, i64)>
69-
// CHECK: [[cotan0memref:%.+]] = builtin.unrealized_conversion_cast [[cotan0struct]] : !llvm.struct<(ptr, ptr, i64)> to memref<f64>
7067
// CHECK: [[results:%.+]]:2 = call @fwd([[in0memref]])
7168

7269
%1:2 = func.call @fwd(%arg0) : (memref<f64>) -> (memref<f64>, memref<f64>)
@@ -117,9 +114,7 @@ module @test1 {
117114
gradient.reverse @rev.rev(%arg0: memref<f64>, %arg1: memref<f64>, %arg2: memref<f64>, %arg3: memref<f64>, %arg4: memref<f64>) attributes {argc = 1 : i64, implementation = @rev, resc = 1 : i64, tape = 1 : i64} {
118115

119116
// CHECK: [[in0struct:%.+]] = llvm.load [[in0ptr]] : !llvm.ptr -> !llvm.struct<(ptr, ptr, i64)>
120-
// CHECK: [[in0memref:%.+]] = builtin.unrealized_conversion_cast [[in0struct]] : !llvm.struct<(ptr, ptr, i64)> to memref<f64>
121117
// CHECK: [[diff0struct:%.+]] = llvm.load [[diff0ptr]] : !llvm.ptr -> !llvm.struct<(ptr, ptr, i64)>
122-
// CHECK: [[diff0memref:%.+]] = builtin.unrealized_conversion_cast [[diff0struct]] : !llvm.struct<(ptr, ptr, i64)> to memref<f64>
123118
// CHECK: [[out0struct:%.+]] = llvm.load [[out0ptr]] : !llvm.ptr -> !llvm.struct<(ptr, ptr, i64)>
124119
// CHECK: [[out0memref:%.+]] = builtin.unrealized_conversion_cast [[out0struct]] : !llvm.struct<(ptr, ptr, i64)> to memref<f64>
125120
// CHECK: [[cotan0struct:%.+]] = llvm.load [[cotan0ptr]] : !llvm.ptr -> !llvm.struct<(ptr, ptr, i64)>

mlir/test/Quantum/ConversionTest.mlir

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -221,9 +221,9 @@ module @custom_gate {
221221
// CHECK: llvm.func @__catalyst__qis__RX(f64, !llvm.ptr, !llvm.ptr)
222222
// CHECK-LABEL: @test
223223
func.func @test(%q0: !quantum.bit, %p: f64) -> () {
224-
// CHECK: [[nullptr:%.+]] = llvm.mlir.zero
225224
// CHECK: [[c1:%.+]] = llvm.mlir.constant(1 : i64)
226225
// CHECK: [[alloca:%.+]] = llvm.alloca [[c1]] x !llvm.struct<(i1, i64, ptr, ptr)>
226+
// CHECK: [[nullptr:%.+]] = llvm.mlir.zero
227227
// CHECK: [[true:%.+]] = llvm.mlir.constant(true)
228228
// CHECK: [[off0:%.+]] = llvm.getelementptr inbounds [[alloca]][0, 0]
229229
// CHECK: [[off1:%.+]] = llvm.getelementptr inbounds [[alloca]][0, 1]
@@ -382,6 +382,8 @@ func.func @hamiltonian(%obs : !quantum.obs, %p1 : memref<1xf64>, %p2 : memref<3x
382382
// CHECK: [[memrefvar2:%.+]] = llvm.insertvalue %arg3, [[memrefvar1]][2]
383383
// CHECK: [[memrefvar3:%.+]] = llvm.insertvalue %arg4, [[memrefvar2]][3, 0]
384384
// CHECK: [[memrefvar4:%.+]] = llvm.insertvalue %arg5, [[memrefvar3]][4, 0]
385+
// CHECK: [[cast:%.+]] = builtin.unrealized_conversion_cast [[memrefvar4]]
386+
// CHECK: [[memrefvar4:%.+]] = builtin.unrealized_conversion_cast [[cast]]
385387
// CHECK: [[c1:%.+]] = llvm.mlir.constant(1 : i64)
386388
// CHECK: llvm.store [[memrefvar4]], [[alloca]]
387389
// CHECK: llvm.call @__catalyst__qis__HamiltonianObs([[alloca]], [[c1]], %arg0)
@@ -401,8 +403,10 @@ func.func @hamiltonian(%obs : !quantum.obs, %p1 : memref<1xf64>, %p2 : memref<3x
401403
// CHECK: [[memrefvar2:%.+]] = llvm.insertvalue %arg8, [[memrefvar1]][2]
402404
// CHECK: [[memrefvar3:%.+]] = llvm.insertvalue %arg9, [[memrefvar2]][3, 0]
403405
// CHECK: [[memrefvar4:%.+]] = llvm.insertvalue %arg10, [[memrefvar3]][4, 0]
406+
// CHECK: [[cast:%.+]] = builtin.unrealized_conversion_cast [[memrefvar4]]
404407
// CHECK: [[c1:%.+]] = llvm.mlir.constant(1 : i64)
405408
// CHECK: [[alloca:%.+]] = llvm.alloca [[c1]] x !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
409+
// CHECK: [[memrefvar4:%.+]] = builtin.unrealized_conversion_cast [[cast]]
406410
// CHECK: [[c3:%.+]] = llvm.mlir.constant(3 : i64)
407411
// CHECK: llvm.store [[memrefvar4]], [[alloca]]
408412
// CHECK: llvm.call @__catalyst__qis__HamiltonianObs([[alloca]], [[c3]], %arg0, %arg0, %arg0)
@@ -580,12 +584,11 @@ func.func @probs(%q : !quantum.bit) {
580584

581585
// CHECK-LABEL: @state
582586
func.func @state(%q : !quantum.bit) {
583-
// CHECK: [[qb:%.+]] = builtin.unrealized_conversion_cast %arg0
584-
585587
%o1 = quantum.compbasis qubits %q : !quantum.obs
586588

587589
// CHECK: [[c1:%.+]] = llvm.mlir.constant(1 : i64)
588590
// CHECK: [[ptr:%.+]] = llvm.alloca [[c1]] x !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
591+
// CHECK: [[qb:%.+]] = builtin.unrealized_conversion_cast %arg0
589592
// CHECK: [[c0:%.+]] = llvm.mlir.constant(0 : i64)
590593
// CHECK: llvm.call @__catalyst__qis__State([[ptr]], [[c0]])
591594
%alloc1 = memref.alloc() : memref<2xcomplex<f64>>
@@ -618,13 +621,13 @@ func.func @controlled_circuit(%1 : !quantum.bit, %2 : !quantum.bit, %3 : !quantu
618621
%cst_0 = llvm.mlir.constant (9.000000e-01 : f64) : f64
619622
%cst_1 = llvm.mlir.constant (3.000000e-01 : f64) : f64
620623

621-
// CHECK: [[true:%.+]] = llvm.mlir.constant(true)
622624
// CHECK: [[c1:%.+]] = llvm.mlir.constant(1 : i64)
623625
// CHECK: [[alloca0:%.+]] = llvm.alloca [[c1]] x i1
624626
// CHECK: [[c1:%.+]] = llvm.mlir.constant(1 : i64)
625627
// CHECK: [[alloca1:%.+]] = llvm.alloca [[c1]] x !llvm.ptr
626628
// CHECK: [[c1:%.+]] = llvm.mlir.constant(1 : i64)
627629
// CHECK: [[mod:%.+]] = llvm.alloca [[c1]] x !llvm.struct<(i1, i64, ptr, ptr)>
630+
// CHECK: [[true:%.+]] = llvm.mlir.constant(true)
628631

629632

630633
// CHECK-DAG: [[cst6:%.+]] = llvm.mlir.constant(6.0
@@ -658,13 +661,13 @@ func.func @controlled_circuit(%1 : !quantum.bit, %2 : !quantum.bit, %3 : !quantu
658661
%cst = llvm.mlir.constant (6.000000e-01 : f64) : f64
659662
%true = llvm.mlir.constant (1 : i1) :i1
660663

661-
// CHECK-DAG: [[cst6:%.+]] = llvm.mlir.constant(6.0
662664
// CHECK: [[c1:%.+]] = llvm.mlir.constant(1 : i64)
663665
// CHECK: [[alloca0:%.+]] = llvm.alloca [[c1]] x i1
664666
// CHECK: [[c1:%.+]] = llvm.mlir.constant(1 : i64)
665667
// CHECK: [[alloca1:%.+]] = llvm.alloca [[c1]] x !llvm.ptr
666668
// CHECK: [[c1:%.+]] = llvm.mlir.constant(1 : i64)
667669
// CHECK: [[mod:%.+]] = llvm.alloca [[c1]] x !llvm.struct<(i1, i64, ptr, ptr)>
670+
// CHECK: [[cst6:%.+]] = llvm.mlir.constant(6.0
668671
// CHECK: [[true:%.+]] = llvm.mlir.constant(true)
669672

670673
// CHECK: [[offset0:%.+]] = llvm.getelementptr inbounds [[mod]][0, 0]

0 commit comments

Comments
 (0)