Skip to content

Commit e4b98b1

Browse files
ysiraichipytorchmergebot
authored andcommitted
Add meta functions for lerp, addcmul, and addcdiv. (pytorch#136909)
This PR adds new meta functions for `lerp`, `addcmul`, and `addcdiv` (including their respective inplace versions). These functions only had refs implementations, which was being the root cause of a significant overhead ([issue][1]) when running `AdamW` optimizer step on PyTorch/XLA backend. Running the meta functions resulted in the following improvements: - `lerp` calls: 1,550ms to 140ms (10x) - `addcdiv` calls: 640ms to 350ms (1.8x) - `addcmul` calls: 620ms to 300ms (2.05x) [1]: https://github.com/pytorch/xla/issues/7923 Pull Request resolved: pytorch#136909 Approved by: https://github.com/jansel
1 parent a1f1f58 commit e4b98b1

File tree

1 file changed

+71
-0
lines changed

1 file changed

+71
-0
lines changed

torch/_meta_registrations.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# mypy: allow-untyped-defs
33
import math
44
from enum import Enum
5+
from functools import wraps
56
from typing import List, Optional, Sequence, Tuple, Union
67

78
import torch
@@ -6451,6 +6452,76 @@ def _f(x, y):
64516452
_create_binary_float_meta_func(aten.special_legendre_polynomial_p)
64526453

64536454

6455+
def _register_inplace_meta(fn):
6456+
@wraps(fn)
6457+
def _fn(self, *args, **kwargs):
6458+
out = fn(self, *args, **kwargs)
6459+
check_inplace_broadcast(self.shape, out.shape)
6460+
return self
6461+
6462+
inplace_name = f"{fn.__name__}_"
6463+
_fn.__name__ = inplace_name
6464+
_fn = register_meta(getattr(aten, inplace_name))(_fn) # type: ignore[assignment]
6465+
6466+
return _fn
6467+
6468+
6469+
@register_meta(aten.lerp)
6470+
@out_wrapper()
6471+
def lerp(start, end, weight):
6472+
torch._check(
6473+
start.dtype == end.dtype,
6474+
lambda: f"expected dtype {start.dtype} for `end`, but got dtype {end.dtype}",
6475+
)
6476+
args = [start, end]
6477+
if isinstance(weight, TensorLike):
6478+
torch._check(
6479+
start.dtype == weight.dtype,
6480+
lambda: f"expected dtype {start.dtype} for `weight`, but got dtype {weight.dtype}",
6481+
)
6482+
args.append(weight)
6483+
return elementwise_meta(
6484+
*args, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
6485+
)
6486+
6487+
6488+
@register_meta(aten.addcmul)
6489+
@out_wrapper()
6490+
def addcmul(input, tensor1, tensor2, *, value=1):
6491+
return elementwise_meta(
6492+
input, tensor1, tensor2, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
6493+
)
6494+
6495+
6496+
@register_meta(aten.addcdiv)
6497+
@out_wrapper()
6498+
def addcdiv(input, tensor1, tensor2, *, value=1):
6499+
torch._check(
6500+
not (
6501+
utils.is_integer_dtype(tensor1.dtype)
6502+
and utils.is_integer_dtype(tensor2.dtype)
6503+
),
6504+
lambda: (
6505+
"Integer division with addcdiv is no longer supported, and in a future ",
6506+
"release addcdiv will perform a true division of tensor1 and tensor2. ",
6507+
"The historic addcdiv behavior can be implemented as ",
6508+
"(input + value * torch.trunc(tensor1 / tensor2)).to(input.dtype) ",
6509+
"for integer inputs and as ",
6510+
"(input + value * tensor1 / tensor2) for float inputs. ",
6511+
"The future addcdiv behavior is just the latter implementation: ",
6512+
"(input + value * tensor1 / tensor2), for all dtypes.",
6513+
),
6514+
)
6515+
return elementwise_meta(
6516+
input, tensor1, tensor2, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
6517+
)
6518+
6519+
6520+
lerp_ = _register_inplace_meta(aten.lerp)
6521+
addcmul_ = _register_inplace_meta(aten.addcmul)
6522+
addcdiv_ = _register_inplace_meta(aten.addcdiv)
6523+
6524+
64546525
# We must also trigger meta registrations from PrimTorch ref
64556526
# decompositions
64566527
import torch._refs

0 commit comments

Comments
 (0)