Skip to content

Commit 5d59fa9

Browse files
authored
Reapply "[mlir][py] better support for arith.constant construction" (#84142)
Arithmetic constants for vector types can be constructed from objects implementing Python buffer protocol such as `array.array`. Note that until Python 3.12, there is no typing support for buffer protocol implementers, so the annotations use array explicitly. Reverts #84103
1 parent 8aed911 commit 5d59fa9

File tree

2 files changed

+68
-2
lines changed

2 files changed

+68
-2
lines changed

mlir/python/mlir/dialects/arith.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from ._arith_ops_gen import *
66
from ._arith_ops_gen import _Dialect
77
from ._arith_enum_gen import *
8+
from array import array as _array
9+
from typing import overload
810

911
try:
1012
from ..ir import *
@@ -43,13 +45,37 @@ def _is_float_type(type: Type):
4345
class ConstantOp(ConstantOp):
4446
"""Specialization for the constant op class."""
4547

48+
@overload
49+
def __init__(self, value: Attribute, *, loc=None, ip=None):
50+
...
51+
52+
@overload
4653
def __init__(
47-
self, result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
54+
self, result: Type, value: Union[int, float, _array], *, loc=None, ip=None
4855
):
56+
...
57+
58+
def __init__(self, result, value, *, loc=None, ip=None):
59+
if value is None:
60+
assert isinstance(result, Attribute)
61+
super().__init__(result, loc=loc, ip=ip)
62+
return
63+
4964
if isinstance(value, int):
5065
super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
5166
elif isinstance(value, float):
5267
super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
68+
elif isinstance(value, _array):
69+
if 8 * value.itemsize != result.element_type.width:
70+
raise ValueError(
71+
f"Mismatching array element ({8 * value.itemsize}) and type ({result.element_type.width}) width."
72+
)
73+
if value.typecode in ["i", "l", "q"]:
74+
super().__init__(DenseIntElementsAttr.get(value, type=result))
75+
elif value.typecode in ["f", "d"]:
76+
super().__init__(DenseFPElementsAttr.get(value, type=result))
77+
else:
78+
raise ValueError(f'Unsupported typecode: "{value.typecode}".')
5379
else:
5480
super().__init__(value, loc=loc, ip=ip)
5581

@@ -79,6 +105,6 @@ def literal_value(self) -> Union[int, float]:
79105

80106

81107
def constant(
82-
result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
108+
result: Type, value: Union[int, float, Attribute, _array], *, loc=None, ip=None
83109
) -> Value:
84110
return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip))

mlir/test/python/dialects/arith_dialect.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from mlir.ir import *
55
import mlir.dialects.arith as arith
66
import mlir.dialects.func as func
7+
from array import array
78

89

910
def run(f):
@@ -92,3 +93,42 @@ def __str__(self):
9293
b = a * a
9394
# CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64)
9495
print(b)
96+
97+
98+
# CHECK-LABEL: TEST: testArrayConstantConstruction
99+
@run
100+
def testArrayConstantConstruction():
101+
with Context(), Location.unknown():
102+
module = Module.create()
103+
with InsertionPoint(module.body):
104+
i32_array = array("i", [1, 2, 3, 4])
105+
i32 = IntegerType.get_signless(32)
106+
vec_i32 = VectorType.get([2, 2], i32)
107+
arith.constant(vec_i32, i32_array)
108+
arith.ConstantOp(vec_i32, DenseIntElementsAttr.get(i32_array, type=vec_i32))
109+
110+
# "q" is the equivalent of `long long` in C and requires at least
111+
# 64 bit width integers on both Linux and Windows.
112+
i64_array = array("q", [5, 6, 7, 8])
113+
i64 = IntegerType.get_signless(64)
114+
vec_i64 = VectorType.get([1, 4], i64)
115+
arith.constant(vec_i64, i64_array)
116+
arith.ConstantOp(vec_i64, DenseIntElementsAttr.get(i64_array, type=vec_i64))
117+
118+
f32_array = array("f", [1.0, 2.0, 3.0, 4.0])
119+
f32 = F32Type.get()
120+
vec_f32 = VectorType.get([4, 1], f32)
121+
arith.constant(vec_f32, f32_array)
122+
arith.ConstantOp(vec_f32, DenseFPElementsAttr.get(f32_array, type=vec_f32))
123+
124+
f64_array = array("d", [1.0, 2.0, 3.0, 4.0])
125+
f64 = F64Type.get()
126+
vec_f64 = VectorType.get([2, 1, 2], f64)
127+
arith.constant(vec_f64, f64_array)
128+
arith.ConstantOp(vec_f64, DenseFPElementsAttr.get(f64_array, type=vec_f64))
129+
130+
# CHECK-COUNT-2: arith.constant dense<[{{\[}}1, 2], [3, 4]]> : vector<2x2xi32>
131+
# CHECK-COUNT-2: arith.constant dense<[{{\[}}5, 6, 7, 8]]> : vector<1x4xi64>
132+
# CHECK-COUNT-2: arith.constant dense<[{{\[}}1.000000e+00], [2.000000e+00], [3.000000e+00], [4.000000e+00]]> : vector<4x1xf32>
133+
# CHECK-COUNT-2: arith.constant dense<[{{\[}}[1.000000e+00, 2.000000e+00]], [{{\[}}3.000000e+00, 4.000000e+00]]]> : vector<2x1x2xf64>
134+
print(module)

0 commit comments

Comments
 (0)