Skip to content

Commit 78cd144

Browse files
committed
Add audio to multimodal runner
ghstack-source-id: 98599b7 Pull Request resolved: #13662
1 parent 6e520e2 commit 78cd144

File tree

4 files changed

+266
-20
lines changed

4 files changed

+266
-20
lines changed

extension/llm/runner/audio.h

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// A simple audio struct.
10+
11+
#pragma once
12+
#include <executorch/runtime/platform/compiler.h>
13+
#include <cstdint>
14+
#include <vector>
15+
16+
namespace executorch {
17+
namespace extension {
18+
namespace llm {
19+
20+
/**
21+
* Audio inputs as a raw audio tensor, for use when the audio processing
22+
* into a mel spectrogram is baked into the audio encoder with torch.export.
23+
*/
24+
struct ET_EXPERIMENTAL RawAudio {
25+
std::vector<uint8_t> data;
26+
int32_t batch_size;
27+
int32_t n_channels; // For mono, use n_channels = 1.
28+
int32_t n_samples;
29+
};
30+
31+
/**
32+
* Audio inputs as a mel spectrogram, ready to feed directly into an audio
33+
* encoder.
34+
*/
35+
struct ET_EXPERIMENTAL Audio {
36+
std::vector<uint8_t> data;
37+
int32_t batch_size;
38+
int32_t n_bins;
39+
int32_t n_frames;
40+
};
41+
42+
} // namespace llm
43+
} // namespace extension
44+
} // namespace executorch
45+
46+
namespace torch {
47+
namespace executor {
48+
// TODO(T197294990): Remove these deprecated aliases once all users have moved
49+
// to the new `::executorch` namespaces.
50+
using ::executorch::extension::llm::Audio;
51+
} // namespace executor
52+
} // namespace torch

extension/llm/runner/constants.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ inline constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache";
2121

2222
// Multimodal method name conventions
2323
inline constexpr auto kImageEncoderMethod = "image_encoder";
24-
inline constexpr auto kTokenEmbeddingMethod = "token_embedding";
25-
inline constexpr auto kTextModelMethod = "text_model";
24+
inline constexpr auto kAudioEncoderMethod = "audio_encoder";
25+
inline constexpr auto kTokenEmbeddingMethod = "token_embeddings";
26+
inline constexpr auto kTextModelMethod = "decoder";
2627

2728
} // namespace executorch::extension::llm

extension/llm/runner/multimodal_input.h

Lines changed: 153 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#pragma once
1313

14+
#include <executorch/extension/llm/runner/audio.h>
1415
#include <executorch/extension/llm/runner/image.h>
1516
#include <executorch/runtime/platform/compiler.h>
1617
#include <string>
@@ -19,19 +20,31 @@
1920
namespace executorch::extension::llm {
2021

2122
/**
22-
* A generic class to hold either image or text data for multimodal inputs.
23-
* This allows the generate() API to take a std::vector of these objects
24-
* instead of separate image and text parameters.
23+
* A generic class to hold either image, text, or audio data for multimodal
24+
* inputs. This allows the generate() API to take a std::vector of these objects
25+
* instead of separate image, text, and audio parameters.
2526
*/
2627
class ET_EXPERIMENTAL MultimodalInput {
2728
public:
28-
enum class Type { TEXT, IMAGE };
29+
/// Type of multimodal input data
30+
enum class Type {
31+
TEXT, ///< Text string input
32+
IMAGE, ///< Processed image input
33+
AUDIO, ///< Processed audio input (post-mel spectrogram processing)
34+
RAW_AUDIO, ///< Raw unprocessed audio input (straight from audio file)
35+
UNSUPPORTED ///< Unsupported input type
36+
};
2937

3038
// Constructors
3139
explicit MultimodalInput(const std::string& text) : data_(text) {}
3240
explicit MultimodalInput(std::string&& text) : data_(std::move(text)) {}
3341
explicit MultimodalInput(const Image& image) : data_(image) {}
3442
explicit MultimodalInput(Image&& image) : data_(std::move(image)) {}
43+
explicit MultimodalInput(const Audio& audio) : data_(audio) {}
44+
explicit MultimodalInput(Audio&& audio) : data_(std::move(audio)) {}
45+
explicit MultimodalInput(const RawAudio& raw_audio) : data_(raw_audio) {}
46+
explicit MultimodalInput(RawAudio&& raw_audio)
47+
: data_(std::move(raw_audio)) {}
3548

3649
// Copy constructor and assignment
3750
MultimodalInput(const MultimodalInput& other) = default;
@@ -60,12 +73,37 @@ class ET_EXPERIMENTAL MultimodalInput {
6073
return std::holds_alternative<Image>(data_);
6174
}
6275

76+
/**
77+
* Check if this input contains audio data.
78+
* @return true if this input contains audio, false otherwise.
79+
*/
80+
bool is_audio() const noexcept {
81+
return std::holds_alternative<Audio>(data_);
82+
}
83+
84+
/**
85+
* Check if this input contains raw audio data.
86+
* @return true if this input contains raw audio, false otherwise.
87+
*/
88+
bool is_raw_audio() const noexcept {
89+
return std::holds_alternative<RawAudio>(data_);
90+
}
91+
6392
/**
6493
* Get the type of data stored in this input.
65-
* @return Type::TEXT if text data, Type::IMAGE if image data.
94+
* @return Type::TEXT if text data, Type::IMAGE if image data, Type::AUDIO if
95+
* audio data, Type::RAW_AUDIO if raw audio data.
6696
*/
6797
Type get_type() const noexcept {
68-
return is_text() ? Type::TEXT : Type::IMAGE;
98+
if (is_text())
99+
return Type::TEXT;
100+
if (is_image())
101+
return Type::IMAGE;
102+
if (is_audio())
103+
return Type::AUDIO;
104+
if (is_raw_audio())
105+
return Type::RAW_AUDIO;
106+
return Type::UNSUPPORTED;
69107
}
70108

71109
/**
@@ -122,6 +160,60 @@ class ET_EXPERIMENTAL MultimodalInput {
122160
return std::get<Image>(std::move(data_));
123161
}
124162

163+
/**
164+
* Get the audio data from this input.
165+
* @return Reference to the stored Audio object.
166+
* @throws std::bad_variant_access if this input doesn't contain audio.
167+
*/
168+
const Audio& get_audio() const& {
169+
return std::get<Audio>(data_);
170+
}
171+
172+
/**
173+
* Get the audio data from this input (mutable version).
174+
* @return Mutable reference to the stored Audio object.
175+
* @throws std::bad_variant_access if this input doesn't contain audio.
176+
*/
177+
Audio& get_audio() & {
178+
return std::get<Audio>(data_);
179+
}
180+
181+
/**
182+
* Get the audio data from this input (rvalue version).
183+
* @return Rvalue reference to the stored Audio object for efficient moves.
184+
* @throws std::bad_variant_access if this input doesn't contain audio.
185+
*/
186+
Audio&& get_audio() && {
187+
return std::get<Audio>(std::move(data_));
188+
}
189+
190+
/**
191+
* Get the raw audio data from this input.
192+
* @return Reference to the stored RawAudio object.
193+
* @throws std::bad_variant_access if this input doesn't contain raw audio.
194+
*/
195+
const RawAudio& get_raw_audio() const& {
196+
return std::get<RawAudio>(data_);
197+
}
198+
199+
/**
200+
* Get the raw audio data from this input (mutable version).
201+
* @return Mutable reference to the stored RawAudio object.
202+
* @throws std::bad_variant_access if this input doesn't contain raw audio.
203+
*/
204+
RawAudio& get_raw_audio() & {
205+
return std::get<RawAudio>(data_);
206+
}
207+
208+
/**
209+
* Get the raw audio data from this input (rvalue version).
210+
* @return Rvalue reference to the stored RawAudio object for efficient moves.
211+
* @throws std::bad_variant_access if this input doesn't contain raw audio.
212+
*/
213+
RawAudio&& get_raw_audio() && {
214+
return std::get<RawAudio>(std::move(data_));
215+
}
216+
125217
/**
126218
* Try to get the text data from this input safely.
127219
* @return Pointer to the text string if this input contains text, nullptr
@@ -158,8 +250,44 @@ class ET_EXPERIMENTAL MultimodalInput {
158250
return std::get_if<Image>(&data_);
159251
}
160252

253+
/**
254+
* Try to get the audio data from this input safely.
255+
* @return Pointer to the Audio object if this input contains audio,
256+
* nullptr otherwise.
257+
*/
258+
const Audio* try_get_audio() const noexcept {
259+
return std::get_if<Audio>(&data_);
260+
}
261+
262+
/**
263+
* Try to get the audio data from this input safely (mutable version).
264+
* @return Pointer to the Audio object if this input contains audio,
265+
* nullptr otherwise.
266+
*/
267+
Audio* try_get_audio() noexcept {
268+
return std::get_if<Audio>(&data_);
269+
}
270+
271+
/**
272+
* Try to get the raw audio data from this input safely.
273+
* @return Pointer to the RawAudio object if this input contains raw audio,
274+
* nullptr otherwise.
275+
*/
276+
const RawAudio* try_get_raw_audio() const noexcept {
277+
return std::get_if<RawAudio>(&data_);
278+
}
279+
280+
/**
281+
* Try to get the raw audio data from this input safely (mutable version).
282+
* @return Pointer to the RawAudio object if this input contains raw audio,
283+
* nullptr otherwise.
284+
*/
285+
RawAudio* try_get_raw_audio() noexcept {
286+
return std::get_if<RawAudio>(&data_);
287+
}
288+
161289
private:
162-
std::variant<std::string, Image> data_;
290+
std::variant<std::string, Image, Audio, RawAudio> data_;
163291
};
164292

165293
// Convenience factory functions
@@ -179,4 +307,21 @@ inline MultimodalInput make_image_input(Image&& image) noexcept {
179307
return MultimodalInput(std::move(image));
180308
}
181309

182-
} // namespace executorch::extension::llm
310+
inline MultimodalInput make_audio_input(const Audio& audio) noexcept {
311+
return MultimodalInput(audio);
312+
}
313+
314+
inline MultimodalInput make_audio_input(Audio&& audio) noexcept {
315+
return MultimodalInput(std::move(audio));
316+
}
317+
318+
inline MultimodalInput make_raw_audio_input(
319+
const RawAudio& raw_audio) noexcept {
320+
return MultimodalInput(raw_audio);
321+
}
322+
323+
inline MultimodalInput make_raw_audio_input(RawAudio&& raw_audio) noexcept {
324+
return MultimodalInput(std::move(raw_audio));
325+
}
326+
327+
} // namespace executorch::extension::llm

extension/llm/runner/multimodal_prefiller.cpp

Lines changed: 58 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ MultimodalPrefiller::MultimodalPrefiller(
3737
Result<uint64_t> MultimodalPrefiller::prefill(
3838
const MultimodalInput& input,
3939
int64_t& start_pos) {
40-
// Check if input is image
40+
// 1. Run encoder model.
4141
::executorch::runtime::EValue encoder_output;
4242
if (input.is_image()) {
4343
Image image = input.get_image();
@@ -51,34 +51,77 @@ Result<uint64_t> MultimodalPrefiller::prefill(
5151
ET_UNWRAP(module_->execute(kImageEncoderMethod, image_tensor));
5252

5353
encoder_output = image_encoder_outputs[0];
54+
} else if (input.is_audio()) {
55+
Audio audio = input.get_audio();
56+
57+
// Use the original tensor shape as intended
58+
auto audio_tensor = executorch::extension::from_blob(
59+
audio.data.data(),
60+
{audio.batch_size, audio.n_bins, audio.n_frames},
61+
::executorch::aten::ScalarType::Float);
62+
63+
// Run audio encoder
64+
auto audio_encoder_result =
65+
module_->execute(kAudioEncoderMethod, audio_tensor);
66+
if (audio_encoder_result.error() != ::executorch::runtime::Error::Ok) {
67+
return ::executorch::runtime::Error::Internal;
68+
}
69+
auto audio_encoder_outputs = audio_encoder_result.get();
70+
71+
encoder_output = audio_encoder_outputs[0];
5472
} else if (input.is_text()) {
55-
// For text input, we don't need to run the image encoder.
56-
// Instead, we run the text encoder to get the encoder output.
5773
auto& text = input.get_text();
5874
std::vector<uint64_t> tokens =
5975
ET_UNWRAP_TOKENIZER(tokenizer_->encode(text));
76+
6077
auto text_tensor = executorch::extension::from_blob(
6178
tokens.data(),
6279
{1, static_cast<aten::SizesType>(tokens.size())},
6380
::executorch::aten::ScalarType::Long);
6481

65-
// Run token embedding
82+
// Run text encoder (token embeddings)
6683
auto token_embedding_outputs =
6784
ET_UNWRAP(module_->execute(kTokenEmbeddingMethod, text_tensor));
6885

6986
encoder_output = token_embedding_outputs[0];
7087
} else {
7188
ET_LOG(Error, "Unsupported input type");
72-
// For all other input types (e.g., audio), return error
89+
// For any other input types, return error
7390
return ::executorch::runtime::Error::NotSupported;
7491
}
7592

76-
auto outputs_res =
77-
ET_UNWRAP(text_decoder_runner_->decode(encoder_output, start_pos));
93+
// 2. Run decoder model for prefill.
94+
// `cache_position` goes from start_pos to start_pos + encoder_output.size(1).
95+
// e.g. if start_pos = 2 and encoder_output.size(1) = 5,
96+
// cache_position_tensor should be [2, 3, 4, 5, 6].
97+
int64_t seq_len = encoder_output.toTensor().size(1);
98+
if (seq_len == 0) {
99+
ET_LOG(Error, "The encoder returned an empty output.");
100+
return ::executorch::runtime::Error::InvalidState;
101+
}
102+
std::vector<int64_t> cache_positions(seq_len);
103+
for (int64_t i = 0; i < seq_len; ++i) {
104+
cache_positions[i] = start_pos + i;
105+
}
106+
auto cache_position_tensor = ::executorch::extension::from_blob(
107+
cache_positions.data(), {seq_len}, executorch::aten::ScalarType::Long);
108+
auto prefill_result = module_->execute(
109+
kTextModelMethod, {cache_position_tensor, encoder_output});
110+
if (prefill_result.error() != ::executorch::runtime::Error::Ok) {
111+
return prefill_result.error();
112+
}
113+
// Check if prefill_outputs is empty, if it is return error and log that the
114+
// specified encoder returned empty results when used to prefill decoder.
115+
auto prefill_outputs = prefill_result.get();
116+
if (prefill_outputs.empty()) {
117+
ET_LOG(
118+
Error, "Encoder returned empty results when used to prefill decoder");
119+
return ::executorch::runtime::Error::InvalidState;
120+
}
121+
auto outputs_res = prefill_outputs[0].toTensor();
78122

79-
// Update the start_pos, which is only available inside this function.
80-
// outputs_res can have only one logits.
81-
start_pos += encoder_output.toTensor().size(1);
123+
// Update start_pos, tracking the current cache position.
124+
start_pos += seq_len;
82125

83126
return static_cast<uint64_t>(
84127
text_decoder_runner_->logits_to_token(outputs_res));
@@ -103,6 +146,11 @@ ::executorch::runtime::Error MultimodalPrefiller::load() {
103146
if (methods.find(kImageEncoderMethod) != methods.end()) {
104147
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kImageEncoderMethod));
105148
}
149+
150+
if (methods.find(kAudioEncoderMethod) != methods.end()) {
151+
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kAudioEncoderMethod));
152+
}
153+
106154
return ::executorch::runtime::Error::Ok;
107155
}
108156

0 commit comments

Comments
 (0)