Skip to content

Commit 3e05a48

Browse files
yushangdipytorchmergebot
authored andcommitted
Fix clamp type promotion in inductor decomposition (pytorch#154471)
Summary: as title, the clamp type promotion should take min/max arg into consideration as well. Test Plan: ``` buck run fbcode//caffe2/test/inductor:test_aot_inductor -- -r test_clamp_decomposition_cpu python test/inductor/test_torchinductor.py -k test_clamp -v ``` Differential Revision: D75490124 Pull Request resolved: pytorch#154471 Approved by: https://github.com/desertfire, https://github.com/chenyang78
1 parent d865b78 commit 3e05a48

File tree

4 files changed

+36
-2
lines changed

4 files changed

+36
-2
lines changed

test/inductor/test_aot_inductor.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5851,6 +5851,22 @@ def wrapped(**kwargs):
58515851
# compare against eager
58525852
self.assertEqual(optimized(**model_kwargs), model(**model_kwargs))
58535853

5854+
def test_clamp_decomposition(self):
5855+
class Model1(torch.nn.Module):
5856+
def forward(self, x):
5857+
return x.clamp(min=1.5)
5858+
5859+
class Model2(torch.nn.Module):
5860+
def forward(self, x):
5861+
return x.clamp(min=2)
5862+
5863+
x = torch.randint(4, (4,))
5864+
5865+
# the output should have float32 type, not int
5866+
self.check_model(Model1(), (x,))
5867+
# the output should have int type
5868+
self.check_model(Model2(), (x,))
5869+
58545870

58555871
class AOTInductorLoggingTest(LoggingTestCase):
58565872
@make_logging_test(dynamic=logging.DEBUG)

test/inductor/test_torchinductor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2696,6 +2696,12 @@ def fn(a):
26962696

26972697
self.common(fn, (torch.randint(4, (4,)),))
26982698

2699+
def test_clamp_type_promotion_non_tensor(self):
2700+
def fn(a):
2701+
return a.clamp(min=1.5), a.clamp(min=2)
2702+
2703+
self.common(fn, (torch.randint(4, (4,)),))
2704+
26992705
@skip_if_gpu_halide
27002706
@xfail_if_triton_cpu
27012707
def test_dist(self):

torch/_decomp/decompositions.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,17 @@ def type_casts(
5858
f: Callable,
5959
type_promotion: utils.ELEMENTWISE_TYPE_PROMOTION_KIND,
6060
compute_dtype_only: bool = False,
61+
include_non_tensor_args: bool = False,
6162
):
6263
@functools.wraps(f)
6364
def inner(*args, **kwargs):
65+
allowed_types = (
66+
(Tensor, torch.types._Number) if include_non_tensor_args else (Tensor,)
67+
) # type: ignore[arg-type]
6468
flat_args = [
65-
x for x in pytree.arg_tree_leaves(*args, **kwargs) if isinstance(x, Tensor)
69+
x
70+
for x in pytree.arg_tree_leaves(*args, **kwargs)
71+
if isinstance(x, allowed_types)
6672
]
6773
computation_dtype, result_dtype = utils.elementwise_dtypes(
6874
*flat_args, type_promotion_kind=type_promotion
@@ -98,6 +104,11 @@ def decrease_prec(x):
98104
pw_cast_for_opmath = partial(
99105
type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
100106
)
107+
pw_cast_for_opmath_non_tensor_args = partial(
108+
type_casts,
109+
type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
110+
include_non_tensor_args=True,
111+
)
101112
pw_cast_for_int_to_real = partial(
102113
type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
103114
)

torch/_inductor/decomposition.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
_index_add,
2323
embedding_dense_backward as decomp_embedding_dense_backward,
2424
pw_cast_for_opmath,
25+
pw_cast_for_opmath_non_tensor_args,
2526
)
2627
from torch._decomp.decompositions_for_rng import extra_random_decomps
2728
from torch._dynamo.utils import counters
@@ -181,7 +182,7 @@ def sym_constrain_range_for_size(
181182

182183

183184
@register_decomposition([aten.clamp])
184-
@pw_cast_for_opmath
185+
@pw_cast_for_opmath_non_tensor_args
185186
def clamp(
186187
x: torch.Tensor,
187188
min: Optional[torch.types.Number] = None,

0 commit comments

Comments
 (0)