|
5 | 5 | from ._arith_ops_gen import *
|
6 | 6 | from ._arith_ops_gen import _Dialect
|
7 | 7 | from ._arith_enum_gen import *
|
| 8 | +from array import array as _array |
| 9 | +from typing import overload |
8 | 10 |
|
9 | 11 | try:
|
10 | 12 | from ..ir import *
|
@@ -43,13 +45,37 @@ def _is_float_type(type: Type):
|
43 | 45 | class ConstantOp(ConstantOp):
|
44 | 46 | """Specialization for the constant op class."""
|
45 | 47 |
|
| 48 | + @overload |
| 49 | + def __init__(self, value: Attribute, *, loc=None, ip=None): |
| 50 | + ... |
| 51 | + |
| 52 | + @overload |
46 | 53 | def __init__(
|
47 |
| - self, result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None |
| 54 | + self, result: Type, value: Union[int, float, _array], *, loc=None, ip=None |
48 | 55 | ):
|
| 56 | + ... |
| 57 | + |
| 58 | + def __init__(self, result, value, *, loc=None, ip=None): |
| 59 | + if value is None: |
| 60 | + assert isinstance(result, Attribute) |
| 61 | + super().__init__(result, loc=loc, ip=ip) |
| 62 | + return |
| 63 | + |
49 | 64 | if isinstance(value, int):
|
50 | 65 | super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
|
51 | 66 | elif isinstance(value, float):
|
52 | 67 | super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
|
| 68 | + elif isinstance(value, _array): |
| 69 | + if 8 * value.itemsize != result.element_type.width: |
| 70 | + raise ValueError( |
| 71 | + f"Mismatching array element ({8 * value.itemsize}) and type ({result.element_type.width}) width." |
| 72 | + ) |
| 73 | + if value.typecode in ["i", "l", "q"]: |
| 74 | + super().__init__(DenseIntElementsAttr.get(value, type=result)) |
| 75 | + elif value.typecode in ["f", "d"]: |
| 76 | + super().__init__(DenseFPElementsAttr.get(value, type=result)) |
| 77 | + else: |
| 78 | + raise ValueError(f'Unsupported typecode: "{value.typecode}".') |
53 | 79 | else:
|
54 | 80 | super().__init__(value, loc=loc, ip=ip)
|
55 | 81 |
|
@@ -79,6 +105,6 @@ def literal_value(self) -> Union[int, float]:
|
79 | 105 |
|
80 | 106 |
|
81 | 107 | def constant(
|
82 |
| - result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None |
| 108 | + result: Type, value: Union[int, float, Attribute, _array], *, loc=None, ip=None |
83 | 109 | ) -> Value:
|
84 | 110 | return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip))
|
0 commit comments