Skip to content

Commit c6883b1

Browse files
committed
Add audio to multimodal runner
ghstack-source-id: 75f5ea3 Pull Request resolved: #13662
1 parent 6e520e2 commit c6883b1

File tree

3 files changed

+193
-19
lines changed

3 files changed

+193
-19
lines changed

extension/llm/runner/constants.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ inline constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache";
2121

2222
// Multimodal method name conventions
2323
inline constexpr auto kImageEncoderMethod = "image_encoder";
24-
inline constexpr auto kTokenEmbeddingMethod = "token_embedding";
25-
inline constexpr auto kTextModelMethod = "text_model";
24+
inline constexpr auto kAudioEncoderMethod = "audio_encoder";
25+
inline constexpr auto kTokenEmbeddingMethod = "token_embeddings";
26+
inline constexpr auto kTextModelMethod = "decoder";
2627

2728
} // namespace executorch::extension::llm

extension/llm/runner/multimodal_input.h

Lines changed: 144 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#pragma once
1313

14+
#include <executorch/extension/llm/runner/audio.h>
1415
#include <executorch/extension/llm/runner/image.h>
1516
#include <executorch/runtime/platform/compiler.h>
1617
#include <string>
@@ -19,19 +20,24 @@
1920
namespace executorch::extension::llm {
2021

2122
/**
22-
* A generic class to hold either image or text data for multimodal inputs.
23-
* This allows the generate() API to take a std::vector of these objects
24-
* instead of separate image and text parameters.
23+
* A generic class to hold either image, text, or audio data for multimodal
24+
* inputs. This allows the generate() API to take a std::vector of these objects
25+
* instead of separate image, text, and audio parameters.
2526
*/
2627
class ET_EXPERIMENTAL MultimodalInput {
2728
public:
28-
enum class Type { TEXT, IMAGE };
29+
enum class Type { TEXT, IMAGE, AUDIO, RAW_AUDIO };
2930

3031
// Constructors
3132
explicit MultimodalInput(const std::string& text) : data_(text) {}
3233
explicit MultimodalInput(std::string&& text) : data_(std::move(text)) {}
3334
explicit MultimodalInput(const Image& image) : data_(image) {}
3435
explicit MultimodalInput(Image&& image) : data_(std::move(image)) {}
36+
explicit MultimodalInput(const Audio& audio) : data_(audio) {}
37+
explicit MultimodalInput(Audio&& audio) : data_(std::move(audio)) {}
38+
explicit MultimodalInput(const RawAudio& raw_audio) : data_(raw_audio) {}
39+
explicit MultimodalInput(RawAudio&& raw_audio)
40+
: data_(std::move(raw_audio)) {}
3541

3642
// Copy constructor and assignment
3743
MultimodalInput(const MultimodalInput& other) = default;
@@ -60,12 +66,35 @@ class ET_EXPERIMENTAL MultimodalInput {
6066
return std::holds_alternative<Image>(data_);
6167
}
6268

69+
/**
70+
* Check if this input contains audio data.
71+
* @return true if this input contains audio, false otherwise.
72+
*/
73+
bool is_audio() const noexcept {
74+
return std::holds_alternative<Audio>(data_);
75+
}
76+
77+
/**
78+
* Check if this input contains raw audio data.
79+
* @return true if this input contains raw audio, false otherwise.
80+
*/
81+
bool is_raw_audio() const noexcept {
82+
return std::holds_alternative<RawAudio>(data_);
83+
}
84+
6385
/**
6486
* Get the type of data stored in this input.
65-
* @return Type::TEXT if text data, Type::IMAGE if image data.
87+
* @return Type::TEXT if text data, Type::IMAGE if image data, Type::AUDIO if
88+
* audio data, Type::RAW_AUDIO if raw audio data.
6689
*/
6790
Type get_type() const noexcept {
68-
return is_text() ? Type::TEXT : Type::IMAGE;
91+
if (is_text())
92+
return Type::TEXT;
93+
if (is_image())
94+
return Type::IMAGE;
95+
if (is_audio())
96+
return Type::AUDIO;
97+
return Type::RAW_AUDIO;
6998
}
7099

71100
/**
@@ -122,6 +151,60 @@ class ET_EXPERIMENTAL MultimodalInput {
122151
return std::get<Image>(std::move(data_));
123152
}
124153

154+
/**
155+
* Get the audio data from this input.
156+
* @return Reference to the stored Audio object.
157+
* @throws std::bad_variant_access if this input doesn't contain audio.
158+
*/
159+
const Audio& get_audio() const& {
160+
return std::get<Audio>(data_);
161+
}
162+
163+
/**
164+
* Get the audio data from this input (mutable version).
165+
* @return Mutable reference to the stored Audio object.
166+
* @throws std::bad_variant_access if this input doesn't contain audio.
167+
*/
168+
Audio& get_audio() & {
169+
return std::get<Audio>(data_);
170+
}
171+
172+
/**
173+
* Get the audio data from this input (rvalue version).
174+
* @return Rvalue reference to the stored Audio object for efficient moves.
175+
* @throws std::bad_variant_access if this input doesn't contain audio.
176+
*/
177+
Audio&& get_audio() && {
178+
return std::get<Audio>(std::move(data_));
179+
}
180+
181+
/**
182+
* Get the raw audio data from this input.
183+
* @return Reference to the stored RawAudio object.
184+
* @throws std::bad_variant_access if this input doesn't contain raw audio.
185+
*/
186+
const RawAudio& get_raw_audio() const& {
187+
return std::get<RawAudio>(data_);
188+
}
189+
190+
/**
191+
* Get the raw audio data from this input (mutable version).
192+
* @return Mutable reference to the stored RawAudio object.
193+
* @throws std::bad_variant_access if this input doesn't contain raw audio.
194+
*/
195+
RawAudio& get_raw_audio() & {
196+
return std::get<RawAudio>(data_);
197+
}
198+
199+
/**
200+
* Get the raw audio data from this input (rvalue version).
201+
* @return Rvalue reference to the stored RawAudio object for efficient moves.
202+
* @throws std::bad_variant_access if this input doesn't contain raw audio.
203+
*/
204+
RawAudio&& get_raw_audio() && {
205+
return std::get<RawAudio>(std::move(data_));
206+
}
207+
125208
/**
126209
* Try to get the text data from this input safely.
127210
* @return Pointer to the text string if this input contains text, nullptr
@@ -158,8 +241,44 @@ class ET_EXPERIMENTAL MultimodalInput {
158241
return std::get_if<Image>(&data_);
159242
}
160243

244+
/**
245+
* Try to get the audio data from this input safely.
246+
* @return Pointer to the Audio object if this input contains audio,
247+
* nullptr otherwise.
248+
*/
249+
const Audio* try_get_audio() const noexcept {
250+
return std::get_if<Audio>(&data_);
251+
}
252+
253+
/**
254+
* Try to get the audio data from this input safely (mutable version).
255+
* @return Pointer to the Audio object if this input contains audio,
256+
* nullptr otherwise.
257+
*/
258+
Audio* try_get_audio() noexcept {
259+
return std::get_if<Audio>(&data_);
260+
}
261+
262+
/**
263+
* Try to get the raw audio data from this input safely.
264+
* @return Pointer to the RawAudio object if this input contains raw audio,
265+
* nullptr otherwise.
266+
*/
267+
const RawAudio* try_get_raw_audio() const noexcept {
268+
return std::get_if<RawAudio>(&data_);
269+
}
270+
271+
/**
272+
* Try to get the raw audio data from this input safely (mutable version).
273+
* @return Pointer to the RawAudio object if this input contains raw audio,
274+
* nullptr otherwise.
275+
*/
276+
RawAudio* try_get_raw_audio() noexcept {
277+
return std::get_if<RawAudio>(&data_);
278+
}
279+
161280
private:
162-
std::variant<std::string, Image> data_;
281+
std::variant<std::string, Image, Audio, RawAudio> data_;
163282
};
164283

165284
// Convenience factory functions
@@ -179,4 +298,21 @@ inline MultimodalInput make_image_input(Image&& image) noexcept {
179298
return MultimodalInput(std::move(image));
180299
}
181300

182-
} // namespace executorch::extension::llm
301+
inline MultimodalInput make_audio_input(const Audio& audio) noexcept {
302+
return MultimodalInput(audio);
303+
}
304+
305+
inline MultimodalInput make_audio_input(Audio&& audio) noexcept {
306+
return MultimodalInput(std::move(audio));
307+
}
308+
309+
inline MultimodalInput make_raw_audio_input(
310+
const RawAudio& raw_audio) noexcept {
311+
return MultimodalInput(raw_audio);
312+
}
313+
314+
inline MultimodalInput make_raw_audio_input(RawAudio&& raw_audio) noexcept {
315+
return MultimodalInput(std::move(raw_audio));
316+
}
317+
318+
} // namespace executorch::extension::llm

extension/llm/runner/multimodal_prefiller.cpp

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ MultimodalPrefiller::MultimodalPrefiller(
3737
Result<uint64_t> MultimodalPrefiller::prefill(
3838
const MultimodalInput& input,
3939
int64_t& start_pos) {
40-
// Check if input is image
40+
// 1. Run encoder model.
4141
::executorch::runtime::EValue encoder_output;
4242
if (input.is_image()) {
4343
Image image = input.get_image();
@@ -51,33 +51,65 @@ Result<uint64_t> MultimodalPrefiller::prefill(
5151
ET_UNWRAP(module_->execute(kImageEncoderMethod, image_tensor));
5252

5353
encoder_output = image_encoder_outputs[0];
54+
} else if (input.is_audio()) {
55+
Audio audio = input.get_audio();
56+
57+
// Use the original tensor shape as intended
58+
auto audio_tensor = executorch::extension::from_blob(
59+
audio.data.data(),
60+
{audio.batch_size, audio.n_bins, audio.n_frames},
61+
::executorch::aten::ScalarType::Float);
62+
63+
// Run audio encoder
64+
auto audio_encoder_result =
65+
module_->execute(kAudioEncoderMethod, audio_tensor);
66+
if (audio_encoder_result.error() != ::executorch::runtime::Error::Ok) {
67+
return ::executorch::runtime::Error::Internal;
68+
}
69+
auto audio_encoder_outputs = audio_encoder_result.get();
70+
71+
encoder_output = audio_encoder_outputs[0];
5472
} else if (input.is_text()) {
55-
// For text input, we don't need to run the image encoder.
56-
// Instead, we run the text encoder to get the encoder output.
5773
auto& text = input.get_text();
5874
std::vector<uint64_t> tokens =
5975
ET_UNWRAP_TOKENIZER(tokenizer_->encode(text));
76+
6077
auto text_tensor = executorch::extension::from_blob(
6178
tokens.data(),
6279
{1, static_cast<aten::SizesType>(tokens.size())},
6380
::executorch::aten::ScalarType::Long);
6481

65-
// Run token embedding
82+
// Run text encoder (token embeddings)
6683
auto token_embedding_outputs =
6784
ET_UNWRAP(module_->execute(kTokenEmbeddingMethod, text_tensor));
6885

6986
encoder_output = token_embedding_outputs[0];
7087
} else {
7188
ET_LOG(Error, "Unsupported input type");
72-
// For all other input types (e.g., audio), return error
89+
// For any other input types, return error
7390
return ::executorch::runtime::Error::NotSupported;
7491
}
7592

76-
auto outputs_res =
77-
ET_UNWRAP(text_decoder_runner_->decode(encoder_output, start_pos));
93+
// 2. Run decoder model for prefill.
94+
// `cache_position` goes from start_pos to start_pos + encoder_output.size(1).
95+
// e.g. if start_pos = 2 and encoder_output.size(1) = 5,
96+
// cache_position_tensor should be [2, 3, 4, 5, 6].
97+
int64_t seq_len = encoder_output.toTensor().size(1);
98+
std::vector<int64_t> cache_positions(seq_len);
99+
for (int64_t i = 0; i < seq_len; ++i) {
100+
cache_positions[i] = start_pos + i;
101+
}
102+
auto cache_position_tensor = ::executorch::extension::from_blob(
103+
cache_positions.data(), {seq_len}, executorch::aten::ScalarType::Long);
104+
auto prefill_result = module_->execute(
105+
kTextModelMethod, {cache_position_tensor, encoder_output});
106+
if (prefill_result.error() != ::executorch::runtime::Error::Ok) {
107+
return prefill_result.error();
108+
}
109+
auto prefill_outputs = prefill_result.get();
110+
auto outputs_res = prefill_outputs[0].toTensor();
78111

79-
// Update the start_pos, which is only available inside this function.
80-
// outputs_res can have only one logits.
112+
// Update start_pos, tracking the current cache position.
81113
start_pos += encoder_output.toTensor().size(1);
82114

83115
return static_cast<uint64_t>(
@@ -103,6 +135,11 @@ ::executorch::runtime::Error MultimodalPrefiller::load() {
103135
if (methods.find(kImageEncoderMethod) != methods.end()) {
104136
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kImageEncoderMethod));
105137
}
138+
139+
if (methods.find(kAudioEncoderMethod) != methods.end()) {
140+
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kAudioEncoderMethod));
141+
}
142+
106143
return ::executorch::runtime::Error::Ok;
107144
}
108145

0 commit comments

Comments
 (0)