|  | 
| 59 | 59 |     "from mlir.extras.dialects.ext.scf import canonicalizer as scf, range_ as range\n", | 
| 60 | 60 |     "from mlir.extras.runtime.passes import Pipeline, run_pipeline\n", | 
| 61 | 61 |     "from mlir.extras.runtime.refbackend import LLVMJITBackend\n", | 
|  | 62 | +    "from mlir.ir import StridedLayoutAttr\n", | 
| 62 | 63 |     "\n", | 
| 63 | 64 |     "# you need this to register the memref value caster\n", | 
| 64 | 65 |     "# noinspection PyUnresolvedReferences\n", | 
|  | 
| 91 | 92 |    "outputs": [], | 
| 92 | 93 |    "source": [ | 
| 93 | 94 |     "K = 10\n", | 
| 94 |  | -    "memref_i64 = T.memref(K, K, T.i64)\n", | 
|  | 95 | +    "memref_i64 = T.memref(K, K, T.i64())\n", | 
| 95 | 96 |     "\n", | 
| 96 | 97 |     "@func(emit=True)\n", | 
| 97 | 98 |     "@canonicalize(using=scf)\n", | 
| 98 | 99 |     "def memfoo(A: memref_i64, B: memref_i64, C: memref_i64):\n", | 
| 99 | 100 |     "    one = constant(1)\n", | 
| 100 | 101 |     "    two = constant(2)\n", | 
| 101 | 102 |     "    if one > two:\n", | 
| 102 |  | -    "        C[0, 0] = constant(3, T.i64)\n", | 
|  | 103 | +    "        C[0, 0] = constant(3, T.i64())\n", | 
| 103 | 104 |     "    else:\n", | 
| 104 | 105 |     "        for i in range(0, K):\n", | 
| 105 | 106 |     "            for j in range(0, K):\n", | 
|  | 
| 447 | 448 |     "D = 32\n", | 
| 448 | 449 |     "\n", | 
| 449 | 450 |     "F = K // D\n", | 
| 450 |  | -    "ranked_memref_kxk_f32 = T.memref(K, K, T.f32)\n", | 
| 451 |  | -    "ranked_memref_dxd_f32 = T.memref(D, D, T.f32, layout=((K, 1), S))\n", | 
|  | 451 | +    "ranked_memref_kxk_f32 = T.memref(K, K, T.f32())\n", | 
|  | 452 | +    "layout = StridedLayoutAttr.get(S, (K, 1))\n", | 
|  | 453 | +    "ranked_memref_dxd_f32 = T.memref(D, D, T.f32(), layout=layout)\n", | 
| 452 | 454 |     "\n", | 
| 453 | 455 |     "@func(emit=True)\n", | 
| 454 | 456 |     "@canonicalize(using=scf)\n", | 
|  | 
| 784 | 786 |     "ctx_man = mlir_mod_ctx()\n", | 
| 785 | 787 |     "ctx = ctx_man.__enter__()\n", | 
| 786 | 788 |     "\n", | 
| 787 |  | -    "ranked_memref_kxk_f32 = T.memref(K, K, T.f32)\n", | 
| 788 |  | -    "ranked_memref_dxd_f32 = T.memref(D, D, T.f32, layout=((K, 1), S))\n", | 
|  | 789 | +    "ranked_memref_kxk_f32 = T.memref(K, K, T.f32())\n", | 
|  | 790 | +    "layout = StridedLayoutAttr.get(S, (K, 1))\n", | 
|  | 791 | +    "ranked_memref_dxd_f32 = T.memref(D, D, T.f32(), layout=layout)\n", | 
| 789 | 792 |     "\n", | 
| 790 | 793 |     "from mlir.extras.dialects.ext import linalg\n", | 
| 791 | 794 |     "\n", | 
|  | 
0 commit comments