-
Notifications
You must be signed in to change notification settings - Fork 638
Description
Background: I tried to integrate EAGLE-2 into ExecuTorch, but encountered some errors.
Relate code:
if i not in noleaf_index: # An error occurred at this branch
cid = i
depth = position_ids_list[i]
for j in reversed(range(depth + 1)):
retrieve_indices[rid][j] = cid
cid = mask_index_list[cid - 1]
rid += 1
Specifically, I got an error after adding a branch in model forward
- code:
if A_Tensor: # In forward, a Boolean value or bool tensor calculated using real-time data
...
- error:
torch._dynamo.exc.UserError: Dynamic control flow is not supported at the moment. Please use torch.cond to explicitly capture the control flow. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands
I refer to the above page and add the module "CondBranchNestedFunction", encountered a new error
- code:
from functorch.experimental.control_flow import cond
@torch.compile(dynamic=True, fullgraph=True) # dynamic setting True and False will both result in an error
class CondBranchNestedFunction(torch.nn.Module):
@torch.compile(dynamic=True, fullgraph=True)
def forward(self, tmpA, i):
def true_fn(i):
i+=1
return None
def false_fn(i):
return None
return cond(tmpA, true_fn, false_fn, [i])
self.condition_func = CondBranchNestedFunction()
self.condition_func = torch.compile(self.condition_func, mode='max-autotune')
tmpA = True # or bool Tensor. In forward, a Boolean value or bool tensor calculated using real-time data
tmpB = torch.tensor(i)
self.condition_func(tmpA, tmpB)
- error:
"torch._dynamo.exc.UncapturedHigherOrderOpError: Cond doesn't work unless it is captured completely with torch.compile."
I have tried using torch.compile in some locations. But the error did not change. I would like to ask which direction I should go to solve the problem.
Specification are as follows:
My torch version:
torch 2.6.0.dev20241224+cpu
torchaudio 2.6.0.dev20241224+cpu
torchvision 0.22.0.dev20241224+cpu
ExecuTorch commit: 86cb5d7
The changes to ExecuTorch's llama_transformer.py(link: llama_transformer.py ) are as follows:
Add the following code to line 25 (link:llama_transformer.py)
from functorch.experimental.control_flow import cond
@torch.compile(dynamic=True, fullgraph=True) # dynamic设为True和False均会出错
class CondBranchNestedFunction(torch.nn.Module):
@torch.compile(dynamic=True, fullgraph=True)
def forward(self, tmpA, i):
def true_fn(i):
i+=1
return None
def false_fn(i):
return None
return cond(tmpA, true_fn, false_fn, [i])
Add the following code to line 425 (link:llama_transformer.py)
self.condition_func = CondBranchNestedFunction()
self.condition_func = torch.compile(self.condition_func, mode='max-autotune')
Add the following code to line 544(link:llama_transformer.py#L544)
for i in range(10):
tmpA = (i == logits).int().sum() == 0
tmpB = torch.tensor(i)
self.condition_func(tmpA, tmpB)
python -m examples.models.llama.export_llama \
--model "llama3_2" \
--checkpoint $MODEL_PATH \
--params $PARAM_PATH \
-kv \
-X \
-d bf16 \
--output_name=$OUTPUT_NAME \
--use_sdpa_with_kv_cache \
Then the above error will occur
torch._dynamo.exc.UncapturedHigherOrderOpError: Cond doesn't work unless it is captured completely with torch.compile. Scroll up to find out what causes the graph break.
Metadata
Metadata
Labels
Type
Projects
Status
Status