|
| 1 | +from . import arith |
1 | 2 | from ...util import get_user_code_loc |
2 | | - |
| 3 | +from ....dialects import linalg |
3 | 4 | # noinspection PyUnresolvedReferences |
4 | 5 | from ....dialects.linalg import * |
5 | | -from ....dialects import linalg |
| 6 | +from ....extras import types as T |
6 | 7 |
|
7 | 8 |
|
8 | 9 | def abs(I, O, *, loc=None, ip=None): |
@@ -263,16 +264,25 @@ def exp(I, O, *, loc=None, ip=None): |
263 | 264 | return linalg.exp(I, loc=loc, ip=ip, outs=[O]) |
264 | 265 |
|
265 | 266 |
|
266 | | -def fill(O, *, loc=None, ip=None): |
| 267 | +def fill(v, O, *, loc=None, ip=None): |
| 268 | + if isinstance(v, (float, int, bool)): |
| 269 | + v = arith.constant(v) |
267 | 270 | if loc is None: |
268 | 271 | loc = get_user_code_loc() |
269 | | - return linalg.fill(loc=loc, ip=ip, outs=[O]) |
| 272 | + return linalg.fill(v, loc=loc, ip=ip, outs=[O]) |
270 | 273 |
|
271 | 274 |
|
272 | | -def fill_rng_2d(O, *, loc=None, ip=None): |
| 275 | +def fill_rng_2d(min, max, seed, O, *, loc=None, ip=None): |
| 276 | + params = [min, max] |
| 277 | + for i, m in enumerate(params): |
| 278 | + if isinstance(m, (float, int)): |
| 279 | + params[i] = arith.constant(m, type=T.f64()) |
| 280 | + min, max = params |
| 281 | + if isinstance(seed, int): |
| 282 | + seed = arith.constant(seed, T.i32()) |
273 | 283 | if loc is None: |
274 | 284 | loc = get_user_code_loc() |
275 | | - return linalg.fill_rng_2d(loc=loc, ip=ip, outs=[O]) |
| 285 | + return linalg.fill_rng_2d(min, max, seed, loc=loc, ip=ip, outs=[O]) |
276 | 286 |
|
277 | 287 |
|
278 | 288 | def floor(I, O, *, loc=None, ip=None): |
|
0 commit comments