diff --git a/mlir/python/mlir/dialects/tensor.py b/mlir/python/mlir/dialects/tensor.py index 0b30d10209908..146b5f85d07f5 100644 --- a/mlir/python/mlir/dialects/tensor.py +++ b/mlir/python/mlir/dialects/tensor.py @@ -1,6 +1,7 @@ # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from typing import Optional from ._tensor_ops_gen import * from ._tensor_ops_gen import _Dialect @@ -25,6 +26,7 @@ def __init__( sizes: Sequence[Union[int, Value]], element_type: Type, *, + encoding: Optional[Attribute] = None, loc=None, ip=None, ): @@ -40,7 +42,7 @@ def __init__( else: static_sizes.append(ShapedType.get_dynamic_size()) dynamic_sizes.append(s) - result_type = RankedTensorType.get(static_sizes, element_type) + result_type = RankedTensorType.get(static_sizes, element_type, encoding) super().__init__(result_type, dynamic_sizes, loc=loc, ip=ip) @@ -48,11 +50,14 @@ def empty( sizes: Sequence[Union[int, Value]], element_type: Type, *, + encoding: Optional[Attribute] = None, loc=None, ip=None, ) -> _ods_cext.ir.Value: return _get_op_result_or_op_results( - EmptyOp(sizes=sizes, element_type=element_type, loc=loc, ip=ip) + EmptyOp( + sizes=sizes, element_type=element_type, encoding=encoding, loc=loc, ip=ip + ) ) diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py index 3cc4575eb3e24..656979f3d9a1d 100644 --- a/mlir/test/python/dialects/sparse_tensor/dialect.py +++ b/mlir/test/python/dialects/sparse_tensor/dialect.py @@ -1,7 +1,7 @@ # RUN: %PYTHON %s | FileCheck %s from mlir.ir import * -from mlir.dialects import sparse_tensor as st +from mlir.dialects import sparse_tensor as st, tensor import textwrap @@ -225,3 +225,21 @@ def testEncodingAttrOnTensorType(): # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 64, crdWidth = 32 }> print(tt.encoding) assert tt.encoding == encoding + + +# CHECK-LABEL: TEST: testEncodingEmptyTensor +@run +def testEncodingEmptyTensor(): + with Context(), Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + levels = [st.LevelFormat.compressed] + ordering = AffineMap.get_permutation([0]) + encoding = st.EncodingAttr.get(levels, ordering, ordering, 32, 32) + tensor.empty((1024,), F32Type.get(), encoding=encoding) + + # CHECK: #sparse = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 32, crdWidth = 32 }> + # CHECK: module { + # CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<1024xf32, #sparse> + # CHECK: } + print(module)