Skip to content

Commit ce2d45d

Browse files
desertfiredmenig
authored andcommitted
[AOTI] Handle empty input args (pytorch#114682)
Summary: When the model takes no inputs, AOTInductor relies on checking weights to figure out which device to compile the model into. Currently recording buffer device type happens too late, and this PR fixes that. Pull Request resolved: pytorch#114682 Approved by: https://github.com/chenyang78
1 parent 516691c commit ce2d45d

File tree

4 files changed

+26
-16
lines changed

4 files changed

+26
-16
lines changed

test/inductor/test_aot_inductor.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1465,6 +1465,20 @@ def forward(self, inp):
14651465
inputs = (torch.rand(4, 4, 4, 4, device=self.device),)
14661466
self.check_model(Model(4), inputs)
14671467

1468+
def test_no_args(self):
1469+
class Model(torch.nn.Module):
1470+
def __init__(self, m, n):
1471+
super().__init__()
1472+
self.weight = torch.nn.Parameter(
1473+
torch.randn(m, n),
1474+
)
1475+
self.alpha = torch.nn.Parameter(torch.randn(m, n))
1476+
1477+
def forward(self):
1478+
return self.weight * self.alpha
1479+
1480+
self.check_model(Model(6, 4), ())
1481+
14681482

14691483
common_utils.instantiate_parametrized_tests(AOTInductorTestsTemplate)
14701484

torch/_inductor/graph.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -472,9 +472,10 @@ def warn_fallback(self, name):
472472
self._warned_fallback.add(name)
473473
perf_hint_log.info("Using FallbackKernel: %s", name)
474474

475-
def add_device_idx(self, idx: Optional[int]):
476-
if idx is not None:
477-
self.device_idxs.add(idx)
475+
def add_device_info(self, device: torch.device):
476+
self.device_types.add(device.type)
477+
if device.index is not None:
478+
self.device_idxs.add(device.index)
478479

479480
@property
480481
def fake_mode(self):
@@ -521,6 +522,9 @@ def register_buffer(self, buffer: ir.Buffer):
521522
name = f"buf{len(self.buffers)}"
522523
self.buffers.append(buffer)
523524
self.name_to_buffer[name] = buffer
525+
# Skip empty CPU tensor so that CUDA graphs can succeed, see https://github.com/pytorch/pytorch/pull/114144
526+
if not isinstance(buffer, ir.ComputedBuffer) or not buffer.is_zero_elements():
527+
self.add_device_info(buffer.get_device())
524528
return name
525529

526530
def register_list(self, buffer_names: List[str]):
@@ -645,8 +649,7 @@ def placeholder(self, target: str, args, kwargs):
645649
)
646650
self.graph_inputs[target] = tensor
647651
self.graph_inputs_original[target] = tensor.data.data
648-
self.device_types.add(example.device.type)
649-
self.add_device_idx(example.device.index)
652+
self.add_device_info(example.device)
650653
return tensor
651654

652655
def call_function(self, target, args, kwargs):
@@ -979,10 +982,6 @@ def init_wrapper_code(self):
979982
return
980983

981984
device_types = self.device_types.copy()
982-
# In terms of some operations that don't have input tensors, we need to
983-
# check the device of the buffers.
984-
for buffer in self.buffers:
985-
device_types.add(buffer.get_device().type)
986985
device_types.discard("cpu")
987986
# TODO(Eikan): Only support mixing cpu and other device now.
988987
assert len(device_types) <= 1, "Does not support mixing {}".format(
@@ -1015,7 +1014,7 @@ def materialize(x):
10151014
else:
10161015
assert isinstance(
10171016
x, torch.Tensor
1018-
), "Unknown type when creating real inputs"
1017+
), "Unknown type when creating real inputs" + str(type(x))
10191018
return x
10201019

10211020
with torch.utils._python_dispatch._disable_current_modes():

torch/_inductor/ir.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4160,10 +4160,8 @@ def create(cls, x, device):
41604160
):
41614161
return x.constant_to_device(device)
41624162

4163-
V.graph.device_types.add(device.type)
4164-
V.graph.add_device_idx(device.index)
4165-
V.graph.device_types.add(x.get_device().type)
4166-
V.graph.add_device_idx(x.get_device().index)
4163+
V.graph.add_device_info(device)
4164+
V.graph.add_device_info(x.get_device())
41674165

41684166
developer_warning("DeviceCopy in input program")
41694167
return DeviceCopy(

torch/_inductor/scheduler.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2130,8 +2130,7 @@ def create_backend(self, device: torch.device):
21302130
assert (
21312131
device.type != "cuda" or device.index is not None
21322132
), f"{device} should have been normalized in lowering"
2133-
V.graph.device_types.add(device.type)
2134-
V.graph.add_device_idx(device.index)
2133+
V.graph.add_device_info(device)
21352134

21362135
device_scheduling = get_scheduling_for_device(device.type)
21372136
if device_scheduling is None:

0 commit comments

Comments
 (0)