Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
685 changes: 650 additions & 35 deletions examples/cuda_matmul_opt.py

Large diffs are not rendered by default.

18 changes: 9 additions & 9 deletions examples/mlir_python_extras.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
"from mlir.extras.dialects.ext.arith import constant\n",
"from mlir.extras.dialects.ext.memref import S\n",
"from mlir.extras.dialects.ext.func import func\n",
"from mlir.extras.dialects.ext.scf import canonicalizer as scf, range_ as range\n",
"from mlir.extras.dialects.ext.scf import canonicalizer as scf, range_\n",
"from mlir.extras.runtime.passes import Pipeline, run_pipeline\n",
"from mlir.extras.runtime.refbackend import LLVMJITBackend\n",
"from mlir.ir import StridedLayoutAttr\n",
Expand Down Expand Up @@ -102,8 +102,8 @@
" if one > two:\n",
" C[0, 0] = constant(3, T.i64())\n",
" else:\n",
" for i in range(0, K):\n",
" for j in range(0, K):\n",
" for i in range_(0, K):\n",
" for j in range_(0, K):\n",
" C[i, j] = A[i, j] * B[i, j]"
]
},
Expand Down Expand Up @@ -457,17 +457,17 @@
"def tile(\n",
" A: ranked_memref_dxd_f32, B: ranked_memref_dxd_f32, C: ranked_memref_dxd_f32\n",
"):\n",
" for i in range(0, D):\n",
" for j in range(0, D):\n",
" for i in range_(0, D):\n",
" for j in range_(0, D):\n",
" C[i, j] = A[i, j] + B[i, j]\n",
"\n",
"@func(emit=True)\n",
"@canonicalize(using=scf)\n",
"def tiled_memfoo(\n",
" A: ranked_memref_kxk_f32, B: ranked_memref_kxk_f32, C: ranked_memref_kxk_f32\n",
"):\n",
" for i in range(0, F):\n",
" for j in range(0, F):\n",
" for i in range_(0, F):\n",
" for j in range_(0, F):\n",
" l = lambda l: l * D\n",
" r = lambda r: (r + 1) * D\n",
" a, b, c = (\n",
Expand Down Expand Up @@ -797,8 +797,8 @@
"def linalg_memfoo(\n",
" A: ranked_memref_kxk_f32, B: ranked_memref_kxk_f32, C: ranked_memref_kxk_f32\n",
"):\n",
" for i in range(0, F):\n",
" for j in range(0, F):\n",
" for i in range_(0, F):\n",
" for j in range_(0, F):\n",
" l = lambda l: l * D\n",
" r = lambda r: (r + 1) * D\n",
" a, b, c = (\n",
Expand Down
2 changes: 1 addition & 1 deletion mlir/extras/ast/canonicalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def transform_ast(
max([l for _, l in line_starts]) - min([l for _, l in line_starts]) + 1
> n_lines
) or (f.__code__.co_firstlineno != min([l for _, l in line_starts])):
warnings.warn(
logger.debug(
"something went wrong with the line numbers for the rewritten/canonicalized function"
)
f.__code__ = new_f_code_o
Expand Down
3 changes: 3 additions & 0 deletions mlir/extras/ast/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ def copy_func(f, new_closure: Dict = None):

def append_hidden_node(node_body, new_node):
last_statement = node_body[-1]
assert (
last_statement.end_lineno is not None
), f"last_statement {ast.unparse(last_statement)} must have end_lineno"
new_node = ast.fix_missing_locations(
set_lineno(new_node, last_statement.end_lineno)
)
Expand Down
6 changes: 5 additions & 1 deletion mlir/extras/dialects/ext/arith.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ast
import copy
import operator
from abc import abstractmethod
from copy import deepcopy
Expand Down Expand Up @@ -513,17 +514,20 @@ def visit_AugAssign(
and isinstance(updated_node.value, ast.BinOp)
and isinstance(updated_node.value.op, ast.Mult)
):
target = copy.deepcopy(updated_node.target)
target.ctx = ast.Load()
updated_node = ast.Assign(
targets=[updated_node.target],
value=ast_call(
_FMA_BUILDER_NAME,
[
updated_node.value.left,
updated_node.value.right,
ast.Name(updated_node.target.id, ast.Load()),
target,
],
),
)
updated_node = ast.fix_missing_locations(updated_node)

return updated_node

Expand Down
45 changes: 35 additions & 10 deletions mlir/extras/dialects/ext/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from functools import partial
from typing import Any, List, Optional, Tuple, Union

from mlir.dialects._gpu_enum_gen import AddressSpace

from .arith import constant
from .func import FuncBase
from ... import types as T
Expand Down Expand Up @@ -117,32 +119,39 @@ def get_device_mapping_array_attr(
return ArrayAttr.get(mapping, context=context)


def device_mapping_attr(mnemonic, mapping_id_enum: MappingId):
def gpu_attr(mnemonic, mapping_id_enum: MappingId):
return Attribute.parse(f"#gpu.{mnemonic}<{mapping_id_enum}>")


def thread_attr(thread):
return device_mapping_attr("thread", thread)
return gpu_attr("thread", thread)


def block_attr(block):
return device_mapping_attr("block", block)
return gpu_attr("block", block)


def warp_attr(warp):
return device_mapping_attr("warp", warp)
return gpu_attr("warp", warp)


def warpgroup_attr(warpgroup):
return device_mapping_attr("warpgroup", warpgroup)
return gpu_attr("warpgroup", warpgroup)


def address_space_attr(address_space: AddressSpace):
return device_mapping_attr("address_space", address_space)
return gpu_attr("address_space", address_space)


_int = int


def smem_space(int=False):
a = AddressSpace.Workgroup
if int:
return _int(a)

def smem_space():
return address_space_attr(AddressSpace.Workgroup)
return address_space_attr(a)


@_cext.register_operation(_Dialect, replace=True)
Expand Down Expand Up @@ -577,13 +586,29 @@ def printf(format, *args):
_dynamic_shared_memory = dynamic_shared_memory


def dynamic_shared_memory(*, loc=None, ip=None):
def dynamic_shared_memory(*, int=False, loc=None, ip=None):
return _dynamic_shared_memory(
T.memref(
ShapedType.get_dynamic_size(),
element_type=T.i8(),
memory_space=smem_space(),
memory_space=smem_space(int),
),
loc=loc,
ip=ip,
)


_memset = memset


def memset(dst, value, async_dependencies=None, *, loc=None, ip=None):
if loc is None:
loc = get_user_code_loc()
if async_dependencies is None:
async_dependencies = []
async_token = None
if len(async_dependencies):
async_token = gpu_async_token()
if isinstance(value, (int, float, bool)):
value = constant(value, type=dst.type.element_type)
return _memset(async_token, async_dependencies, dst, value, loc=loc, ip=ip)
1 change: 1 addition & 0 deletions mlir/extras/dialects/ext/llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@

def llvm_ptr_t():
return Type.parse("!llvm.ptr")

63 changes: 49 additions & 14 deletions mlir/extras/dialects/ext/memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def _alloc(
sizes: Sequence[Union[int, Value]],
element_type: Type,
memory_space=None,
alignment=None,
loc=None,
ip=None,
):
Expand All @@ -52,21 +53,56 @@ def _alloc(

symbol_operands = []
return get_op_result_or_op_results(
op_ctor(result_type, dynamic_sizes, symbol_operands, loc=loc, ip=ip)
op_ctor(
result_type,
dynamic_sizes,
symbol_operands,
alignment=alignment,
loc=loc,
ip=ip,
)
)


def alloc(sizes: Union[int, Value], element_type: Type = None, memory_space=None):
loc = get_user_code_loc()
def alloc(
sizes: Union[int, Value],
element_type: Type = None,
memory_space=None,
alignment=None,
loc=None,
ip=None,
):
if loc is None:
loc = get_user_code_loc()
return _alloc(
AllocOp, sizes, element_type, memory_space=memory_space, loc=loc, ip=None
AllocOp,
sizes,
element_type,
memory_space=memory_space,
alignment=alignment,
loc=loc,
ip=ip,
)


def alloca(sizes: Union[int, Value], element_type: Type = None, memory_space=None):
loc = get_user_code_loc()
def alloca(
sizes: Union[int, Value],
element_type: Type = None,
memory_space=None,
alignment=None,
loc=None,
ip=None,
):
if loc is None:
loc = get_user_code_loc()
return _alloc(
AllocaOp, sizes, element_type, memory_space=memory_space, loc=loc, ip=None
AllocaOp,
sizes,
element_type,
memory_space=memory_space,
alignment=alignment,
loc=loc,
ip=ip,
)


Expand Down Expand Up @@ -113,8 +149,9 @@ def __getitem__(self, idx: tuple) -> "MemRef":
if idx is None:
return expand_shape(self, (0,), loc=loc)

idx = list((idx,) if isinstance(idx, (int, slice)) else idx)
idx = list((idx,) if isinstance(idx, (int, Scalar, slice)) else idx)
for i, d in enumerate(idx):
# TODO(max): rethink this since subview and etc probably take constant attributes?
if isinstance(d, int):
idx[i] = constant(d, index=True, loc=loc)

Expand All @@ -123,7 +160,7 @@ def __getitem__(self, idx: tuple) -> "MemRef":
else:
return _subview(self, tuple(idx), loc=loc)

def __setitem__(self, idx, source):
def __setitem__(self, idx, val):
loc = get_user_code_loc()

if not self.has_rank():
Expand All @@ -135,12 +172,10 @@ def __setitem__(self, idx, source):
idx[i] = constant(d, index=True, loc=loc)

if all(isinstance(d, Scalar) for d in idx) and len(idx) == len(self.shape):
assert isinstance(
source, Scalar
), "coordinate insert requires scalar element"
store(source, self, idx, loc=loc)
assert isinstance(val, Scalar), "coordinate insert requires scalar element"
store(val, self, idx, loc=loc)
else:
_copy_to_subview(self, source, tuple(idx), loc=loc)
_copy_to_subview(self, val, tuple(idx), loc=loc)


def expand_shape(
Expand Down
21 changes: 16 additions & 5 deletions mlir/extras/dialects/ext/scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from contextlib import contextmanager
from copy import deepcopy
from typing import List
from typing import List, Union, Optional, Sequence

from bytecode import ConcreteBytecode

Expand All @@ -18,6 +18,7 @@
get_op_result_or_op_results,
)
from ....dialects.linalg.opdsl.lang.emitter import _is_index_type

# gotta come first
from ....dialects.scf import *
from ....dialects.scf import _Dialect, yield_ as yield__
Expand Down Expand Up @@ -432,13 +433,18 @@ def visit_If(self, updated_node: ast.If) -> ast.If:
updated_node.orelse, deepcopy(new_yield)
)

updated_node = ast.fix_missing_locations(updated_node)
return updated_node

def visit_For(self, updated_node: ast.For) -> ast.For:
updated_node = self.generic_visit(updated_node)
new_yield = ast.Expr(ast.Yield(value=None))
if not is_yield(updated_node.body[-1]):
updated_node.body = append_hidden_node(updated_node.body, new_yield)
# TODO(max): this isn't robust at all...
line = ast.dump(updated_node.iter.func)
if "range_" in line or "for_" in line:
updated_node = self.generic_visit(updated_node)
new_yield = ast.Expr(ast.Yield(value=None))
if not is_yield(updated_node.body[-1]):
updated_node.body = append_hidden_node(updated_node.body, new_yield)
updated_node = ast.fix_missing_locations(updated_node)
return updated_node


Expand Down Expand Up @@ -480,6 +486,7 @@ def visit_If(self, updated_node: ast.If) -> ast.If:

if needs_forward(updated_node.orelse):
updated_node.orelse = forward_yield_from_nested_if(updated_node.orelse)
updated_node = ast.fix_missing_locations(updated_node)
return updated_node


Expand Down Expand Up @@ -515,6 +522,10 @@ def visit_While(self, updated_node: ast.While) -> List[ast.AST]:
)
new_test = ast.copy_location(new_test, updated_node)
updated_node.test = new_test

updated_node = ast.fix_missing_locations(updated_node)
assign = ast.fix_missing_locations(assign)

return [assign, updated_node]


Expand Down
1 change: 1 addition & 0 deletions mlir/extras/dialects/ext/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def insert_slice(
)


# TODO(max): unify vector/memref/tensor
@register_value_caster(RankedTensorType.static_typeid)
class Tensor(ShapedValue, ArithValue):
def __getitem__(self, idx: tuple) -> "Tensor":
Expand Down
Loading