1111
1212#pragma once
1313
14+ #include < executorch/extension/llm/runner/audio.h>
1415#include < executorch/extension/llm/runner/image.h>
1516#include < executorch/runtime/platform/compiler.h>
1617#include < string>
1920namespace executorch ::extension::llm {
2021
2122/* *
22- * A generic class to hold either image or text data for multimodal inputs.
23- * This allows the generate() API to take a std::vector of these objects
24- * instead of separate image and text parameters.
23+ * A generic class to hold either image, text, or audio data for multimodal
24+ * inputs. This allows the generate() API to take a std::vector of these objects
25+ * instead of separate image, text, and audio parameters.
2526 */
2627class ET_EXPERIMENTAL MultimodalInput {
2728 public:
28- enum class Type { TEXT, IMAGE };
29+ enum class Type { TEXT, IMAGE, AUDIO, RAW_AUDIO };
2930
3031 // Constructors
3132 explicit MultimodalInput (const std::string& text) : data_(text) {}
3233 explicit MultimodalInput (std::string&& text) : data_(std::move(text)) {}
3334 explicit MultimodalInput (const Image& image) : data_(image) {}
3435 explicit MultimodalInput (Image&& image) : data_(std::move(image)) {}
36+ explicit MultimodalInput (const Audio& audio) : data_(audio) {}
37+ explicit MultimodalInput (Audio&& audio) : data_(std::move(audio)) {}
38+ explicit MultimodalInput (const RawAudio& raw_audio) : data_(raw_audio) {}
39+ explicit MultimodalInput (RawAudio&& raw_audio)
40+ : data_(std::move(raw_audio)) {}
3541
3642 // Copy constructor and assignment
3743 MultimodalInput (const MultimodalInput& other) = default ;
@@ -60,12 +66,35 @@ class ET_EXPERIMENTAL MultimodalInput {
6066 return std::holds_alternative<Image>(data_);
6167 }
6268
69+ /* *
70+ * Check if this input contains audio data.
71+ * @return true if this input contains audio, false otherwise.
72+ */
73+ bool is_audio () const noexcept {
74+ return std::holds_alternative<Audio>(data_);
75+ }
76+
77+ /* *
78+ * Check if this input contains raw audio data.
79+ * @return true if this input contains raw audio, false otherwise.
80+ */
81+ bool is_raw_audio () const noexcept {
82+ return std::holds_alternative<RawAudio>(data_);
83+ }
84+
6385 /* *
6486 * Get the type of data stored in this input.
65- * @return Type::TEXT if text data, Type::IMAGE if image data.
87+ * @return Type::TEXT if text data, Type::IMAGE if image data, Type::AUDIO if
88+ * audio data, Type::RAW_AUDIO if raw audio data.
6689 */
6790 Type get_type () const noexcept {
68- return is_text () ? Type::TEXT : Type::IMAGE;
91+ if (is_text ())
92+ return Type::TEXT;
93+ if (is_image ())
94+ return Type::IMAGE;
95+ if (is_audio ())
96+ return Type::AUDIO;
97+ return Type::RAW_AUDIO;
6998 }
7099
71100 /* *
@@ -122,6 +151,60 @@ class ET_EXPERIMENTAL MultimodalInput {
122151 return std::get<Image>(std::move (data_));
123152 }
124153
154+ /* *
155+ * Get the audio data from this input.
156+ * @return Reference to the stored Audio object.
157+ * @throws std::bad_variant_access if this input doesn't contain audio.
158+ */
159+ const Audio& get_audio () const & {
160+ return std::get<Audio>(data_);
161+ }
162+
163+ /* *
164+ * Get the audio data from this input (mutable version).
165+ * @return Mutable reference to the stored Audio object.
166+ * @throws std::bad_variant_access if this input doesn't contain audio.
167+ */
168+ Audio& get_audio () & {
169+ return std::get<Audio>(data_);
170+ }
171+
172+ /* *
173+ * Get the audio data from this input (rvalue version).
174+ * @return Rvalue reference to the stored Audio object for efficient moves.
175+ * @throws std::bad_variant_access if this input doesn't contain audio.
176+ */
177+ Audio&& get_audio() && {
178+ return std::get<Audio>(std::move (data_));
179+ }
180+
181+ /* *
182+ * Get the raw audio data from this input.
183+ * @return Reference to the stored RawAudio object.
184+ * @throws std::bad_variant_access if this input doesn't contain raw audio.
185+ */
186+ const RawAudio& get_raw_audio () const & {
187+ return std::get<RawAudio>(data_);
188+ }
189+
190+ /* *
191+ * Get the raw audio data from this input (mutable version).
192+ * @return Mutable reference to the stored RawAudio object.
193+ * @throws std::bad_variant_access if this input doesn't contain raw audio.
194+ */
195+ RawAudio& get_raw_audio () & {
196+ return std::get<RawAudio>(data_);
197+ }
198+
199+ /* *
200+ * Get the raw audio data from this input (rvalue version).
201+ * @return Rvalue reference to the stored RawAudio object for efficient moves.
202+ * @throws std::bad_variant_access if this input doesn't contain raw audio.
203+ */
204+ RawAudio&& get_raw_audio() && {
205+ return std::get<RawAudio>(std::move (data_));
206+ }
207+
125208 /* *
126209 * Try to get the text data from this input safely.
127210 * @return Pointer to the text string if this input contains text, nullptr
@@ -158,8 +241,44 @@ class ET_EXPERIMENTAL MultimodalInput {
158241 return std::get_if<Image>(&data_);
159242 }
160243
244+ /* *
245+ * Try to get the audio data from this input safely.
246+ * @return Pointer to the Audio object if this input contains audio,
247+ * nullptr otherwise.
248+ */
249+ const Audio* try_get_audio () const noexcept {
250+ return std::get_if<Audio>(&data_);
251+ }
252+
253+ /* *
254+ * Try to get the audio data from this input safely (mutable version).
255+ * @return Pointer to the Audio object if this input contains audio,
256+ * nullptr otherwise.
257+ */
258+ Audio* try_get_audio () noexcept {
259+ return std::get_if<Audio>(&data_);
260+ }
261+
262+ /* *
263+ * Try to get the raw audio data from this input safely.
264+ * @return Pointer to the RawAudio object if this input contains raw audio,
265+ * nullptr otherwise.
266+ */
267+ const RawAudio* try_get_raw_audio () const noexcept {
268+ return std::get_if<RawAudio>(&data_);
269+ }
270+
271+ /* *
272+ * Try to get the raw audio data from this input safely (mutable version).
273+ * @return Pointer to the RawAudio object if this input contains raw audio,
274+ * nullptr otherwise.
275+ */
276+ RawAudio* try_get_raw_audio () noexcept {
277+ return std::get_if<RawAudio>(&data_);
278+ }
279+
161280 private:
162- std::variant<std::string, Image> data_;
281+ std::variant<std::string, Image, Audio, RawAudio > data_;
163282};
164283
165284// Convenience factory functions
@@ -179,4 +298,21 @@ inline MultimodalInput make_image_input(Image&& image) noexcept {
179298 return MultimodalInput (std::move (image));
180299}
181300
182- } // namespace executorch::extension::llm
301+ inline MultimodalInput make_audio_input (const Audio& audio) noexcept {
302+ return MultimodalInput (audio);
303+ }
304+
305+ inline MultimodalInput make_audio_input (Audio&& audio) noexcept {
306+ return MultimodalInput (std::move (audio));
307+ }
308+
309+ inline MultimodalInput make_raw_audio_input (
310+ const RawAudio& raw_audio) noexcept {
311+ return MultimodalInput (raw_audio);
312+ }
313+
314+ inline MultimodalInput make_raw_audio_input (RawAudio&& raw_audio) noexcept {
315+ return MultimodalInput (std::move (raw_audio));
316+ }
317+
318+ } // namespace executorch::extension::llm
0 commit comments