Skip to content

Commit 2735488

Browse files
internal developermryszt
authored andcommitted
add new HOO hints_wrapper to annotate function with hints
Cherry-pick changes from upstream PR: pytorch#132860 Change-Id: I540d5a17534b8166a1a2ca49a8a96f475ee56a96
1 parent f150e55 commit 2735488

File tree

6 files changed

+428
-2
lines changed

6 files changed

+428
-2
lines changed

test/dynamo/test_higher_order_ops.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
normalize_gm,
2626
)
2727
from torch._dynamo.utils import counters, ifdynstaticdefault
28+
from torch._higher_order_ops.hints_wrap import hints_wrapper
2829
from torch._higher_order_ops.wrap import wrap
2930
from torch.testing._internal.common_utils import (
3031
munge_exc,
@@ -2425,6 +2426,139 @@ def fn(pred, pytree_in):
24252426
):
24262427
torch.compile(fn, backend="eager")(pred, pytree_in)
24272428

2429+
def test_hints_wrapper(self):
2430+
def ref_fn(x, y):
2431+
x = x + y
2432+
x = torch.relu(x)
2433+
x = x + y
2434+
return torch.abs(x)
2435+
2436+
def fn_with_hints(x, y):
2437+
x = x + y
2438+
2439+
def inner_body_fn(x, y):
2440+
x = torch.relu(x)
2441+
x = x + y
2442+
return x
2443+
2444+
def outer_body_fn(x, y):
2445+
x = hints_wrapper(inner_body_fn, (x, y), {}, hints={"inner_body": True})
2446+
x = torch.abs(x)
2447+
return x
2448+
2449+
res = hints_wrapper(outer_body_fn, (x, y), {}, hints={"outer_body": True})
2450+
return res
2451+
2452+
backend = EagerAndRecordGraphs()
2453+
cnt = CompileCounterWithBackend(backend)
2454+
2455+
x = torch.randn(2, 4)
2456+
y = torch.ones(4)
2457+
2458+
eager_res = fn_with_hints(x, y)
2459+
compiled_res = torch.compile(fn_with_hints, backend=cnt)(x, y)
2460+
ref_res = ref_fn(x, y)
2461+
self.assertEqual(eager_res, ref_res)
2462+
self.assertEqual(compiled_res, ref_res)
2463+
self.assertEqual(len(cnt.graphs), 1)
2464+
2465+
# Dynamic shapes produce a slightly different graph.
2466+
if check_dynamic_shape_capture():
2467+
return
2468+
2469+
graph = backend.graphs[0]
2470+
self.assertExpectedInline(
2471+
normalize_gm(graph.print_readable(print_output=False)),
2472+
"""\
2473+
class GraphModule(torch.nn.Module):
2474+
def forward(self, L_x_: "f32[2, 4]", L_y_: "f32[4]"):
2475+
l_x_ = L_x_
2476+
l_y_ = L_y_
2477+
2478+
x: "f32[2, 4]" = l_x_ + l_y_; l_x_ = None
2479+
2480+
hints_wrapper_body_1 = self.hints_wrapper_body_1
2481+
hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_1, (x, l_y_), {}, hints = {'outer_body': True}); hints_wrapper_body_1 = x = l_y_ = None
2482+
res: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
2483+
return (res,)
2484+
2485+
class GraphModule(torch.nn.Module):
2486+
def forward(self, x: "f32[2, 4]", l_y_: "f32[4]"):
2487+
hints_wrapper_body_0 = self.hints_wrapper_body_0
2488+
hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_0, (x, l_y_), {}, hints = {'inner_body': True}); hints_wrapper_body_0 = x = l_y_ = None
2489+
getitem: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
2490+
2491+
abs_1: "f32[2, 4]" = torch.abs(getitem); getitem = None
2492+
return (abs_1,)
2493+
2494+
class GraphModule(torch.nn.Module):
2495+
def forward(self, x: "f32[2, 4]", l_y_: "f32[4]"):
2496+
relu: "f32[2, 4]" = torch.relu(x); x = None
2497+
2498+
add: "f32[2, 4]" = relu + l_y_; relu = l_y_ = None
2499+
return (add,)
2500+
""",
2501+
)
2502+
2503+
def test_hints_wrapper_no_hints(self):
2504+
def fn_with_hints(x, y):
2505+
def outer_body_fn(x, y):
2506+
x = torch.add(x, y)
2507+
return x
2508+
2509+
res = hints_wrapper(outer_body_fn, (x, y), {})
2510+
return res
2511+
2512+
backend = EagerAndRecordGraphs()
2513+
cnt = CompileCounterWithBackend(backend)
2514+
2515+
x = torch.randn(2, 4)
2516+
y = torch.ones(4)
2517+
2518+
msg = "hints_wrapper - key hints not provided"
2519+
with self.assertRaisesRegex(RuntimeError, msg):
2520+
compiled_res = torch.compile(fn_with_hints, backend=cnt)(x, y)
2521+
2522+
def test_hints_wrapper_incorrect_type(self):
2523+
def fn_with_hints(x, y):
2524+
def outer_body_fn(x, y):
2525+
x = torch.add(x, y)
2526+
return x
2527+
2528+
res = hints_wrapper(outer_body_fn, (x, y), {}, hints={"test": (True,)})
2529+
return res
2530+
2531+
backend = EagerAndRecordGraphs()
2532+
cnt = CompileCounterWithBackend(backend)
2533+
2534+
x = torch.randn(2, 4)
2535+
y = torch.ones(4)
2536+
2537+
msg = r"hints must be a dict containing int, float, bool or str value,"
2538+
with self.assertRaisesRegex(RuntimeError, msg):
2539+
compiled_res = torch.compile(fn_with_hints, backend=cnt)(x, y)
2540+
2541+
def test_hints_wrapper_pytree_inputs(self):
2542+
def fn_with_hints(x, y):
2543+
def outer_body_fn(x):
2544+
res = torch.add(x[0], x[1]["test"])
2545+
return res
2546+
2547+
res = hints_wrapper(
2548+
outer_body_fn, ((x, {"test": y}),), {}, hints={"test": True}
2549+
)
2550+
return res
2551+
2552+
backend = EagerAndRecordGraphs()
2553+
cnt = CompileCounterWithBackend(backend)
2554+
2555+
x = torch.randn(2, 4)
2556+
y = torch.ones(4)
2557+
2558+
msg = r"args must be a tuple of tensors, ints, floats, or bools,"
2559+
with self.assertRaisesRegex(RuntimeError, msg):
2560+
fn_with_hints(x, y)
2561+
24282562

24292563
class HigherOrderOpVmapGuardTests(LoggingTestCase):
24302564
@make_logging_test(recompiles=True)

test/export/test_export.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818

1919
from functorch.experimental.control_flow import cond, map
2020
from torch import Tensor
21+
from torch._higher_order_ops.hints_wrap import hints_wrapper
2122
from torch._dynamo.test_case import TestCase
23+
from torch._dynamo.testing import normalize_gm
2224
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
2325
from torch._export.utils import (
2426
get_buffer,
@@ -5927,6 +5929,66 @@ def forward(self, x):
59275929
):
59285930
self.assertEqual(v1, v2)
59295931

5932+
def test_hints_wrapper(self):
5933+
class M(torch.nn.Module):
5934+
def __init__(self) -> None:
5935+
super().__init__()
5936+
5937+
def forward(self, x, y):
5938+
x = x + y
5939+
5940+
def inner_body_fn(x, y):
5941+
x = torch.relu(x)
5942+
x = x + y
5943+
return x
5944+
5945+
def outer_body_fn(x, y):
5946+
x = hints_wrapper(
5947+
inner_body_fn, (x, y), {}, hints={"inner_body": True}
5948+
)
5949+
x = torch.abs(x)
5950+
return x
5951+
5952+
res = hints_wrapper(
5953+
outer_body_fn, (x, y), {}, hints={"outer_body": True}
5954+
)
5955+
return res
5956+
5957+
x = torch.randn(2, 4)
5958+
y = torch.ones(4)
5959+
5960+
ep = export(M(), (x, y))
5961+
export_res = ep.module()(x, y)
5962+
ref_res = M()(x, y)
5963+
self.assertEqual(export_res, ref_res)
5964+
self.assertExpectedInline(
5965+
normalize_gm(ep.graph_module.print_readable(print_output=False)),
5966+
"""\
5967+
class GraphModule(torch.nn.Module):
5968+
def forward(self, x: "f32[2, 4]", y: "f32[4]"):
5969+
add: "f32[2, 4]" = torch.ops.aten.add.Tensor(x, y); x = None
5970+
5971+
hints_wrapper_body_graph_0 = self.hints_wrapper_body_graph_0
5972+
hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_graph_0, (add, y), {}, hints = {'outer_body': True}); hints_wrapper_body_graph_0 = add = y = None
5973+
getitem: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
5974+
return (getitem,)
5975+
5976+
class <lambda>(torch.nn.Module):
5977+
def forward(self, arg0_1: "f32[2, 4]", arg1_1: "f32[4]"):
5978+
hints_wrapper_body_graph_0 = self.hints_wrapper_body_graph_0
5979+
hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_graph_0, (arg0_1, arg1_1), {}, hints = {'inner_body': True}); hints_wrapper_body_graph_0 = arg0_1 = arg1_1 = None
5980+
getitem: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None
5981+
abs_1: "f32[2, 4]" = torch.ops.aten.abs.default(getitem); getitem = None
5982+
return (abs_1,)
5983+
5984+
class <lambda>(torch.nn.Module):
5985+
def forward(self, arg0_1: "f32[2, 4]", arg1_1: "f32[4]"):
5986+
relu: "f32[2, 4]" = torch.ops.aten.relu.default(arg0_1); arg0_1 = None
5987+
add: "f32[2, 4]" = torch.ops.aten.add.Tensor(relu, arg1_1); relu = arg1_1 = None
5988+
return (add,)
5989+
""",
5990+
)
5991+
59305992

59315993
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
59325994
class TestExportCustomClass(TorchTestCase):

torch/_dynamo/variables/higher_order_ops.py

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from torch.utils import _pytree as pytree
2525
from .. import variables
2626

27-
from ..exc import UncapturedHigherOrderOpError, unimplemented, Unsupported
27+
from ..exc import IncorrectUsage, UncapturedHigherOrderOpError, unimplemented, Unsupported
2828
from ..source import AttrSource
2929
from ..utils import proxy_args_kwargs
3030
from .dicts import ConstDictVariable
@@ -34,7 +34,6 @@
3434
if TYPE_CHECKING:
3535
from torch._dynamo.symbolic_convert import InstructionTranslator
3636

37-
3837
log = logging.getLogger(__name__)
3938

4039

@@ -543,6 +542,8 @@ def make(value, source=None, **kwargs):
543542
return HintedContextHigherOrderVariable(value, source, **kwargs)
544543
elif value.__name__ == "wrap":
545544
return WrapHigherOrderVariable(value, source, **kwargs)
545+
elif value.__name__ == "hints_wrapper":
546+
return HintsWrapperHigherOrderVariable(value, source, **kwargs)
546547
elif value.__name__ == "flex_attention":
547548
return TemplatedAttentionHigherOrderVariable(value, source, **kwargs)
548549
elif value.__name__ in (
@@ -1326,6 +1327,80 @@ def call_function(
13261327
)
13271328

13281329

1330+
class HintsWrapperHigherOrderVariable(TorchHigherOrderOperatorVariable):
1331+
@raise_hard_error_if_graph_break(
1332+
reason="Hints_wrapper doesn't work unless it is captured completely with torch.compile."
1333+
)
1334+
def call_function(
1335+
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
1336+
) -> "VariableTracker":
1337+
_check_supported_callable_arg(tx, args[0], "body_fn")
1338+
1339+
# inputs
1340+
if len(args) != 3:
1341+
unimplemented(
1342+
f"Expected 3 arguments but got {len(args)}.\n"
1343+
f"Usage: hints_wrapper(body_fn, args, kwargs, hints).\n"
1344+
f"kwargs required to be provided explicitly."
1345+
)
1346+
1347+
if not isinstance(args[1], (ListVariable, TupleVariable)):
1348+
unimplemented(
1349+
f"Expected a tuple but got {args[1].python_type()}",
1350+
)
1351+
operands = args[1].unpack_var_sequence(tx)
1352+
1353+
if not isinstance(args[2], ConstDictVariable):
1354+
unimplemented(
1355+
f"Expected a dict but got {args[2].python_type()}",
1356+
)
1357+
1358+
if "hints" not in kwargs:
1359+
raise IncorrectUsage("hints_wrapper - key hints not provided")
1360+
1361+
(
1362+
(body_r, treespec),
1363+
body_graph,
1364+
body_lifted_freevars,
1365+
) = speculate_subgraph(
1366+
tx,
1367+
args[0], # function
1368+
operands,
1369+
args[2].as_python_constant(),
1370+
"hints_wrapper",
1371+
source_target=self.value,
1372+
should_flatten_outputs=True,
1373+
)
1374+
1375+
body_gmod = torch.fx.GraphModule(tx.output.nn_modules, body_graph)
1376+
body_name = add_subgraph(
1377+
tx,
1378+
"hints_wrapper_body",
1379+
body_gmod,
1380+
)
1381+
1382+
body_node = make_attr(tx, body_name)
1383+
1384+
# Since, we call `speculate_subgraph` with `set_subgraph_inputs="automatic`,
1385+
# all the arguments are lifted.
1386+
lifted_args = tuple(arg for arg in body_lifted_freevars.keys())
1387+
p_args = (body_node, lifted_args, {})
1388+
1389+
p_kwargs = {}
1390+
# add hints into p_kwargs
1391+
p_kwargs["hints"] = kwargs["hints"].as_python_constant()
1392+
1393+
flat_example_value = pytree.tree_map_only(
1394+
torch.fx.Proxy,
1395+
lambda a: a.node.meta["example_value"],
1396+
body_r.as_proxy(),
1397+
)
1398+
1399+
return _call_function_and_unflatten_output(
1400+
tx, self.value, p_args, p_kwargs, flat_example_value, treespec
1401+
)
1402+
1403+
13291404
class OutDtypeHigherOrderVariable(TorchHigherOrderOperatorVariable):
13301405
def call_function(
13311406
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .cond import cond
22
from .while_loop import while_loop
33
from .flex_attention import flex_attention, flex_attention_backward
4+
from torch._higher_order_ops.hints_wrap import hints_wrapper

0 commit comments

Comments
 (0)