Skip to content

Commit bef085f

Browse files
authored
add support to multimodal in chat completions (#49)
* add support to multimodal in chat completions Signed-off-by: Juanma Barea <[email protected]> * fix lint errors Signed-off-by: Juanma Barea <[email protected]> * add image_url as a block Signed-off-by: Juanma Barea <[email protected]> * fix lint errors Signed-off-by: Juanma Barea <[email protected]> * add text separator to allow better tokens number calculation Signed-off-by: Juanma Barea <[email protected]> --------- Signed-off-by: Juanma Barea <[email protected]>
1 parent b72af25 commit bef085f

File tree

3 files changed

+65
-5
lines changed

3 files changed

+65
-5
lines changed

pkg/llm-d-inference-sim/defs.go

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ limitations under the License.
1919
package llmdinferencesim
2020

2121
import (
22+
"encoding/json"
23+
"errors"
2224
"fmt"
2325
"strings"
2426
"sync"
@@ -175,7 +177,65 @@ type message struct {
175177
// Role is the message Role, optional values are 'user', 'assistant', ...
176178
Role string `json:"role,omitempty"`
177179
// Content defines text of this message
178-
Content string `json:"content,omitempty"`
180+
Content content `json:"content,omitempty"`
181+
}
182+
183+
type content struct {
184+
Raw string
185+
Structured []contentBlock
186+
}
187+
188+
type contentBlock struct {
189+
Type string `json:"type"`
190+
Text string `json:"text,omitempty"`
191+
ImageURL ImageBlock `json:"image_url,omitempty"`
192+
}
193+
194+
type ImageBlock struct {
195+
Url string `json:"url,omitempty"`
196+
}
197+
198+
// UnmarshalJSON allow use both format
199+
func (mc *content) UnmarshalJSON(data []byte) error {
200+
// Raw format
201+
var str string
202+
if err := json.Unmarshal(data, &str); err == nil {
203+
mc.Raw = str
204+
return nil
205+
}
206+
207+
// Block format
208+
var blocks []contentBlock
209+
if err := json.Unmarshal(data, &blocks); err == nil {
210+
mc.Structured = blocks
211+
return nil
212+
}
213+
214+
return errors.New("content format not supported")
215+
}
216+
217+
func (mc content) MarshalJSON() ([]byte, error) {
218+
if mc.Raw != "" {
219+
return json.Marshal(mc.Raw)
220+
}
221+
if mc.Structured != nil {
222+
return json.Marshal(mc.Structured)
223+
}
224+
return json.Marshal("")
225+
}
226+
227+
func (mc content) PlainText() string {
228+
if mc.Raw != "" {
229+
return mc.Raw
230+
}
231+
var sb strings.Builder
232+
for _, block := range mc.Structured {
233+
if block.Type == "text" {
234+
sb.WriteString(block.Text)
235+
sb.WriteString(" ")
236+
}
237+
}
238+
return sb.String()
179239
}
180240

181241
// chatCompletionRequest defines structure of /chat/completion request
@@ -200,7 +260,7 @@ type chatCompletionRequest struct {
200260
func (c *chatCompletionRequest) getNumberOfPromptTokens() int {
201261
var messages string
202262
for _, message := range c.Messages {
203-
messages += message.Content + " "
263+
messages += message.Content.PlainText() + " "
204264
}
205265
return len(strings.Fields(messages))
206266
}
@@ -328,7 +388,7 @@ func (req textCompletionRequest) createResponseText(mode string) (string, string
328388
func (req *chatCompletionRequest) getLastUserMsg() string {
329389
for i := len(req.Messages) - 1; i >= 0; i-- {
330390
if req.Messages[i].Role == roleUser {
331-
return req.Messages[i].Content
391+
return req.Messages[i].Content.PlainText()
332392
}
333393
}
334394

pkg/llm-d-inference-sim/simulator.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ func (s *VllmSimulator) createCompletionResponse(isChatCompletion bool, respText
446446
baseResp.Object = chatCompletionObject
447447
return &chatCompletionResponse{
448448
baseCompletionResponse: baseResp,
449-
Choices: []chatRespChoice{{Message: message{Role: roleAssistant, Content: respText}, baseResponseChoice: baseChoice}},
449+
Choices: []chatRespChoice{{Message: message{Role: roleAssistant, Content: content{Raw: respText}}, baseResponseChoice: baseChoice}},
450450
}
451451
}
452452

pkg/llm-d-inference-sim/streaming.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ func (s *VllmSimulator) createCompletionChunk(isChatCompletion bool, creationTim
156156
chunk.Choices[0].Delta.Role = role
157157
}
158158
if len(token) > 0 {
159-
chunk.Choices[0].Delta.Content = token
159+
chunk.Choices[0].Delta.Content.Raw = token
160160
}
161161

162162
return &chunk

0 commit comments

Comments
 (0)