Skip to content

Commit 69c8f26

Browse files
committed
use star import in ext dir
1 parent dac8023 commit 69c8f26

File tree

10 files changed

+51
-84
lines changed

10 files changed

+51
-84
lines changed

.github/workflows/test.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@ on:
1414
# At minute 0 past hour 6. (see https://crontab.guru)
1515
- cron: '00 06 * * *'
1616

17+
concurrency:
18+
# A PR number if a pull request and otherwise the commit hash. This cancels
19+
# queued and in-progress runs for the same PR (presubmit) or commit
20+
# (postsubmit).
21+
group: ci-build-test-${{ github.event.number || github.sha }}
22+
cancel-in-progress: true
23+
1724
jobs:
1825

1926
test-mlir-bindings:

mlir/extras/dialects/ext/arith.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import numpy as np
88
from ....dialects import arith as arith_dialect
9+
from ....dialects.arith import *
910
from ....dialects import complex as complex_dialect
1011
from ....dialects._arith_enum_gen import (
1112
_arith_cmpfpredicateattr,

mlir/extras/dialects/ext/cf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from ....dialects.cf import BranchOp, CondBranchOp
1+
from ....dialects.cf import *
22
from ....dialects._cf_ops_gen import _Dialect
33
from ....dialects._ods_common import (
44
get_op_result_or_value,

mlir/extras/dialects/ext/func.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from ...meta import make_maybe_no_args_decorator, op_region_builder
66
from ...util import get_user_code_loc
7-
from ....dialects.func import FuncOp, ReturnOp, CallOp
7+
from ....dialects.func import *
88
from ....ir import (
99
InsertionPoint,
1010
FunctionType,

mlir/extras/dialects/ext/gpu.py

Lines changed: 19 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,7 @@
33
from typing import Optional, Any
44

55
from ....dialects._ods_common import get_default_loc_context, _cext
6-
from ....dialects.gpu import (
7-
AddressSpace,
8-
MappingId,
9-
GPUModuleOp,
10-
GPUFuncOp,
11-
LaunchFuncOp,
12-
LaunchOp,
13-
ReturnOp,
14-
AllReduceOp,
15-
YieldOp,
16-
TerminatorOp,
17-
WaitOp,
18-
)
6+
from ....dialects.gpu import *
197
from ....dialects._gpu_ops_gen import _Dialect
208
from ....ir import (
219
Type,
@@ -30,9 +18,8 @@
3018
)
3119

3220
from ... import types as T
33-
from ...dialects.ext.arith import constant
34-
from ...dialects.ext.func import FuncBase
35-
from ....dialects.gpu import block_id, module_end
21+
from .arith import constant
22+
from .func import FuncBase
3623
from ...meta import (
3724
ModuleMeta,
3825
make_maybe_no_args_decorator,
@@ -318,20 +305,18 @@ def __call__(
318305
size[i] = constant(s, index=True)
319306

320307
loc = get_user_code_loc()
321-
return (
322-
get_op_result_or_op_results(
323-
LaunchFuncOp(
324-
[self.qualname, self.func_name]
325-
if self.qualname is not None
326-
else [self.func_name],
327-
grid_size,
328-
block_size,
329-
kernel_operands,
330-
async_dependencies,
331-
dynamic_shared_memory_size,
332-
async_object=stream,
333-
loc=loc,
334-
)
308+
return get_op_result_or_op_results(
309+
LaunchFuncOp(
310+
[self.qualname, self.func_name]
311+
if self.qualname is not None
312+
else [self.func_name],
313+
grid_size,
314+
block_size,
315+
kernel_operands,
316+
async_dependencies,
317+
dynamic_shared_memory_size,
318+
async_object=stream,
319+
loc=loc,
335320
)
336321
)
337322

@@ -409,10 +394,8 @@ def all_reduce__(value: Value, *, op=None, uniform=None, loc=None, ip=None):
409394

410395

411396
def all_reduce_(value: Value, *, op=None, uniform=None, loc=None, ip=None):
412-
return (
413-
get_op_result_or_op_results(
414-
all_reduce__(value, op=op, uniform=uniform, loc=loc, ip=ip)
415-
)
397+
return get_op_result_or_op_results(
398+
all_reduce__(value, op=op, uniform=uniform, loc=loc, ip=ip)
416399
)
417400

418401

@@ -425,6 +408,6 @@ def wait(async_dependencies: Optional[list[Value]] = None, *, loc=None, ip=None)
425408
if async_dependencies is None:
426409
async_dependencies = []
427410
async_token = gpu_async_token()
428-
return (
429-
get_op_result_or_op_results(WaitOp(async_token, async_dependencies, loc=loc, ip=ip))
411+
return get_op_result_or_op_results(
412+
WaitOp(async_token, async_dependencies, loc=loc, ip=ip)
430413
)

mlir/extras/dialects/ext/llvm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from ....ir import Type
2+
from ....dialects.llvm import *
23

34

45
def llvm_ptr_t():

mlir/extras/dialects/ext/memref.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,10 @@
55
from ....ir import Type, Value, MemRefType, ShapedType, MLIRError
66

77
from ... import types as T
8+
from ....dialects.memref import *
89
from ....dialects import memref, arith
9-
from ...dialects.ext.arith import Scalar, constant
10-
from ...dialects.ext.tensor import (
11-
_indices_to_indexer,
12-
compute_result_shape_reassoc_list,
13-
)
10+
from .arith import Scalar, constant
11+
from .tensor import _indices_to_indexer, compute_result_shape_reassoc_list
1412
from ...meta import region_op
1513
from ...._mlir_libs._mlir import register_value_caster
1614
from ...util import get_user_code_loc
@@ -39,7 +37,7 @@ def _alloc(
3937
def alloc(sizes: Sequence[Union[int, Value]], element_type: Type, *, loc=None, ip=None):
4038
if loc is None:
4139
loc = get_user_code_loc()
42-
return _alloc(memref.AllocOp, sizes, element_type, loc=loc, ip=ip)
40+
return _alloc(AllocOp, sizes, element_type, loc=loc, ip=ip)
4341

4442

4543
def alloca(
@@ -48,7 +46,7 @@ def alloca(
4846
if loc is None:
4947
loc = get_user_code_loc()
5048
return get_op_result_or_op_results(
51-
_alloc(memref.AllocaOp, sizes, element_type, loc=loc, ip=ip)
49+
_alloc(AllocaOp, sizes, element_type, loc=loc, ip=ip)
5250
)
5351

5452

@@ -59,7 +57,7 @@ def load(mem: Value, indices: Sequence[Value | int], *, loc=None, ip=None):
5957
for idx, i in enumerate(indices):
6058
if isinstance(i, int):
6159
indices[idx] = constant(i, index=True)
62-
return get_op_result_or_op_results(memref.LoadOp(mem, indices, loc=loc, ip=ip))
60+
return get_op_result_or_op_results(LoadOp(mem, indices, loc=loc, ip=ip))
6361

6462

6563
def store(
@@ -71,9 +69,7 @@ def store(
7169
for idx, i in enumerate(indices):
7270
if isinstance(i, int):
7371
indices[idx] = constant(i, index=True)
74-
return get_op_result_or_op_results(
75-
memref.StoreOp(value, mem, indices, loc=loc, ip=ip)
76-
)
72+
return get_op_result_or_op_results(StoreOp(value, mem, indices, loc=loc, ip=ip))
7773

7874

7975
def subview(
@@ -345,4 +341,4 @@ def _copy_to_subview(
345341
return memref.copy(source, dest_subview, loc=loc, ip=ip)
346342

347343

348-
alloca_scope = region_op(memref.AllocaScopeOp)
344+
alloca_scope = region_op(AllocaScopeOp)

mlir/extras/dialects/ext/scf.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,50 +4,37 @@
44
from copy import deepcopy
55
from typing import Optional, Sequence, Union
66

7-
from bytecode import ConcreteBytecode, ConcreteInstr
7+
from bytecode import ConcreteBytecode
88

9-
from ... import types as T
109
from ...ast.canonicalize import (
1110
StrictTransformer,
1211
Canonicalizer,
1312
BytecodePatcher,
1413
)
1514
from ...ast.util import ast_call, set_lineno
16-
from ...dialects.ext.arith import constant, index_cast
17-
from ...dialects.ext.gpu import get_device_mapping_array_attr
18-
from ....dialects.scf import yield_ as yield__, reduce_return, condition
15+
from .gpu import get_device_mapping_array_attr
1916
from ...meta import region_adder, region_op
2017
from ...util import get_user_code_loc
2118
from ....dialects._ods_common import (
22-
get_op_results_or_values,
2319
get_op_result_or_op_results,
2420
get_default_loc_context,
2521
_cext,
2622
)
2723
from ....dialects.linalg.opdsl.lang.emitter import _is_index_type
28-
from ....dialects.scf import (
29-
_Dialect,
30-
IfOp,
31-
ForOp,
32-
ForallOp,
33-
ParallelOp,
34-
InParallelOp,
35-
ReduceOp,
36-
WhileOp,
37-
ExecuteRegionOp,
38-
)
24+
from ....dialects.scf import *
25+
from ....dialects.scf import _Dialect, yield_ as yield__, reduce_return, condition
3926
from ....ir import (
4027
InsertionPoint,
4128
Value,
4229
OpResultList,
43-
OpResult,
4430
Operation,
4531
OpView,
4632
IndexType,
4733
_denseI64ArrayAttr,
4834
Attribute,
4935
OpaqueType,
5036
)
37+
from .arith import constant, index_cast
5138

5239
logger = logging.getLogger(__name__)
5340

mlir/extras/dialects/ext/tensor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616

1717
from ... import types as T
1818
from ....dialects import tensor
19-
from ...dialects.ext.arith import ArithValue, Scalar, constant
19+
from ....dialects.tensor import *
20+
from .arith import ArithValue, Scalar, constant
2021
from ...meta import region_op, _update_caller_vars
2122
from ...._mlir_libs._mlir import register_value_caster
2223
from ...util import get_user_code_loc

mlir/extras/dialects/ext/transform.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,23 @@
11
from typing import Optional, Union, Sequence
22

3-
from ... import types as T
43
from ...meta import region_op
54
from ...util import get_user_code_loc
6-
from ....dialects.transform.structured import _dispatch_mixed_values, TileUsingForOp
5+
from ....dialects import pdl
6+
from ....dialects.transform import *
7+
from ....dialects.transform.loop import *
8+
from ....dialects.transform.structured import *
9+
from ....dialects._ods_common import get_op_result_or_op_results
710
from ....dialects._structured_transform_ops_gen import (
811
TileUsingForallOp,
912
MatchOp,
1013
)
11-
from ....dialects.transform import ApplyPatternsOp
12-
from ....dialects.transform import (
13-
SequenceOp,
14-
FailurePropagationMode,
15-
YieldOp,
16-
AnyOpType,
17-
OperationType,
18-
)
19-
from ....dialects.transform.loop import LoopUnrollOp
20-
from ....dialects.transform import GetParentOp
14+
from ....dialects.transform.structured import _dispatch_mixed_values
2115
from ....ir import (
2216
Type,
2317
Value,
2418
Operation,
2519
StringAttr,
2620
)
27-
from ....dialects._ods_common import get_op_result_or_op_results
28-
from ....dialects import pdl
29-
3021

3122
pdl_operation_t = lambda: pdl.OperationType.get()
3223

0 commit comments

Comments
 (0)