Skip to content

[Request Help] “torch._dynamo.exc.UserError" #8873

@liye0626

Description

@liye0626

Background: I tried to integrate EAGLE-2 into ExecuTorch, but encountered some errors.
Relate code:

            if i not in noleaf_index: # An error occurred at this branch
                cid = i
                depth = position_ids_list[i]
                for j in reversed(range(depth + 1)):
                    retrieve_indices[rid][j] = cid
                    cid = mask_index_list[cid - 1]
                rid += 1

Specifically, I got an error after adding a branch in model forward

  • code:
if A_Tensor: # In forward, a Boolean value or bool tensor calculated using real-time data
	...
  • error:

torch._dynamo.exc.UserError: Dynamic control flow is not supported at the moment. Please use torch.cond to explicitly capture the control flow. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands

I refer to the above page and add the module "CondBranchNestedFunction", encountered a new error

  • code:
from functorch.experimental.control_flow import cond
@torch.compile(dynamic=True, fullgraph=True) # dynamic setting True and False will both result in an error
class CondBranchNestedFunction(torch.nn.Module):
    @torch.compile(dynamic=True, fullgraph=True)
    def forward(self, tmpA, i):
        def true_fn(i):
            i+=1
            return None

        def false_fn(i):
            return None

        return cond(tmpA, true_fn, false_fn, [i])


self.condition_func = CondBranchNestedFunction()
self.condition_func = torch.compile(self.condition_func, mode='max-autotune')

tmpA = True # or bool Tensor.  In forward, a Boolean value or bool tensor calculated using real-time data
tmpB = torch.tensor(i)
self.condition_func(tmpA, tmpB)
  • error:
    "torch._dynamo.exc.UncapturedHigherOrderOpError: Cond doesn't work unless it is captured completely with torch.compile."

I have tried using torch.compile in some locations. But the error did not change. I would like to ask which direction I should go to solve the problem.

Specification are as follows:

My torch version:

torch                     2.6.0.dev20241224+cpu

torchaudio                2.6.0.dev20241224+cpu

torchvision               0.22.0.dev20241224+cpu

ExecuTorch commit: 86cb5d7

The changes to ExecuTorch's llama_transformer.py(link: llama_transformer.py ) are as follows:

Add the following code to line 25 (link:llama_transformer.py)

from functorch.experimental.control_flow import cond
@torch.compile(dynamic=True, fullgraph=True) # dynamic设为True和False均会出错
class CondBranchNestedFunction(torch.nn.Module):
    @torch.compile(dynamic=True, fullgraph=True)
    def forward(self, tmpA, i):
        def true_fn(i):
            i+=1
            return None

        def false_fn(i):
            return None

        return cond(tmpA, true_fn, false_fn, [i])

Add the following code to line 425 (link:llama_transformer.py)

        self.condition_func = CondBranchNestedFunction()
        self.condition_func = torch.compile(self.condition_func, mode='max-autotune')

Add the following code to line 544(link:llama_transformer.py#L544)

        for i in range(10):
            tmpA = (i == logits).int().sum() == 0
            tmpB = torch.tensor(i)
            self.condition_func(tmpA, tmpB) 
python -m examples.models.llama.export_llama \
    --model "llama3_2" \
    --checkpoint $MODEL_PATH \
    --params $PARAM_PATH \
    -kv \
    -X \
    -d bf16 \
    --output_name=$OUTPUT_NAME \
    --use_sdpa_with_kv_cache \

Then the above error will occur

torch._dynamo.exc.UncapturedHigherOrderOpError: Cond doesn't work unless it is captured completely with torch.compile. Scroll up to find out what causes the graph break.

cc @JacobSzwejbka @angelayi @mergennachin @byjlw

Metadata

Metadata

Assignees

Labels

module: exirIssues related to Export IR and the code under exir/module: user experienceIssues related to reducing friction for userstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

Status

To triage

Status

To triage

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions