Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ Practically speaking that means you need to have *some* package installed that i
So

```shell
$ YOUR_HOST_MLIR_PYTHON_PACKAGE_PREFIX=<YOUR_HOST_MLIR_PYTHON_PACKAGE_PREFIX> pip install git+https://github.com/makslevental/mlir-python-extras
$ HOST_MLIR_PYTHON_PACKAGE_PREFIX=<YOUR_HOST_MLIR_PYTHON_PACKAGE_PREFIX> pip install git+https://github.com/makslevental/mlir-python-extras
```

where `YOUR_HOST_MLIR_PYTHON_PACKAGE_PREFIX` is (as it says) the package prefix for your chosen host bindings.
Expand Down
22 changes: 16 additions & 6 deletions mlir/extras/dialects/ext/linalg.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from . import arith
from ...util import get_user_code_loc

from ....dialects import linalg
# noinspection PyUnresolvedReferences
from ....dialects.linalg import *
from ....dialects import linalg
from ....extras import types as T


def abs(I, O, *, loc=None, ip=None):
Expand Down Expand Up @@ -263,16 +264,25 @@ def exp(I, O, *, loc=None, ip=None):
return linalg.exp(I, loc=loc, ip=ip, outs=[O])


def fill(O, *, loc=None, ip=None):
def fill(v, O, *, loc=None, ip=None):
if isinstance(v, (float, int, bool)):
v = arith.constant(v)
if loc is None:
loc = get_user_code_loc()
return linalg.fill(loc=loc, ip=ip, outs=[O])
return linalg.fill(v, loc=loc, ip=ip, outs=[O])


def fill_rng_2d(O, *, loc=None, ip=None):
def fill_rng_2d(min, max, seed, O, *, loc=None, ip=None):
params = [min, max]
for i, m in enumerate(params):
if isinstance(m, (float, int)):
params[i] = arith.constant(m, type=T.f64())
min, max = params
if isinstance(seed, int):
seed = arith.constant(seed, T.i32())
if loc is None:
loc = get_user_code_loc()
return linalg.fill_rng_2d(loc=loc, ip=ip, outs=[O])
return linalg.fill_rng_2d(min, max, seed, loc=loc, ip=ip, outs=[O])


def floor(I, O, *, loc=None, ip=None):
Expand Down
44 changes: 44 additions & 0 deletions tests/test_linalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from textwrap import dedent

import pytest

import mlir.extras.types as T
from mlir.extras.dialects.ext import linalg, memref, tensor

# noinspection PyUnresolvedReferences
from mlir.extras.testing import MLIRContext, filecheck, mlir_ctx as ctx

# needed since the fix isn't defined here nor conftest.py
pytest.mark.usefixtures("ctx")


def test_np_constructor(ctx: MLIRContext):
x = memref.alloc(10, 10, T.i32())
linalg.fill(5, x)
linalg.fill_rng_2d(0.0, 10.0, 1, x)

x = tensor.empty(10, 10, T.i32())
y = linalg.fill_rng_2d(0.0, 10.0, 1, x)
z = linalg.fill(5, x)

correct = dedent(
"""\
module {
%alloc = memref.alloc() : memref<10x10xi32>
%c5_i32 = arith.constant 5 : i32
linalg.fill ins(%c5_i32 : i32) outs(%alloc : memref<10x10xi32>)
%cst = arith.constant 0.000000e+00 : f64
%cst_0 = arith.constant 1.000000e+01 : f64
%c1_i32 = arith.constant 1 : i32
linalg.fill_rng_2d ins(%cst, %cst_0, %c1_i32 : f64, f64, i32) outs(%alloc : memref<10x10xi32>)
%0 = tensor.empty() : tensor<10x10xi32>
%cst_1 = arith.constant 0.000000e+00 : f64
%cst_2 = arith.constant 1.000000e+01 : f64
%c1_i32_3 = arith.constant 1 : i32
%1 = linalg.fill_rng_2d ins(%cst_1, %cst_2, %c1_i32_3 : f64, f64, i32) outs(%0 : tensor<10x10xi32>) -> tensor<10x10xi32>
%c5_i32_4 = arith.constant 5 : i32
%2 = linalg.fill ins(%c5_i32_4 : i32) outs(%0 : tensor<10x10xi32>) -> tensor<10x10xi32>
}
"""
)
filecheck(correct, ctx.module)