Skip to content

Commit d020d8e

Browse files
authored
fix scf afer upstream ForAll (#171)
1 parent b6729e9 commit d020d8e

File tree

3 files changed

+7
-69
lines changed

3 files changed

+7
-69
lines changed

mlir/extras/dialects/ext/scf.py

Lines changed: 0 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -108,68 +108,6 @@ def placeholder_opaque_t():
108108
for__ = region_op(_build_for, terminator=yield__)
109109

110110

111-
@_cext.register_operation(_Dialect, replace=True)
112-
class ForallOp(ForallOp):
113-
def __init__(
114-
self,
115-
lower_bounds,
116-
upper_bounds,
117-
steps,
118-
shared_outs: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
119-
*,
120-
device_mapping: Optional[List[Attribute]] = None,
121-
loc=None,
122-
ip=None,
123-
):
124-
assert len(lower_bounds) == len(upper_bounds) == len(steps)
125-
if shared_outs is not None:
126-
results = [o.type for o in shared_outs]
127-
else:
128-
results = shared_outs = []
129-
iv_types = [IndexType.get()] * len(lower_bounds)
130-
context = get_default_loc_context(loc)
131-
mapping = None
132-
if device_mapping is not None:
133-
mapping = get_device_mapping_array_attr(device_mapping)
134-
135-
super().__init__(
136-
results_=results,
137-
dynamicLowerBound=[],
138-
dynamicUpperBound=[],
139-
dynamicStep=[],
140-
staticLowerBound=_denseI64ArrayAttr(lower_bounds, context),
141-
staticUpperBound=_denseI64ArrayAttr(upper_bounds, context),
142-
staticStep=_denseI64ArrayAttr(steps, context),
143-
outputs=shared_outs,
144-
mapping=mapping,
145-
loc=loc,
146-
ip=ip,
147-
)
148-
self.regions[0].blocks.append(*iv_types, *results)
149-
150-
@property
151-
def body(self):
152-
"""Returns the body (block) of the loop."""
153-
return self.regions[0].blocks[0]
154-
155-
@property
156-
def arguments(self):
157-
"""Returns the induction variable of the loop."""
158-
return self.body.arguments
159-
160-
161-
@_cext.register_operation(_Dialect, replace=True)
162-
class InParallelOp(InParallelOp):
163-
def __init__(self, *, loc=None, ip=None):
164-
super().__init__(loc=loc, ip=ip)
165-
self.regions[0].blocks.append(*[])
166-
167-
@property
168-
def body(self):
169-
"""Returns the body (block) of the loop."""
170-
return self.regions[0].blocks[0]
171-
172-
173111
def _parfor(op_ctor):
174112
def _base(
175113
lower_bounds, upper_bounds=None, steps=None, *, loc=None, ip=None, **kwargs

tests/test_gpu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def test_forall_insert_slice_no_region_with_for_with_gpu_mapping(ctx: MLIRContex
8181
[1, 1],
8282
[2, 2],
8383
[3, 3],
84-
device_mapping=[thread("x"), thread("y")],
84+
mapping=[thread("x"), thread("y")],
8585
):
8686
a = memref.load(x, (i, j))
8787
b = memref.load(y, (i, j))
@@ -101,7 +101,7 @@ def test_forall_insert_slice_no_region_with_for_with_gpu_mapping(ctx: MLIRContex
101101
# CHECK: %[[VAL_6:.*]] = arith.constant 2 : index
102102
# CHECK: %[[VAL_7:.*]] = arith.constant 3 : index
103103
# CHECK: %[[VAL_8:.*]] = arith.constant 3 : index
104-
# CHECK: scf.forall (%[[VAL_9:.*]], %[[VAL_10:.*]]) = (1, 1) to (2, 2) step (3, 3) {
104+
# CHECK: scf.forall (%[[VAL_9:.*]], %[[VAL_10:.*]]) = (%[[VAL_3]], %[[VAL_4]]) to (%[[VAL_5]], %[[VAL_6]]) step (%[[VAL_7]], %[[VAL_8]]) {
105105
# CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_9]], %[[VAL_10]]] : memref<10x10xf32>
106106
# CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_9]], %[[VAL_10]]] : memref<10x10xf32>
107107
# CHECK: %[[VAL_13:.*]] = math.fma %[[VAL_2]], %[[VAL_11]], %[[VAL_12]] : f32

tests/test_scf.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2133,7 +2133,7 @@ def forfoo(ivs):
21332133
# CHECK: %[[VAL_0:.*]] = arith.constant 1 : index
21342134
# CHECK: %[[VAL_1:.*]] = arith.constant 2 : index
21352135
# CHECK: %[[VAL_2:.*]] = arith.constant 3 : index
2136-
# CHECK: scf.forall (%[[VAL_3:.*]]) = (1) to (2) step (3) {
2136+
# CHECK: scf.forall (%[[VAL_3:.*]]) = (%[[VAL_0]]) to (%[[VAL_1]]) step (%[[VAL_2]]) {
21372137
# CHECK: %[[VAL_4:.*]] = arith.constant 1.000000e+00 : f32
21382138
# CHECK: }
21392139

@@ -2153,7 +2153,7 @@ def forfoo(iv1, iv2):
21532153
# CHECK: %[[VAL_3:.*]] = arith.constant 2 : index
21542154
# CHECK: %[[VAL_4:.*]] = arith.constant 3 : index
21552155
# CHECK: %[[VAL_5:.*]] = arith.constant 3 : index
2156-
# CHECK: scf.forall (%[[VAL_6:.*]], %[[VAL_7:.*]]) = (1, 1) to (2, 2) step (3, 3) {
2156+
# CHECK: scf.forall (%[[VAL_6:.*]], %[[VAL_7:.*]]) = (%[[VAL_0]], %[[VAL_1]]) to (%[[VAL_2]], %[[VAL_3]]) step (%[[VAL_4]], %[[VAL_5]]) {
21572157
# CHECK: %[[VAL_8:.*]] = arith.constant 1.000000e+00 : f32
21582158
# CHECK: }
21592159

@@ -2184,7 +2184,7 @@ def forfoo(i, j, shared_outs):
21842184
# CHECK: %[[VAL_4:.*]] = arith.constant 2 : index
21852185
# CHECK: %[[VAL_5:.*]] = arith.constant 3 : index
21862186
# CHECK: %[[VAL_6:.*]] = arith.constant 3 : index
2187-
# CHECK: %[[VAL_7:.*]] = scf.forall (%[[VAL_8:.*]], %[[VAL_9:.*]]) = (1, 1) to (2, 2) step (3, 3) shared_outs(%[[VAL_10:.*]] = %[[VAL_0]]) -> (tensor<10x10xi32>) {
2187+
# CHECK: %[[VAL_7:.*]] = scf.forall (%[[VAL_8:.*]], %[[VAL_9:.*]]) = (%[[VAL_1]], %[[VAL_2]]) to (%[[VAL_3]], %[[VAL_4]]) step (%[[VAL_5]], %[[VAL_6]]) shared_outs(%[[VAL_10:.*]] = %[[VAL_0]]) -> (tensor<10x10xi32>) {
21882188
# CHECK: %[[VAL_11:.*]] = arith.constant 1.000000e+00 : f32
21892189
# CHECK: scf.forall.in_parallel {
21902190
# CHECK: tensor.parallel_insert_slice %[[VAL_0]] into %[[VAL_10]]{{\[}}%[[VAL_8]], %[[VAL_9]]] [10, 10] [1, 1] : tensor<10x10xi32> into tensor<10x10xi32>
@@ -2218,7 +2218,7 @@ def forfoo(i, j, shared_outs):
22182218
# CHECK: %[[VAL_4:.*]] = arith.constant 2 : index
22192219
# CHECK: %[[VAL_5:.*]] = arith.constant 3 : index
22202220
# CHECK: %[[VAL_6:.*]] = arith.constant 3 : index
2221-
# CHECK: %[[VAL_7:.*]] = scf.forall (%[[VAL_8:.*]], %[[VAL_9:.*]]) = (1, 1) to (2, 2) step (3, 3) shared_outs(%[[VAL_10:.*]] = %[[VAL_0]]) -> (tensor<10x10xi32>) {
2221+
# CHECK: %[[VAL_7:.*]] = scf.forall (%[[VAL_8:.*]], %[[VAL_9:.*]]) = (%[[VAL_1]], %[[VAL_2]]) to (%[[VAL_3]], %[[VAL_4]]) step (%[[VAL_5]], %[[VAL_6]]) shared_outs(%[[VAL_10:.*]] = %[[VAL_0]]) -> (tensor<10x10xi32>) {
22222222
# CHECK: %[[VAL_11:.*]] = arith.constant 1.000000e+00 : f32
22232223
# CHECK: scf.forall.in_parallel {
22242224
# CHECK: tensor.parallel_insert_slice %[[VAL_0]] into %[[VAL_10]]{{\[}}%[[VAL_8]], %[[VAL_9]]] [10, 10] [1, 1] : tensor<10x10xi32> into tensor<10x10xi32>
@@ -2280,7 +2280,7 @@ def test_forall_insert_slice_no_region_with_for(ctx: MLIRContext):
22802280
# CHECK: %[[VAL_4:.*]] = arith.constant 2 : index
22812281
# CHECK: %[[VAL_5:.*]] = arith.constant 3 : index
22822282
# CHECK: %[[VAL_6:.*]] = arith.constant 3 : index
2283-
# CHECK: %[[VAL_7:.*]] = scf.forall (%[[VAL_8:.*]], %[[VAL_9:.*]]) = (1, 1) to (2, 2) step (3, 3) shared_outs(%[[VAL_10:.*]] = %[[VAL_0]]) -> (tensor<10x10xi32>) {
2283+
# CHECK: %[[VAL_7:.*]] = scf.forall (%[[VAL_8:.*]], %[[VAL_9:.*]]) = (%[[VAL_1]], %[[VAL_2]]) to (%[[VAL_3]], %[[VAL_4]]) step (%[[VAL_5]], %[[VAL_6]]) shared_outs(%[[VAL_10:.*]] = %[[VAL_0]]) -> (tensor<10x10xi32>) {
22842284
# CHECK: %[[VAL_11:.*]] = arith.constant 1.000000e+00 : f32
22852285
# CHECK: scf.forall.in_parallel {
22862286
# CHECK: tensor.parallel_insert_slice %[[VAL_0]] into %[[VAL_10]]{{\[}}%[[VAL_8]], %[[VAL_9]]] [10, 10] [1, 1] : tensor<10x10xi32> into tensor<10x10xi32>

0 commit comments

Comments
 (0)