|
1 | | -import warnings |
2 | | - |
3 | 1 | # noinspection PyUnresolvedReferences |
4 | 2 | from .....dialects.llvm import * |
5 | | -from .....ir import Type, F16Type, F32Type, F64Type, BF16Type, IntegerType |
6 | | - |
7 | | -try: |
8 | | - from llvm import intrinsic_is_overloaded, intrinsic_get_type, print_type_to_string |
9 | | - from llvm import types_ |
10 | | - from llvm.context import context as llvm_context |
11 | | -except ImportError: |
12 | | - warnings.warn( |
13 | | - "llvm bindings not installed; call_intrinsic won't work without supplying return type explicitly" |
14 | | - ) |
| 3 | +from .....ir import Type, Value |
15 | 4 |
|
| 5 | +ValueRef = Value |
16 | 6 |
|
17 | 7 | def llvm_ptr_t(): |
18 | 8 | return Type.parse("!llvm.ptr") |
19 | 9 |
|
20 | | - |
21 | | -def mlir_type_to_llvm_type(mlir_type, llvm_ctx): |
22 | | - if F16Type.isinstance(mlir_type): |
23 | | - return types_.half_type_in_context(llvm_ctx) |
24 | | - if F32Type.isinstance(mlir_type): |
25 | | - return types_.float_type_in_context(llvm_ctx) |
26 | | - if F64Type.isinstance(mlir_type): |
27 | | - return types_.double_type_in_context(llvm_ctx) |
28 | | - if BF16Type.isinstance(mlir_type): |
29 | | - return types_.b_float_type_in_context(llvm_ctx) |
30 | | - if IntegerType.isinstance(mlir_type): |
31 | | - return types_.int_type_in_context(llvm_ctx, mlir_type.width) |
32 | | - |
33 | | - raise NotImplementedError(f"{mlir_type} is not supported") |
34 | | - |
35 | | - |
36 | | -def llvm_type_str_to_mlir_type(llvm_type: str): |
37 | | - if llvm_type.startswith("<"): |
38 | | - return Type.parse(f"vector{llvm_type}") |
39 | | - if llvm_type == "float": |
40 | | - return F32Type.get() |
41 | | - raise NotImplementedError(f"{llvm_type} is not supported") |
42 | | - |
43 | | - |
44 | | -_call_intrinsic = call_intrinsic |
45 | | - |
46 | | - |
47 | | -def call_intrinsic(*args, **kwargs): |
48 | | - intr_id = kwargs.pop("intr_id") |
49 | | - intr_name = kwargs.pop("intr_name") |
50 | | - mlir_ret_type = kwargs.pop("return_type", None) |
51 | | - if mlir_ret_type: |
52 | | - return _call_intrinsic(mlir_ret_type, intr_name, args, [], []) |
53 | | - |
54 | | - is_overloaded = kwargs.pop("is_overloaded", None) |
55 | | - if is_overloaded is None: |
56 | | - is_overloaded = intrinsic_is_overloaded(intr_id) |
57 | | - with llvm_context() as ctx: |
58 | | - types = [] |
59 | | - if is_overloaded: |
60 | | - types = [mlir_type_to_llvm_type(a.type, ctx.context) for a in args] |
61 | | - intr_decl_fn_ty = intrinsic_get_type(ctx.context, intr_id, types) |
62 | | - |
63 | | - ret_type_str = print_type_to_string(intr_decl_fn_ty).split(" (")[0].strip() |
64 | | - mlir_ret_type = None |
65 | | - if ret_type_str: |
66 | | - mlir_ret_type = llvm_type_str_to_mlir_type(ret_type_str) |
67 | | - |
68 | | - return _call_intrinsic(mlir_ret_type, intr_name, args, [], []) |
69 | | - |
70 | | - |
71 | | -call_intrinsic_ = call_intrinsic |
72 | | - |
73 | 10 | from . import amdgcn |
0 commit comments