Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions mlir/extras/dialects/ext/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
from ...meta import (
region_op,
)
from ...util import ModuleMeta, get_user_code_loc, make_maybe_no_args_decorator
from ...util import (
ModuleMeta,
_get_previous_frame_idents,
get_user_code_loc,
make_maybe_no_args_decorator,
)
from ....dialects._gpu_ops_gen import _Dialect
from ....dialects._ods_common import (
_cext,
Expand Down Expand Up @@ -327,14 +332,7 @@ def __init__(self, func):

def __getitem__(self, item):
previous_frame = inspect.currentframe().f_back
var_names = [
[
var_name
for var_name, var_val in previous_frame.f_locals.items()
if var_val is arg
]
for arg in item
]
var_names = [_get_previous_frame_idents(arg, previous_frame) for arg in item]
kwargs = {}
for i, it in enumerate(item):
assert len(var_names[i]) == 1, "expected unique kwarg"
Expand Down
17 changes: 9 additions & 8 deletions mlir/extras/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,14 @@ def memref_type_to_np_dtype(memref_type):
return _memref_type_to_np_dtype.get(memref_type)


def _get_previous_frame_idents(val, previous_frame):
return [
var_name
for var_name, var_val in previous_frame.f_locals.items()
if var_val is val
]


def _update_caller_vars(previous_frame, args: Sequence, replacements: Sequence):
"""Update caller vars passed as args.

Expand All @@ -249,14 +257,7 @@ def _update_caller_vars(previous_frame, args: Sequence, replacements: Sequence):
if len(args) != len(replacements):
raise ValueError(f"updates must be 1-1: {args=} {replacements=}")
# find the name of the iter args in the previous frame
var_names = [
[
var_name
for var_name, var_val in previous_frame.f_locals.items()
if var_val is arg
]
for arg in args
]
var_names = [_get_previous_frame_idents(arg, previous_frame) for arg in args]
for i, var_names in enumerate(var_names):
for var_name in var_names:
previous_frame.f_locals[var_name] = replacements[i]
Expand Down