Skip to content

Commit c2e4ac0

Browse files
committed
[ez][release blocker fix] Insert linalg_vector_norm into decomp table used for Edge export
Summary: ## Context Addresses this [release blocker](https://github.com/orgs/pytorch/projects/99/views/1?pane=issue&itemId=104088363&issue=pytorch%7Cpytorch%7C150207) issue. Some models cannot export because they use `linalg_vector_norm` which is not currently an ATen operator. I initially tried adding the op to the core decomp table, but the decomp is not passing pytorch correctness tests. Please see pytorch/pytorch#150241 for more details. ## Changes Since we currently cannot include the op in PyTorch's decomp table, instead we can insert the op into the edge decomp table directly. This PR is a simple change to add `linalg_vector_norm` directly to the edge decomp table. Test Plan: Tested exporting and running a model with the `linalg_vector_norm` op via the following script. ``` import torch from executorch.exir import to_edge_transform_and_lower, EdgeCompileConfig from torch.export import Dim, export from executorch.extension.pybindings.portable_lib import ( # @Manual _load_for_executorch_from_buffer, ) class Model(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return torch.linalg.vector_norm(x, 2) model = Model() inputs = (torch.randn(1,1,16,16),) dynamic_shapes = { "x": { 2: Dim("h", min=16, max=1024), 3: Dim("w", min=16, max=1024), } } exported_program = export(model, inputs, dynamic_shapes=dynamic_shapes) executorch_program = to_edge_transform_and_lower( exported_program, compile_config=EdgeCompileConfig(_check_ir_validity=False), ).to_executorch() executorch_module = _load_for_executorch_from_buffer( executorch_program.buffer ) model_output = executorch_module.run_method( "forward", tuple(inputs) ) print(model_output) ```
1 parent 1f5ca0c commit c2e4ac0

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

exir/tracer.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,8 +631,16 @@ def _default_decomposition_table(
631631
]
632632
# pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.e...
633633
return get_decompositions(decomp_opset)
634+
635+
decomps = default_decompositions()
636+
# Add edge specific decompositions
637+
additional_decomp_ops = [
638+
torch.ops.aten.linalg_vector_norm.default,
639+
]
640+
additional_decomps = get_decompositions(additional_decomp_ops)
641+
decomps.update(additional_decomps)
634642
# pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.exir....
635-
return default_decompositions()
643+
return decomps
636644

637645

638646
def dynamo_trace(

0 commit comments

Comments
 (0)