From d38a980013bfa294c7c056932a5161000eeeb04b Mon Sep 17 00:00:00 2001 From: "Tugsbayasgalan (Tugsuu) Manlaibaatar" Date: Tue, 18 Jul 2023 11:39:35 -0700 Subject: [PATCH] Add unlifting pass under private config (#4) Summary: X-link: https://github.com/pytorch/pytorch/pull/104897 Pull Request resolved: https://github.com/pytorch/executorch/pull/4 We wanna do this little by little. For now, I tried only on DissectedPartsModel which needs to use aot_export version. Reviewed By: JacobSzwejbka Differential Revision: D46785735 fbshipit-source-id: 9357b25615d97be2426bf74164b9995c57c3b187 --- backends/test/test_backends.py | 11 +- exir/__init__.py | 160 +++++++++++++++++- exir/dialects/edge/edge.yaml | 82 ++++++--- exir/dialects/edge/yaml_generator.py | 2 + ...test_quant_lowering_custom_backend_pass.py | 17 +- exir/tests/test_tracer.py | 134 +++++++++++++++ 6 files changed, 359 insertions(+), 47 deletions(-) diff --git a/backends/test/test_backends.py b/backends/test/test_backends.py index 0a03de50d13..f4b9c2c31e8 100644 --- a/backends/test/test_backends.py +++ b/backends/test/test_backends.py @@ -618,11 +618,7 @@ def forward(self, x_raw, h, c): ).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False)) program_without_delegates = ( - exir.capture( - composite_m, - (input_x, input_h, input_c), - exir.CaptureConfig(pt2_mode=True), - ) + exir.capture(CompositeModel(3), inputs) .to_edge(exir.EdgeCompileConfig(_check_ir_validity=False)) .to_executorch( config=exir.ExecutorchBackendConfig(extract_segments=extract_segments), @@ -726,7 +722,7 @@ def forward(self, x_raw, h, c): program_without_delegates = ( exir.capture( - composite_m, + CompositeModel(3), (input_x, input_h, input_c), exir.CaptureConfig(pt2_mode=True), ) @@ -962,7 +958,8 @@ def test_quantized_with_delegate(self) -> None: example_inputs, exir.CaptureConfig( pt2_mode=True, - enable_functionalization=False, + enable_aot=True, + _unlift=True, ), ).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False)) FileCheck().check_count("quantize_per_tensor.default", 3).check("addmm").run( diff --git a/exir/__init__.py b/exir/__init__.py index 68079c91f6c..60fedc816f4 100644 --- a/exir/__init__.py +++ b/exir/__init__.py @@ -5,9 +5,11 @@ from collections import namedtuple from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from unittest.mock import patch import sympy import torch +import torch._export from executorch.exir.dynamic_shape import DynamicMemoryPlanningMode from executorch.exir.emit import emit_program, EmitterOutput from executorch.exir.error import ExportError, ExportErrorType, InternalError @@ -25,6 +27,7 @@ from executorch.exir.schema import Program from executorch.exir.serialize import serialize_to_flatbuffer from executorch.exir.tracer import ( + _default_decomposition_table, dispatch_trace, dynamo_trace, ExirDynamoConfig, @@ -41,6 +44,7 @@ from torch._dynamo.eval_frame import Constraint from torch._export import CallSpec, export, ExportGraphSignature from torch._export.exported_program import ExportedProgram +from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass from torch._export.passes.add_runtime_assertions_for_constraints_pass import ( InputDim, RangeConstraint, @@ -49,12 +53,156 @@ from torch.fx._compatibility import compatibility from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.experimental.symbolic_shapes import ShapeEnv +from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo from torch.utils import _pytree as pytree Val = Any +def _unlift(gm, inp_pos_to_param_buffer_name, in_spec, out_spec, state_dict): + count = 0 + # Step 1: make lifted params as get_attr + for node in gm.graph.nodes: + if node.op == "placeholder": + if count in inp_pos_to_param_buffer_name: + with gm.graph.inserting_after(node): + getattr_node = gm.graph.get_attr( + inp_pos_to_param_buffer_name[count] + ) + node.replace_all_uses_with(getattr_node) + metadata = node.meta + gm.graph.erase_node(node) + getattr_node.meta = metadata + count += 1 + + # Step 2: Fix the input/output of the graph now that we deleted + # some args. + gm.graph.lint() + names = [f"arg_{i}" for i in range(len(in_spec.children_specs))] + gm.graph._codegen = _PyTreeCodeGen( + _PyTreeInfo( + names, + in_spec, + out_spec, + ) + ) + gm.recompile() + + # Step 3: Find state references in HigherOrderOps and recursively + # fix them. + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == torch.ops.cond: + pred, true_graph, false_graph, operands = node.args + true_gm = getattr(gm, true_graph.name) + false_gm = getattr(gm, false_graph.name) + inp_pos_to_param_buffer_name_for_submod = {} + real_operands = [] + for ix, operand in enumerate(operands): + if operand.target in inp_pos_to_param_buffer_name.values(): + inp_pos_to_param_buffer_name_for_submod[ix] = operand.target + true_gm.register_buffer(operand.target, state_dict[operand.target]) + false_gm.register_buffer(operand.target, state_dict[operand.target]) + else: + real_operands.append(operand) + node.args = (pred, true_graph, false_graph, real_operands) + + _, in_spec = pytree.tree_flatten(real_operands) + + _unlift( + true_gm, + inp_pos_to_param_buffer_name_for_submod, + in_spec, + None, + state_dict, + ) + _unlift( + false_gm, + inp_pos_to_param_buffer_name_for_submod, + in_spec, + None, + state_dict, + ) + if node.op == "call_function" and node.target.__name__ == "map_impl": + body_graph, num_mapped, *operands = node.args + body_gm = getattr(gm, body_graph.name) + inp_pos_to_buffer_name_for_submod = {} + real_operands = [] + for ix, operand in enumerate(operands): + if operand.target in inp_pos_to_param_buffer_name.values(): + inp_pos_to_buffer_name_for_submod[ix] = operand.target + body_gm.register_buffer(operand.target, state_dict[operand.target]) + else: + real_operands.append(operand) + node.args = (body_graph, num_mapped, *real_operands) + + _, in_spec = pytree.tree_flatten(real_operands) + + _unlift( + body_gm, inp_pos_to_buffer_name_for_submod, in_spec, None, state_dict + ) + gm.graph.lint() + gm.graph.eliminate_dead_code() + gm.recompile() + return gm + + +def unlift_exported_program_lifted_states( + ep: torch._export.exported_program.ExportedProgram, +): + new_gm = copy.deepcopy(ep.graph_module) + + # TODO Fix the period in params/buffers names later + # maybe a pass to replace graph signature with fixed names + param_buffer_name_to_corrected_name = {} + + for name, stuff in ep.state_dict.items(): + if name in ep.graph_signature.buffers: + if "." in name: + new_gm.register_buffer(name.replace(".", "_"), stuff) + param_buffer_name_to_corrected_name[name] = name.replace(".", "_") + else: + new_gm.register_buffer(name, stuff) + elif name in ep.graph_signature.parameters: + if "." in name: + new_gm.register_parameter(name.replace(".", "_"), stuff) + param_buffer_name_to_corrected_name[name] = name.replace(".", "_") + else: + new_gm.register_parameter(name, stuff) + else: + raise AssertionError("encountered not registered param/buffer") + + count = 0 + inp_pos_to_param_buffer_name = {} + for node in new_gm.graph.nodes: + if node.op == "placeholder": + if node.name in ep.graph_signature.inputs_to_buffers: + buffer_name = ep.graph_signature.inputs_to_buffers[node.name] + if buffer_name in param_buffer_name_to_corrected_name: + inp_pos_to_param_buffer_name[ + count + ] = param_buffer_name_to_corrected_name[buffer_name] + else: + inp_pos_to_param_buffer_name[count] = buffer_name + if node.name in ep.graph_signature.inputs_to_parameters: + param_name = ep.graph_signature.inputs_to_parameters[node.name] + if param_name in param_buffer_name_to_corrected_name: + inp_pos_to_param_buffer_name[ + count + ] = param_buffer_name_to_corrected_name[param_name] + else: + inp_pos_to_param_buffer_name[count] = param_name + count += 1 + new_gm = _unlift( + new_gm, + inp_pos_to_param_buffer_name, + ep.call_spec.in_spec, + ep.call_spec.out_spec, + ep.state_dict, + ) + return new_gm + + @compatibility(is_backward_compatible=False) @dataclass class CaptureConfig: @@ -63,6 +211,7 @@ class CaptureConfig: enable_dynamic_shape: bool = False enable_aot: bool = False _dynamo_config: "ExirDynamoConfig" = ExirDynamoConfig() + _unlift: bool = False @compatibility(is_backward_compatible=False) @@ -400,8 +549,15 @@ def capture( "Functionalization is required for enable_aot.", ) - ep = export(f, args, _add_runtime_assertions=False, constraints=constraints) - return ep # pyre-ignore + # TODO remove this later + with patch("torch._export.DECOMP_TABLE", _default_decomposition_table()): + ep = export( + f, args, _add_runtime_assertions=False, constraints=constraints + ) + ep = ep.transform(ReplaceViewOpsWithViewCopyOpsPass()) + if not config._unlift: + return ep # pyre-ignore + graph_module = unlift_exported_program_lifted_states(ep) elif config.enable_dynamic_shape: if not config._dynamo_config.dynamic_shapes: diff --git a/exir/dialects/edge/edge.yaml b/exir/dialects/edge/edge.yaml index 3c8ee2f1a47..0653944eab6 100644 --- a/exir/dialects/edge/edge.yaml +++ b/exir/dialects/edge/edge.yaml @@ -89,6 +89,14 @@ mat2: T0 __ret_0: T0 +- func: aten::arange.start_step + namespace: edge + inherits: aten::arange.start_step + type_alias: + T0: [Byte, Char, Double, Float, Int, Long, Short] + type_constraint: + - __ret_0: T0 + - func: aten::bmm namespace: edge inherits: aten::bmm @@ -198,14 +206,43 @@ - self: T0 __ret_0: T0 -- func: aten::lift_fresh_copy +- func: aten::index_select namespace: edge - inherits: aten::lift_fresh_copy + inherits: aten::index_select type_alias: - T0: [Bool, Byte, Char, Double, Float, Int, Long, Short] + T0: [Bool] + T1: [Byte] + T2: [Char] + T3: [Double] + T4: [Float] + T5: [Int] + T6: [Long] + T7: [Short] type_constraint: - self: T0 + index: T6 __ret_0: T0 + - self: T1 + index: T6 + __ret_0: T1 + - self: T2 + index: T6 + __ret_0: T2 + - self: T3 + index: T6 + __ret_0: T3 + - self: T4 + index: T6 + __ret_0: T4 + - self: T5 + index: T6 + __ret_0: T5 + - self: T6 + index: T6 + __ret_0: T6 + - self: T7 + index: T6 + __ret_0: T7 - func: aten::masked_fill.Scalar namespace: edge @@ -245,16 +282,6 @@ mask: T0 __ret_0: T7 -- func: aten::minimum - namespace: edge - inherits: aten::minimum - type_alias: - T0: [Bool, Byte, Char, Double, Float, Int, Long, Short] - type_constraint: - - self: T0 - other: T0 - __ret_0: T0 - - func: aten::mm namespace: edge inherits: aten::mm @@ -324,15 +351,6 @@ - self: T0 __ret_0: T0 -- func: aten::select_copy.int - namespace: edge - inherits: aten::select_copy.int - type_alias: - T0: [Bool, Byte, Char, Double, Float, Int, Long, Short] - type_constraint: - - self: T0 - __ret_0: T0 - - func: aten::sigmoid namespace: edge inherits: aten::sigmoid @@ -383,9 +401,25 @@ other: T0 __ret_0: T0 -- func: aten::t +- func: aten::sym_numel + namespace: edge + inherits: aten::sym_numel + type_alias: + T0: [Bool, Byte, Char, Double, Float, Int, Long, Short] + type_constraint: + - self: T0 + +- func: aten::sym_size.int + namespace: edge + inherits: aten::sym_size.int + type_alias: + T0: [Bool, Byte, Char, Double, Float, Int, Long, Short] + type_constraint: + - self: T0 + +- func: aten::t_copy namespace: edge - inherits: aten::t + inherits: aten::t_copy type_alias: T0: [Bool, Byte, Char, Double, Float, Int, Long, Short] type_constraint: diff --git a/exir/dialects/edge/yaml_generator.py b/exir/dialects/edge/yaml_generator.py index 41f85ac9b76..473dd5bf29b 100644 --- a/exir/dialects/edge/yaml_generator.py +++ b/exir/dialects/edge/yaml_generator.py @@ -143,6 +143,8 @@ def get_test_gen_key(op_name: str) -> str: opdb_key = opdb_key[:-5] elif opdb_key == "sym_size": opdb_key = "resize_" + elif opdb_key == "sym_numel": + opdb_key = "abs" elif opdb_key == "convolution": opdb_key = "conv_transpose2d" elif opdb_key == "embedding": diff --git a/exir/tests/test_quant_lowering_custom_backend_pass.py b/exir/tests/test_quant_lowering_custom_backend_pass.py index 4dd2ee0a63e..a8398f06794 100644 --- a/exir/tests/test_quant_lowering_custom_backend_pass.py +++ b/exir/tests/test_quant_lowering_custom_backend_pass.py @@ -642,11 +642,7 @@ def test_quantized_linear_dynamic(self) -> None: ) # Step 2: EXIR capturing - dynamo_config = ExirDynamoConfig( - dynamic_shapes=False, - ) - - capture_config = CaptureConfig(pt2_mode=True, _dynamo_config=dynamo_config) + capture_config = CaptureConfig(pt2_mode=True, enable_aot=True, _unlift=True) captured_mod = ( exir.capture(converted_mod, example_inputs, config=capture_config) .to_edge(exir.EdgeCompileConfig(_check_ir_validity=False)) @@ -777,11 +773,7 @@ def test_quantized_linear_dynamic_symmetric_act_per_channel_weight(self) -> None print("converted:", converted_mod) # Step 2: EXIR capturing - dynamo_config = ExirDynamoConfig( - dynamic_shapes=False, - ) - - capture_config = CaptureConfig(pt2_mode=True, _dynamo_config=dynamo_config) + capture_config = CaptureConfig(pt2_mode=True, enable_aot=True, _unlift=True) captured_mod = exir.capture( converted_mod, example_inputs, config=capture_config ).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False)) @@ -839,10 +831,7 @@ def test_quantized_linear_dynamic_symmetric_act_per_tensor_weight(self) -> None: print("converted:", converted_mod) # Step 2: EXIR capturing - dynamo_config = ExirDynamoConfig( - dynamic_shapes=False, - ) - capture_config = CaptureConfig(pt2_mode=True, _dynamo_config=dynamo_config) + capture_config = CaptureConfig(pt2_mode=True, enable_aot=True, _unlift=True) captured_mod = exir.capture( converted_mod, example_inputs, config=capture_config ).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False)) diff --git a/exir/tests/test_tracer.py b/exir/tests/test_tracer.py index eca9af76ed7..58e07b5972c 100644 --- a/exir/tests/test_tracer.py +++ b/exir/tests/test_tracer.py @@ -12,10 +12,12 @@ import torch import torch.utils._pytree as pytree +from executorch.exir import CaptureConfig from executorch.exir.error import ExportError from executorch.exir.passes import DebugPass from executorch.exir.tests.common import register_additional_test_aten_ops from executorch.exir.tracer import dynamo_trace, ExirDynamoConfig, using_dynamo +from functorch.experimental.control_flow import cond, map from parameterized import parameterized from torch._export.verifier import SpecViolationError @@ -408,3 +410,135 @@ def forward(self, x): ) self.assertEqual(len(placeholder_nodes), 2) + + def test_export_unlift(self) -> None: + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("buffer", torch.ones(6, 4)) + + def forward(self, x): + return x.cos() + self.buffer.sin() + + ep = exir.capture( + Foo(), + (torch.ones(6, 4),), + exir.CaptureConfig(enable_aot=True, pt2_mode=True, _unlift=True), + ) + + self.assertTrue(torch.allclose(ep(torch.ones(6, 4)), Foo()(torch.ones(6, 4)))) + + def test_export_container_unlift(self) -> None: + class FooContainerInputOutput(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("buffer", torch.ones(6, 4)) + + def forward(self, x): + return x[0][0].cos() + x[0][1].sin() + self.buffer.sin() + + inp = ((torch.ones(6, 4), torch.ones(6, 4)),) + ep = exir.capture( + FooContainerInputOutput(), + (inp,), + CaptureConfig(pt2_mode=True, enable_aot=True, _unlift=True), + ) + self.assertTrue(torch.allclose(ep(inp), FooContainerInputOutput()(inp))) + + def test_export_container_input_unlift(self) -> None: + class FooContainerInputOutputV2(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("buffer", torch.ones(6, 4)) + + def forward(self, x, y): + return x[0].cos() + y[0].sin() + self.buffer.sin() + + inp = ((torch.ones(6, 4),), (torch.ones(6, 4),)) + ep = exir.capture( + FooContainerInputOutputV2(), + inp, + CaptureConfig(pt2_mode=True, enable_aot=True, _unlift=True), + ) + self.assertTrue(torch.allclose(ep(*inp), FooContainerInputOutputV2()(*inp))) + + def test_export_cond(self) -> None: + class A(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("buffer", torch.ones(6, 4)) + + def forward(self): + return self.buffer.cos() + + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = A() + + def forward(self, x): + def true_fn(x): + return x.cos() + self.a().sum() + + def false_fn(x): + return x.sin() + + return cond(x.shape[0] > 4, true_fn, false_fn, [x]) + + inp = torch.ones(6, 4) + ep = exir.capture( + Foo(), + (inp,), + CaptureConfig(pt2_mode=True, enable_aot=True, _unlift=True), + ) + self.assertTrue(torch.allclose(ep(torch.ones(6, 4)), Foo()(torch.ones(6, 4)))) + + def test_export_cond_map(self) -> None: + class A(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("buffer", torch.ones(6, 4)) + + def forward(self): + return self.buffer.sum() + + class Module(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = A() + + def inner(self, x, pred): + def true_fn(x): + return x + x + self.a() + + def false_fn(x): + return x * x - self.a() + + return cond(pred, true_fn, false_fn, [x]) + + def forward(self, pred, xs): + def body(x, pred): + return self.inner(x, pred) + self.a() + + return map(body, xs, pred) + + inp = torch.randn(3, 2, 1) + ep = exir.capture( + Module(), + (torch.tensor(True), inp), + CaptureConfig(pt2_mode=True, enable_aot=True, _unlift=True), + ) + + inp_test = torch.randn(3, 2, 1) + self.assertTrue( + torch.allclose( + ep(torch.tensor(True), inp_test), + Module()(torch.tensor(True), inp_test), + ) + ) + self.assertTrue( + torch.allclose( + ep(torch.tensor(False), inp_test), + Module()(torch.tensor(False), inp_test), + ) + )