Skip to content

Commit d33986c

Browse files
authored
Update module wrapper so that params are explicitly registered to the wrapper (#10357)
Pull Request resolved: #10305 Seeing issue with linear where the fqns for constants disappear. Registering self.method_name as a submodule of wrapper means that the parameters are registered to the wrapper. thanks @angelayi for the fix! ``` File "/data/users/lfq/fbsource/buck-out/v2/gen/fbcode/1af94fa701700343/executorch/test/models/__export_delegated_program__/export_delegated_program#link-tree/torch/export/_trace.py", line 1980, in _export_for_training export_artifact = export_func( File "/data/users/lfq/fbsource/buck-out/v2/gen/fbcode/1af94fa701700343/executorch/test/models/__export_delegated_program__/export_delegated_program#link-tree/torch/export/_trace.py", line 1473, in _strict_export _replace_param_buffer_names(param_buffer_table, export_graph_signature) File "/data/users/lfq/fbsource/buck-out/v2/gen/fbcode/1af94fa701700343/executorch/test/models/__export_delegated_program__/export_delegated_program#link-tree/torch/export/_trace.py", line 272, in _replace_param_buffer_names spec.target = param_buffer_table[spec.target] KeyError: 'L__self___fn___self___linear.weight' ``` ghstack-source-id: 279346028 @exported-using-ghexport Differential Revision: [D73279618](https://our.internmc.facebook.com/intern/diff/D73279618/)
1 parent c0593ff commit d33986c

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

test/models/export_delegated_program.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,18 @@ def get_random_inputs(self) -> Sequence[torch.Tensor]:
9090
return (torch.ones(n, n, n), 2 * torch.ones(n, n, n), 3 * torch.ones(n, n, n))
9191

9292

93+
class ModuleLinear(torch.nn.Module):
94+
def __init__(self):
95+
super().__init__()
96+
self.linear = torch.nn.Linear(3, 3)
97+
98+
def forward(self, x: torch.Tensor):
99+
return self.linear(x)
100+
101+
def get_random_inputs(self):
102+
return (torch.randn(3),)
103+
104+
93105
#
94106
# Backends
95107
#
@@ -116,24 +128,23 @@ def export_module_to_program(
116128
extract_delegate_segments: bool,
117129
constant_tensor_alignment: Optional[int] = None,
118130
delegate_alignment: Optional[int] = None,
119-
method: str = "forward",
131+
method_name: str = "forward",
120132
) -> ExecutorchProgramManager:
121133
eager_module = module_class().eval()
122134
inputs = ()
123135
if hasattr(eager_module, "get_random_inputs"):
124136
inputs = eager_module.get_random_inputs() # type: ignore[operator]
125137

126138
class WrapperModule(torch.nn.Module):
127-
def __init__(self, fn):
139+
def __init__(self, fn, method_name=method_name):
128140
super().__init__()
129141
self.fn = fn
142+
self.method_name = method_name
130143

131144
def forward(self, *args, **kwargs):
132-
return self.fn(*args, **kwargs)
145+
return getattr(self.fn, self.method_name)(*args, **kwargs)
133146

134-
exported_program = export(
135-
WrapperModule(getattr(eager_module, method)), args=inputs, strict=True
136-
)
147+
exported_program = export(WrapperModule(eager_module), args=inputs, strict=True)
137148

138149
edge_config = EdgeCompileConfig(_check_ir_validity=False)
139150
et_config = exir.ExecutorchBackendConfig(

test/models/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def define_common_targets():
156156
"ModuleAddMul",
157157
"ModuleAddLarge",
158158
"ModuleSubLarge",
159+
"ModuleLinear",
159160
]
160161

161162
# Name of the backend to use when exporting delegated programs.

0 commit comments

Comments
 (0)