|
115 | 115 | Source, |
116 | 116 | SubclassAttrListSource, |
117 | 117 | TupleIteratorGetItemSource, |
| 118 | + UnspecializedBuiltinNNModuleSource, |
| 119 | + UnspecializedNNModuleSource, |
118 | 120 | ) |
119 | 121 | from ..utils import ( |
120 | 122 | _extract_tensor_dict, |
@@ -434,7 +436,10 @@ def __call__(self, value): |
434 | 436 | return cached_vt |
435 | 437 |
|
436 | 438 | vt = self._wrap(value) |
437 | | - vt.source = self.source |
| 439 | + |
| 440 | + if vt.source is None: |
| 441 | + vt.source = self.source |
| 442 | + |
438 | 443 | if ( |
439 | 444 | self._can_lift_attrs_to_inputs(vt) |
440 | 445 | and value not in self.tx.output.side_effects |
@@ -1714,7 +1719,6 @@ def wrap_module(self, value: torch.nn.Module): |
1714 | 1719 | value = value.get_base() |
1715 | 1720 | self.source = AttrProxySource(self.source) |
1716 | 1721 |
|
1717 | | - self.install_guards(GuardBuilder.TYPE_MATCH) |
1718 | 1722 | if torch._dynamo.config.inline_inbuilt_nn_modules: |
1719 | 1723 | freezing = is_parameter_freezing() |
1720 | 1724 |
|
@@ -1749,12 +1753,23 @@ def wrap_module(self, value: torch.nn.Module): |
1749 | 1753 | # this will get cleaned up once compile ends |
1750 | 1754 | self.tx.output.nn_modules[self.name] = value |
1751 | 1755 |
|
1752 | | - if value.__module__.startswith(("torch.nn.", "torch.ao.")) or getattr( |
1753 | | - value.__class__, "_dynamo_marked_static", False |
1754 | | - ): |
1755 | | - result = UnspecializedBuiltinNNModuleVariable(value, source=self.source) |
| 1756 | + if ( |
| 1757 | + value.__module__.startswith(("torch.nn.modules", "torch.ao.")) |
| 1758 | + and not value.__module__.startswith("torch.nn.modules.container") |
| 1759 | + ) or getattr(value.__class__, "_dynamo_marked_static", False): |
| 1760 | + new_source = self.source |
| 1761 | + if config.inline_inbuilt_nn_modules: |
| 1762 | + # Export corner case - look at test_repros.py test_inlining_cornercase |
| 1763 | + new_source = UnspecializedBuiltinNNModuleSource(self.source) |
| 1764 | + result = UnspecializedBuiltinNNModuleVariable(value, source=new_source) |
| 1765 | + install_guard(new_source.make_guard(GuardBuilder.TYPE_MATCH)) |
1756 | 1766 | else: |
1757 | | - result = UnspecializedNNModuleVariable(value, source=self.source) |
| 1767 | + new_source = self.source |
| 1768 | + if config.inline_inbuilt_nn_modules: |
| 1769 | + # Export corner case - look at test_repros.py test_inlining_cornercase |
| 1770 | + new_source = UnspecializedNNModuleSource(self.source) |
| 1771 | + result = UnspecializedNNModuleVariable(value, source=new_source) |
| 1772 | + install_guard(new_source.make_guard(GuardBuilder.TYPE_MATCH)) |
1758 | 1773 |
|
1759 | 1774 | if not SideEffects.cls_supports_mutation_side_effects(type(value)): |
1760 | 1775 | # don't allow STORE_ATTR mutation with custom __setattr__ |
@@ -2127,6 +2142,10 @@ def wrap_numpy_ndarray(self, value): |
2127 | 2142 | ) |
2128 | 2143 | proxy.node.meta["grapharg"] = grapharg |
2129 | 2144 |
|
| 2145 | + # TODO - Why do we need to set the source of the np ndarray vt back to |
| 2146 | + # original source. Many tests fails. |
| 2147 | + numpy_ndarray_variable.source = self.source |
| 2148 | + |
2130 | 2149 | return numpy_ndarray_variable |
2131 | 2150 |
|
2132 | 2151 | def wrap_symint( |
|
0 commit comments