|
2 | 2 | # mypy: allow-untyped-defs |
3 | 3 | import math |
4 | 4 | from enum import Enum |
| 5 | +from functools import wraps |
5 | 6 | from typing import List, Optional, Sequence, Tuple, Union |
6 | 7 |
|
7 | 8 | import torch |
@@ -6451,6 +6452,76 @@ def _f(x, y): |
6451 | 6452 | _create_binary_float_meta_func(aten.special_legendre_polynomial_p) |
6452 | 6453 |
|
6453 | 6454 |
|
| 6455 | +def _register_inplace_meta(fn): |
| 6456 | + @wraps(fn) |
| 6457 | + def _fn(self, *args, **kwargs): |
| 6458 | + out = fn(self, *args, **kwargs) |
| 6459 | + check_inplace_broadcast(self.shape, out.shape) |
| 6460 | + return self |
| 6461 | + |
| 6462 | + inplace_name = f"{fn.__name__}_" |
| 6463 | + _fn.__name__ = inplace_name |
| 6464 | + _fn = register_meta(getattr(aten, inplace_name))(_fn) # type: ignore[assignment] |
| 6465 | + |
| 6466 | + return _fn |
| 6467 | + |
| 6468 | + |
| 6469 | +@register_meta(aten.lerp) |
| 6470 | +@out_wrapper() |
| 6471 | +def lerp(start, end, weight): |
| 6472 | + torch._check( |
| 6473 | + start.dtype == end.dtype, |
| 6474 | + lambda: f"expected dtype {start.dtype} for `end`, but got dtype {end.dtype}", |
| 6475 | + ) |
| 6476 | + args = [start, end] |
| 6477 | + if isinstance(weight, TensorLike): |
| 6478 | + torch._check( |
| 6479 | + start.dtype == weight.dtype, |
| 6480 | + lambda: f"expected dtype {start.dtype} for `weight`, but got dtype {weight.dtype}", |
| 6481 | + ) |
| 6482 | + args.append(weight) |
| 6483 | + return elementwise_meta( |
| 6484 | + *args, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT |
| 6485 | + ) |
| 6486 | + |
| 6487 | + |
| 6488 | +@register_meta(aten.addcmul) |
| 6489 | +@out_wrapper() |
| 6490 | +def addcmul(input, tensor1, tensor2, *, value=1): |
| 6491 | + return elementwise_meta( |
| 6492 | + input, tensor1, tensor2, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT |
| 6493 | + ) |
| 6494 | + |
| 6495 | + |
| 6496 | +@register_meta(aten.addcdiv) |
| 6497 | +@out_wrapper() |
| 6498 | +def addcdiv(input, tensor1, tensor2, *, value=1): |
| 6499 | + torch._check( |
| 6500 | + not ( |
| 6501 | + utils.is_integer_dtype(tensor1.dtype) |
| 6502 | + and utils.is_integer_dtype(tensor2.dtype) |
| 6503 | + ), |
| 6504 | + lambda: ( |
| 6505 | + "Integer division with addcdiv is no longer supported, and in a future ", |
| 6506 | + "release addcdiv will perform a true division of tensor1 and tensor2. ", |
| 6507 | + "The historic addcdiv behavior can be implemented as ", |
| 6508 | + "(input + value * torch.trunc(tensor1 / tensor2)).to(input.dtype) ", |
| 6509 | + "for integer inputs and as ", |
| 6510 | + "(input + value * tensor1 / tensor2) for float inputs. ", |
| 6511 | + "The future addcdiv behavior is just the latter implementation: ", |
| 6512 | + "(input + value * tensor1 / tensor2), for all dtypes.", |
| 6513 | + ), |
| 6514 | + ) |
| 6515 | + return elementwise_meta( |
| 6516 | + input, tensor1, tensor2, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT |
| 6517 | + ) |
| 6518 | + |
| 6519 | + |
| 6520 | +lerp_ = _register_inplace_meta(aten.lerp) |
| 6521 | +addcmul_ = _register_inplace_meta(aten.addcmul) |
| 6522 | +addcdiv_ = _register_inplace_meta(aten.addcdiv) |
| 6523 | + |
| 6524 | + |
6454 | 6525 | # We must also trigger meta registrations from PrimTorch ref |
6455 | 6526 | # decompositions |
6456 | 6527 | import torch._refs |
|
0 commit comments