@@ -19,6 +19,8 @@ limitations under the License.
1919package llmdinferencesim
2020
2121import (
22+ "encoding/json"
23+ "errors"
2224 "fmt"
2325 "strings"
2426 "sync"
@@ -175,7 +177,65 @@ type message struct {
175177 // Role is the message Role, optional values are 'user', 'assistant', ...
176178 Role string `json:"role,omitempty"`
177179 // Content defines text of this message
178- Content string `json:"content,omitempty"`
180+ Content content `json:"content,omitempty"`
181+ }
182+
183+ type content struct {
184+ Raw string
185+ Structured []contentBlock
186+ }
187+
188+ type contentBlock struct {
189+ Type string `json:"type"`
190+ Text string `json:"text,omitempty"`
191+ ImageURL ImageBlock `json:"image_url,omitempty"`
192+ }
193+
194+ type ImageBlock struct {
195+ Url string `json:"url,omitempty"`
196+ }
197+
198+ // UnmarshalJSON allow use both format
199+ func (mc * content ) UnmarshalJSON (data []byte ) error {
200+ // Raw format
201+ var str string
202+ if err := json .Unmarshal (data , & str ); err == nil {
203+ mc .Raw = str
204+ return nil
205+ }
206+
207+ // Block format
208+ var blocks []contentBlock
209+ if err := json .Unmarshal (data , & blocks ); err == nil {
210+ mc .Structured = blocks
211+ return nil
212+ }
213+
214+ return errors .New ("content format not supported" )
215+ }
216+
217+ func (mc content ) MarshalJSON () ([]byte , error ) {
218+ if mc .Raw != "" {
219+ return json .Marshal (mc .Raw )
220+ }
221+ if mc .Structured != nil {
222+ return json .Marshal (mc .Structured )
223+ }
224+ return json .Marshal ("" )
225+ }
226+
227+ func (mc content ) PlainText () string {
228+ if mc .Raw != "" {
229+ return mc .Raw
230+ }
231+ var sb strings.Builder
232+ for _ , block := range mc .Structured {
233+ if block .Type == "text" {
234+ sb .WriteString (block .Text )
235+ sb .WriteString (" " )
236+ }
237+ }
238+ return sb .String ()
179239}
180240
181241// chatCompletionRequest defines structure of /chat/completion request
@@ -200,7 +260,7 @@ type chatCompletionRequest struct {
200260func (c * chatCompletionRequest ) getNumberOfPromptTokens () int {
201261 var messages string
202262 for _ , message := range c .Messages {
203- messages += message .Content + " "
263+ messages += message .Content . PlainText () + " "
204264 }
205265 return len (strings .Fields (messages ))
206266}
@@ -328,7 +388,7 @@ func (req textCompletionRequest) createResponseText(mode string) (string, string
328388func (req * chatCompletionRequest ) getLastUserMsg () string {
329389 for i := len (req .Messages ) - 1 ; i >= 0 ; i -- {
330390 if req .Messages [i ].Role == roleUser {
331- return req .Messages [i ].Content
391+ return req .Messages [i ].Content . PlainText ()
332392 }
333393 }
334394
0 commit comments