Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion go/ai/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,12 @@ type GenerateResponse struct {
Usage *GenerationUsage `json:"usage,omitempty"`
}

// A GenerateResponseChunk is the portion of the [GenerateResponse]
// that is passed to a streaming callback.
type GenerateResponseChunk struct {
Content []*Part `json:"content,omitempty"`
Custom any `json:"custom,omitempty"`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know the purpose of this field? Is it ever set in the JS?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Custom field? I don't think it is ever set in the JS.

Index float64 `json:"index,omitempty"`
Index int `json:"index,omitempty"`
}

// GenerationCommonConfig holds configuration for generation.
Expand Down
34 changes: 25 additions & 9 deletions go/ai/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ import (
)

// A ModelAction is used to generate content from an AI model.
type ModelAction = core.Action[*GenerateRequest, *GenerateResponse, *Candidate]
type ModelAction = core.Action[*GenerateRequest, *GenerateResponse, *GenerateResponseChunk]

// ModelStreamingCallback is the type for the streaming callback of a model.
type ModelStreamingCallback = func(context.Context, *Candidate) error
type ModelStreamingCallback = func(context.Context, *GenerateResponseChunk) error

// ModelCapabilities describes various capabilities of the model.
type ModelCapabilities struct {
Expand Down Expand Up @@ -72,7 +72,7 @@ func DefineModel(provider, name string, metadata *ModelMetadata, generate func(c
// LookupModel looks up a [ModelAction] registered by [DefineModel].
// It returns nil if the model was not defined.
func LookupModel(provider, name string) *ModelAction {
return core.LookupActionFor[*GenerateRequest, *GenerateResponse, *Candidate](atype.Model, provider, name)
return core.LookupActionFor[*GenerateRequest, *GenerateResponse, *GenerateResponseChunk](atype.Model, provider, name)
}

// Generate applies a [ModelAction] to some input, handling tool requests.
Expand Down Expand Up @@ -229,6 +229,23 @@ func (gr *GenerateResponse) Text() (string, error) {
return gr.Candidates[0].Text()
}

// Text returns the text content of the [GenerateResponseChunk]
// as a string. It returns an error if there is no Content
// in the response chunk.
func (c *GenerateResponseChunk) Text() (string, error) {
if len(c.Content) == 0 {
return "", errors.New("response chunk has no content")
}
if len(c.Content) == 1 {
return c.Content[0].Text, nil
}
var sb strings.Builder
for _, p := range c.Content {
sb.WriteString(p.Text)
}
return sb.String(), nil
}

// Text returns the contents of a [Candidate] as a string. It
// returns an error if the candidate has no message.
func (c *Candidate) Text() (string, error) {
Expand All @@ -241,11 +258,10 @@ func (c *Candidate) Text() (string, error) {
}
if len(msg.Content) == 1 {
return msg.Content[0].Text, nil
} else {
var sb strings.Builder
for _, p := range msg.Content {
sb.WriteString(p.Text)
}
return sb.String(), nil
}
var sb strings.Builder
for _, p := range msg.Content {
sb.WriteString(p.Text)
}
return sb.String(), nil
}
2 changes: 1 addition & 1 deletion go/ai/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type PromptRequest struct {
// Prompt is the interface used to execute a prompt template and
// pass the result to a [ModelAction].
type Prompt interface {
Generate(context.Context, *PromptRequest, func(context.Context, *Candidate) error) (*GenerateResponse, error)
Generate(context.Context, *PromptRequest, func(context.Context, *GenerateResponseChunk) error) (*GenerateResponse, error)
}

// RegisterPrompt registers a prompt in the global registry.
Expand Down
7 changes: 7 additions & 0 deletions go/core/schemas.config
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ GenerateRequest doc
A GenerateRequest is a request to generate completions from a model.
.

GenerateResponseChunk.index type int

GenerationCommonConfig.maxOutputTokens type int
GenerationCommonConfig.topK type int

Expand Down Expand Up @@ -184,6 +186,11 @@ GenerateResponse doc
A GenerateResponse is a model's response to a [GenerateRequest].
.

GenerateResponseChunk doc
A GenerateResponseChunk is the portion of the [GenerateResponse]
that is passed to a streaming callback.
.

GenerationCommonConfig doc
GenerationCommonConfig holds configuration for generation.
.
Expand Down
2 changes: 1 addition & 1 deletion go/plugins/dotprompt/genkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ func (p *Prompt) Register() error {
// the prompt.
//
// This implements the [ai.Prompt] interface.
func (p *Prompt) Generate(ctx context.Context, pr *ai.PromptRequest, cb func(context.Context, *ai.Candidate) error) (*ai.GenerateResponse, error) {
func (p *Prompt) Generate(ctx context.Context, pr *ai.PromptRequest, cb func(context.Context, *ai.GenerateResponseChunk) error) (*ai.GenerateResponse, error) {
tracing.SetCustomMetadataAttr(ctx, "subtype", "prompt")

genReq, err := p.buildRequest(pr)
Expand Down
2 changes: 1 addition & 1 deletion go/plugins/dotprompt/genkit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (
"github.com/firebase/genkit/go/ai"
)

func testGenerate(ctx context.Context, req *ai.GenerateRequest, cb func(context.Context, *ai.Candidate) error) (*ai.GenerateResponse, error) {
func testGenerate(ctx context.Context, req *ai.GenerateRequest, cb func(context.Context, *ai.GenerateResponseChunk) error) (*ai.GenerateResponse, error) {
input := req.Messages[0].Content[0].Text
output := fmt.Sprintf("AI reply to %q", input)

Expand Down
8 changes: 6 additions & 2 deletions go/plugins/googleai/googleai.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ type generator struct {
//session *genai.ChatSession // non-nil if we're in the middle of a chat
}

func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb func(context.Context, *ai.Candidate) error) (*ai.GenerateResponse, error) {
func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb func(context.Context, *ai.GenerateResponseChunk) error) (*ai.GenerateResponse, error) {
gm := g.client.GenerativeModel(g.model)

// Translate from a ai.GenerateRequest to a genai request.
Expand Down Expand Up @@ -228,7 +228,11 @@ func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb
}
// Send candidates to the callback.
for _, c := range chunk.Candidates {
err := cb(ctx, translateCandidate(c))
tc := translateCandidate(c)
err := cb(ctx, &ai.GenerateResponseChunk{
Content: tc.Message.Content,
Index: tc.Index,
})
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions go/plugins/googleai/googleai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@ func TestLive(t *testing.T) {
out := ""
parts := 0
g := googleai.Model(generativeModel)
_, err = ai.Generate(ctx, g, req, func(ctx context.Context, c *ai.Candidate) error {
_, err = ai.Generate(ctx, g, req, func(ctx context.Context, c *ai.GenerateResponseChunk) error {
parts++
out += c.Message.Content[0].Text
out += c.Content[0].Text
return nil
})
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion go/plugins/vertexai/vertexai.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ type generator struct {
client *genai.Client
}

func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb func(context.Context, *ai.Candidate) error) (*ai.GenerateResponse, error) {
func (g *generator) generate(ctx context.Context, input *ai.GenerateRequest, cb func(context.Context, *ai.GenerateResponseChunk) error) (*ai.GenerateResponse, error) {
if cb != nil {
panic("streaming not supported yet") // TODO: streaming
}
Expand Down
8 changes: 4 additions & 4 deletions go/samples/coffee-shop/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,9 @@ func main() {
}

simpleGreetingFlow := genkit.DefineFlow("simpleGreeting", func(ctx context.Context, input *simpleGreetingInput, cb func(context.Context, string) error) (string, error) {
var callback func(context.Context, *ai.Candidate) error
var callback func(context.Context, *ai.GenerateResponseChunk) error
if cb != nil {
callback = func(ctx context.Context, c *ai.Candidate) error {
callback = func(ctx context.Context, c *ai.GenerateResponseChunk) error {
text, err := c.Text()
if err != nil {
return err
Expand Down Expand Up @@ -205,9 +205,9 @@ func main() {
}

genkit.DefineFlow("simpleStructuredGreeting", func(ctx context.Context, input *simpleGreetingInput, cb func(context.Context, string) error) (string, error) {
var callback func(context.Context, *ai.Candidate) error
var callback func(context.Context, *ai.GenerateResponseChunk) error
if cb != nil {
callback = func(ctx context.Context, c *ai.Candidate) error {
callback = func(ctx context.Context, c *ai.GenerateResponseChunk) error {
text, err := c.Text()
if err != nil {
return err
Expand Down