Skip to content

Commit ef6dcf8

Browse files
committed
Move torch.cond predicate non-persistent buffer to CPU
Avoid device-to-host memory copies when evaluating `torch.cond` predicates. When a GPU buffer (e.g., a KV cache `initialized` flag) is used as a predicate for `torch.cond`, the runtime must synchronize and copy the predicate value from GPU to CPU on every forward pass to evaluate the condition. This adds latency and synchronization overhead. `MoveCondPredicateToCpuPass` moves non-persistent buffer predicates to CPU at export time, eliminating per-inference D2H transfers. The predicate is typically a small scalar (e.g., a boolean flag), so keeping it on CPU has negligible memory impact. - Add `MoveCondPredicateToCpuPass` in `backends/cuda/passes/` - Add unit tests covering: - GPU buffer predicates moved to CPU - CPU buffer predicates unchanged - Computed predicates unaffected - Multiple `torch.cond` calls - Cross-attention cache pattern - Persistent buffers (state_dict) not moved - Add Python tests to `unittest-cuda` CI job in `cuda.yml` ghstack-source-id: ff22758 ghstack-comment-id: 3687889864 Pull-Request: #16378
1 parent 9ee3010 commit ef6dcf8

File tree

5 files changed

+655
-6
lines changed

5 files changed

+655
-6
lines changed

.github/workflows/cuda.yml

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ jobs:
8787
export LD_LIBRARY_PATH=/opt/conda/lib:$LD_LIBRARY_PATH
8888
PYTHON_EXECUTABLE=python source .ci/scripts/test_model.sh "${{ matrix.model }}" cmake cuda
8989
90-
test-cuda-shims:
91-
name: test-cuda-shims
90+
unittest-cuda:
91+
name: unittest-cuda
9292
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
9393
permissions:
9494
id-token: write
@@ -97,23 +97,26 @@ jobs:
9797
timeout: 90
9898
runner: linux.g5.4xlarge.nvidia.gpu
9999
gpu-arch-type: cuda
100-
gpu-arch-version: 12.6
100+
gpu-arch-version: 12.9
101101
use-custom-docker-registry: false
102102
submodules: recursive
103103
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
104104
script: |
105105
set -eux
106-
# Install requirements
107-
bash ./install_requirements.sh
106+
# Install executorch
107+
bash ./install_executorch.sh
108108
109109
# Build ExecuTorch with CUDA support
110110
cmake --workflow --preset llm-release-cuda
111111
112-
# Build and run CUDA shim tests
112+
# Build and run CUDA shim tests (C++)
113113
pushd backends/cuda/runtime/shims/tests
114114
cmake --workflow --preset default
115115
popd
116116
117+
# Run CUDA backend Python tests
118+
python -m pytest backends/cuda/tests backends/cuda/passes/tests extension/llm/custom_ops -v
119+
117120
export-model-cuda-artifact:
118121
name: export-model-cuda-artifact
119122
# Skip this job if the pull request is from a fork (HuggingFace secrets are not available)

backends/cuda/passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from torch.export import ExportedProgram
9+
10+
11+
class MoveCondPredicateToCpuPass:
12+
"""
13+
A pass that moves the predicate of torch.cond to CPU if the predicate is a constantbuffer.
14+
This is useful for models that use the predicate as a constant buffer, such as an `initialized` flag for cross attention kv cache.
15+
16+
Example:
17+
```
18+
class CrossAttentionWithCache(torch.nn.Module):
19+
def __init__(self, hidden_size):
20+
super().__init__()
21+
self.k_proj = torch.nn.Linear(hidden_size, hidden_size)
22+
self.v_proj = torch.nn.Linear(hidden_size, hidden_size)
23+
self.q_proj = torch.nn.Linear(hidden_size, hidden_size)
24+
self.out_proj = torch.nn.Linear(hidden_size, hidden_size)
25+
# Buffer used as predicate for torch.cond
26+
self.register_buffer("initialized", torch.tensor([False]))
27+
self.register_buffer("k_cache", torch.zeros(1, 10, hidden_size))
28+
self.register_buffer("v_cache", torch.zeros(1, 10, hidden_size))
29+
30+
def compute_kv(self, encoder_hidden_states):
31+
k = self.k_proj(encoder_hidden_states)
32+
v = self.v_proj(encoder_hidden_states)
33+
self.k_cache.copy_(k)
34+
self.v_cache.copy_(v)
35+
self.initialized.fill_(True)
36+
return k, v
37+
38+
def use_cached_kv(self, encoder_hidden_states):
39+
return self.k_cache.clone(), self.v_cache.clone()
40+
41+
def forward(self, hidden_states, encoder_hidden_states):
42+
q = self.q_proj(hidden_states)
43+
# Use torch.cond with initialized buffer as predicate
44+
k, v = torch.cond(
45+
self.initialized,
46+
self.use_cached_kv,
47+
self.compute_kv,
48+
(encoder_hidden_states,),
49+
)
50+
attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v)
51+
return self.out_proj(attn_output)
52+
```
53+
In this example if we keep `self.initialized` on GPU, we will need to copy it to CPU for every forward pass.
54+
We move the predicate to CPU to avoid device to host copies.
55+
This pass is only applicable to models that use torch.cond and its predicate is a constant buffer.
56+
"""
57+
58+
requires_exported_program = True
59+
60+
def __call__(self, exported_program: ExportedProgram):
61+
graph_module = exported_program.graph_module
62+
state_dict = exported_program.state_dict
63+
64+
# Map input names to buffer names
65+
inputs_to_buffers = exported_program.graph_signature.inputs_to_buffers
66+
67+
for node in graph_module.graph.nodes:
68+
if (
69+
node.op == "call_function"
70+
and node.target == torch.ops.higher_order.cond
71+
):
72+
pred_node = node.args[0]
73+
if pred_node.op == "placeholder" and pred_node.name in inputs_to_buffers:
74+
buffer_name = inputs_to_buffers[pred_node.name]
75+
76+
if buffer_name in exported_program.constants:
77+
tensor = exported_program._constants[buffer_name]
78+
if tensor.device.type != "cpu":
79+
exported_program._constants[buffer_name] = tensor.to(
80+
"cpu"
81+
)
82+
83+
# Also update the placeholder metadata
84+
if "val" in pred_node.meta:
85+
fake_tensor = pred_node.meta["val"]
86+
if isinstance(fake_tensor, torch.Tensor):
87+
pred_node.meta["val"] = fake_tensor.to("cpu")
88+
exported_program.validate()

backends/cuda/passes/tests/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)