diff --git a/torchtitan/experiments/compiler_toolkit/cudagraph.py b/torchtitan/experiments/compiler_toolkit/cudagraph.py index cd6e4cfc22..d008e5d455 100644 --- a/torchtitan/experiments/compiler_toolkit/cudagraph.py +++ b/torchtitan/experiments/compiler_toolkit/cudagraph.py @@ -98,7 +98,7 @@ def check_input_types(self, inputs) -> None: def check_static_inputs_address(self) -> None: for i in self.static_input_indices: - actual = args[i].data_ptr() + actual = self.args[i].data_ptr() expected = self.input_addresses[i] assert expected == actual, ( "Expected the same static tensor address but found "