Skip to content

Whisper + Torch.Compile: torch._dynamo.exc.Unsupported: reconstruct: UserDefinedObjectVariable(EncoderDecoderCache) #31987

@kadirnar

Description

@kadirnar

System Info

- `transformers` version: 4.43.0.dev0
- Platform: Linux-6.5.0-28-generic-x86_64-with-glibc2.35
- Python version: 3.10.12
- Huggingface_hub version: 0.23.4
- Safetensors version: 0.4.3
- Accelerate version: 0.32.1
- Accelerate config:    not found
- PyTorch version (GPU?): 2.2.0+cu121 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: <fill in>
- Using GPU in script?: <fill in>
- GPU type: NVIDIA GeForce RTX 4090

Who can help?

@sanchit-gandhi, @Narsil, @SunMarc

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from datasets import load_dataset
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import torch

import torch._dynamo
torch._dynamo.config.suppress_errors = True
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
audio_sample = ds[0]["audio"]

processor = WhisperProcessor.from_pretrained("distil-whisper/distil-large-v3")
model = WhisperForConditionalGeneration.from_pretrained("distil-whisper/distil-large-v3", attn_implementation="sdpa")

model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

import time

start = time.time()
input_features = processor(
    audio_sample["array"], sampling_rate=audio_sample["sampling_rate"], return_tensors="pt"
).input_features

_ = model.generate(input_features)

predicted_ids = model.generate(input_features)

transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)

end = time.time()
print(f"Elapsed time: {end - start} seconds")

Expected behavior

I want to optimize torch.compile using the Whisper model. I also want to use the pipeline function while doing the torch.compile process. Can you also add sample code for .mp3 in the doc section?

Error Message:

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/codegen.py", line 161, in __call__
    output.extend(value.reconstruct(self))
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/base.py", line 277, in reconstruct
    raise NotImplementedError()
NotImplementedError

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/root/whisper-plus/test.py", line 23, in <module>
    _ = model.generate(input_features)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/generation_whisper.py", line 587, in generate
    outputs = super().generate(
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 1969, in generate
    result = self._sample(
  File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 2912, in _sample
    outputs = self(**model_inputs, return_dict=True)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 655, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 383, in _convert_frame_assert
    compiled_product = _compile(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 646, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 562, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 151, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 527, in transform
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2128, in run
    super().run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
    and self.step()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
    getattr(self, inst.opname)(inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2243, in RETURN_VALUE
    self.output.compile_subgraph(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 927, in compile_subgraph
    pass1.restore_stack(stack_values)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/codegen.py", line 68, in restore_stack
    self.foreach(stack_values)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/codegen.py", line 193, in foreach
    self(i)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/codegen.py", line 161, in __call__
    output.extend(value.reconstruct(self))
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/dicts.py", line 701, in reconstruct
    codegen(self.items[key])
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/codegen.py", line 163, in __call__
    unimplemented(f"reconstruct: {value}")
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/exc.py", line 193, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: reconstruct: UserDefinedObjectVariable(EncoderDecoderCache)

from user code:
   File "/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/modeling_whisper.py", line 1751, in forward
    return Seq2SeqLMOutput(

[2024-07-15 21:02:30,139] torch._dynamo.utils: [INFO] TorchDynamo compilation metrics:
[2024-07-15 21:02:30,139] torch._dynamo.utils: [INFO] Function                           Runtimes (s)
[2024-07-15 21:02:30,139] torch._dynamo.utils: [INFO] -------------------------------  --------------
[2024-07-15 21:02:30,139] torch._dynamo.utils: [INFO] _compile.<locals>.compile_inner 

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions