Skip to content

Commit a974667

Browse files
authored
[MLIR][Python] Add encoding argument to tensor.empty Python function (#110656)
Hi @xurui1995 @makslevental, I think in #103087 there's unintended regression where user can no longer create sparse tensors with `tensor.empty`. Previously I could pass: ```python out = tensor.empty(tensor_type, []) ``` where `tensor_type` contained `shape`, `dtype`, and `encoding`. With the latest ```python tensor.empty(sizes: Sequence[Union[int, Value]], element_type: Type, *, loc=None, ip=None) ``` it's no longer possible. I propose to add `encoding` argument which is passed to `RankedTensorType.get(static_sizes, element_type, encoding)` (I updated one of the tests to check it).
1 parent f3c408d commit a974667

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

mlir/python/mlir/dialects/tensor.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
from typing import Optional
45

56
from ._tensor_ops_gen import *
67
from ._tensor_ops_gen import _Dialect
@@ -25,6 +26,7 @@ def __init__(
2526
sizes: Sequence[Union[int, Value]],
2627
element_type: Type,
2728
*,
29+
encoding: Optional[Attribute] = None,
2830
loc=None,
2931
ip=None,
3032
):
@@ -40,19 +42,22 @@ def __init__(
4042
else:
4143
static_sizes.append(ShapedType.get_dynamic_size())
4244
dynamic_sizes.append(s)
43-
result_type = RankedTensorType.get(static_sizes, element_type)
45+
result_type = RankedTensorType.get(static_sizes, element_type, encoding)
4446
super().__init__(result_type, dynamic_sizes, loc=loc, ip=ip)
4547

4648

4749
def empty(
4850
sizes: Sequence[Union[int, Value]],
4951
element_type: Type,
5052
*,
53+
encoding: Optional[Attribute] = None,
5154
loc=None,
5255
ip=None,
5356
) -> _ods_cext.ir.Value:
5457
return _get_op_result_or_op_results(
55-
EmptyOp(sizes=sizes, element_type=element_type, loc=loc, ip=ip)
58+
EmptyOp(
59+
sizes=sizes, element_type=element_type, encoding=encoding, loc=loc, ip=ip
60+
)
5661
)
5762

5863

mlir/test/python/dialects/sparse_tensor/dialect.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# RUN: %PYTHON %s | FileCheck %s
22

33
from mlir.ir import *
4-
from mlir.dialects import sparse_tensor as st
4+
from mlir.dialects import sparse_tensor as st, tensor
55
import textwrap
66

77

@@ -225,3 +225,21 @@ def testEncodingAttrOnTensorType():
225225
# CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 64, crdWidth = 32 }>
226226
print(tt.encoding)
227227
assert tt.encoding == encoding
228+
229+
230+
# CHECK-LABEL: TEST: testEncodingEmptyTensor
231+
@run
232+
def testEncodingEmptyTensor():
233+
with Context(), Location.unknown():
234+
module = Module.create()
235+
with InsertionPoint(module.body):
236+
levels = [st.LevelFormat.compressed]
237+
ordering = AffineMap.get_permutation([0])
238+
encoding = st.EncodingAttr.get(levels, ordering, ordering, 32, 32)
239+
tensor.empty((1024,), F32Type.get(), encoding=encoding)
240+
241+
# CHECK: #sparse = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 32, crdWidth = 32 }>
242+
# CHECK: module {
243+
# CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<1024xf32, #sparse>
244+
# CHECK: }
245+
print(module)

0 commit comments

Comments
 (0)