Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 29 additions & 14 deletions cmd/generate/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@ import (
)

type generateCommandHandler struct {
ctx context.Context
cfg *command.Config
client azuremodels.Client
options *PromptPexOptions
promptFile string
org string
sessionFile *string
ctx context.Context
cfg *command.Config
client azuremodels.Client
options *PromptPexOptions
promptFile string
org string
sessionFile *string
templateVars map[string]string
}

// NewGenerateCommand returns a new command to generate tests using PromptPex.
Expand All @@ -37,6 +38,7 @@ func NewGenerateCommand(cfg *command.Config) *cobra.Command {
gh models generate prompt.yml
gh models generate --org my-org --groundtruth-model "openai/gpt-4.1" prompt.yml
gh models generate --session-file prompt.session.json prompt.yml
gh models generate --var name=Alice --var topic="machine learning" prompt.yml
`),
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
Expand All @@ -50,6 +52,17 @@ func NewGenerateCommand(cfg *command.Config) *cobra.Command {
return fmt.Errorf("failed to parse flags: %w", err)
}

// Parse template variables from flags
templateVars, err := util.ParseTemplateVariables(cmd.Flags())
if err != nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensure that "input" is not a template bars key

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added validation to prevent "input" from being used as a template variable key. The util.ParseTemplateVariables function now returns a clear error: "'input' is a reserved variable name and cannot be used with --var". This prevents conflicts with the built-in "input" variable used for backward compatibility. Commit: b01b3b7

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input is allowed for "run" command but not for "generate" command. Move the check to the generate command code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved the "input" validation from the shared utility function to the generate command code. Now the run command allows "input" as a template variable while the generate command still prevents it to avoid conflicts with the built-in variable. Commit: 7a60a81

return err
}

// Check for reserved keys specific to generate command
if _, exists := templateVars["input"]; exists {
return fmt.Errorf("'input' is a reserved variable name and cannot be used with --var")
}

// Get organization
org, _ := cmd.Flags().GetString("org")

Expand All @@ -67,13 +80,14 @@ func NewGenerateCommand(cfg *command.Config) *cobra.Command {

// Create the command handler
handler := &generateCommandHandler{
ctx: ctx,
cfg: cfg,
client: cfg.Client,
options: options,
promptFile: promptFile,
org: org,
sessionFile: util.Ptr(sessionFile),
ctx: ctx,
cfg: cfg,
client: cfg.Client,
options: options,
promptFile: promptFile,
org: org,
sessionFile: util.Ptr(sessionFile),
templateVars: templateVars,
}

// Create context
Expand Down Expand Up @@ -105,6 +119,7 @@ func AddCommandLineFlags(cmd *cobra.Command) {
flags.String("effort", "", "Effort level (low, medium, high)")
flags.String("groundtruth-model", "", "Model to use for generating groundtruth outputs. Defaults to openai/gpt-4o. Use 'none' to disable groundtruth generation.")
flags.String("session-file", "", "Session file to load existing context from")
flags.StringSlice("var", []string{}, "Template variables for prompt files (can be used multiple times: --var name=value)")

// Custom instruction flags for each phase
flags.String("instruction-intent", "", "Custom system instruction for intent generation phase")
Expand Down
127 changes: 127 additions & 0 deletions cmd/generate/generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ import (
"testing"

"github.com/github/gh-models/internal/azuremodels"
"github.com/github/gh-models/internal/sse"
"github.com/github/gh-models/pkg/command"
"github.com/github/gh-models/pkg/util"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -393,3 +395,128 @@ messages:
require.Contains(t, err.Error(), "failed to load prompt file")
})
}

func TestGenerateCommandWithTemplateVariables(t *testing.T) {
t.Run("parse template variables in command handler", func(t *testing.T) {
client := azuremodels.NewMockClient()
cfg := command.NewConfig(new(bytes.Buffer), new(bytes.Buffer), client, true, 100)

cmd := NewGenerateCommand(cfg)
args := []string{
"--var", "name=Bob",
"--var", "location=Seattle",
"dummy.yml",
}

// Parse flags without executing
err := cmd.ParseFlags(args[:len(args)-1]) // Exclude positional arg
require.NoError(t, err)

// Test that the util.ParseTemplateVariables function works correctly
templateVars, err := util.ParseTemplateVariables(cmd.Flags())
require.NoError(t, err)
require.Equal(t, map[string]string{
"name": "Bob",
"location": "Seattle",
}, templateVars)
})

t.Run("runSingleTestWithContext applies template variables", func(t *testing.T) {
// Create test prompt file with template variables
const yamlBody = `
name: Template Variable Test
description: Test prompt with template variables
model: openai/gpt-4o-mini
messages:
- role: system
content: "You are a helpful assistant for {{name}}."
- role: user
content: "Tell me about {{topic}} in {{style}} style."
`

tmpDir := t.TempDir()
promptFile := filepath.Join(tmpDir, "test.prompt.yml")
err := os.WriteFile(promptFile, []byte(yamlBody), 0644)
require.NoError(t, err)

// Setup mock client to capture template-rendered messages
var capturedOptions azuremodels.ChatCompletionOptions
client := azuremodels.NewMockClient()
client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) {
capturedOptions = opt

// Create a proper mock response with reader
mockResponse := "test response"
mockCompletion := azuremodels.ChatCompletion{
Choices: []azuremodels.ChatChoice{
{
Message: &azuremodels.ChatChoiceMessage{
Content: &mockResponse,
},
},
},
}

return &azuremodels.ChatCompletionResponse{
Reader: sse.NewMockEventReader([]azuremodels.ChatCompletion{mockCompletion}),
}, nil
}

out := new(bytes.Buffer)
cfg := command.NewConfig(out, out, client, true, 100)

// Create handler with template variables
templateVars := map[string]string{
"name": "Alice",
"topic": "machine learning",
"style": "academic",
}

handler := &generateCommandHandler{
ctx: context.Background(),
cfg: cfg,
client: client,
options: GetDefaultOptions(),
promptFile: promptFile,
org: "",
templateVars: templateVars,
}

// Create context from prompt
promptCtx, err := handler.CreateContextFromPrompt()
require.NoError(t, err)

// Call runSingleTestWithContext directly
_, err = handler.runSingleTestWithContext("test input", "openai/gpt-4o-mini", promptCtx)
require.NoError(t, err)

// Verify that template variables were applied correctly
require.NotNil(t, capturedOptions.Messages)
require.Len(t, capturedOptions.Messages, 2)

// Check system message
systemMsg := capturedOptions.Messages[0]
require.Equal(t, azuremodels.ChatMessageRoleSystem, systemMsg.Role)
require.NotNil(t, systemMsg.Content)
require.Contains(t, *systemMsg.Content, "helpful assistant for Alice")

// Check user message
userMsg := capturedOptions.Messages[1]
require.Equal(t, azuremodels.ChatMessageRoleUser, userMsg.Role)
require.NotNil(t, userMsg.Content)
require.Contains(t, *userMsg.Content, "about machine learning")
require.Contains(t, *userMsg.Content, "academic style")
})

t.Run("rejects input as template variable", func(t *testing.T) {
client := azuremodels.NewMockClient()
cfg := command.NewConfig(new(bytes.Buffer), new(bytes.Buffer), client, true, 100)

cmd := NewGenerateCommand(cfg)
cmd.SetArgs([]string{"--var", "input=test", "dummy.yml"})

err := cmd.Execute()
require.Error(t, err)
require.Contains(t, err.Error(), "'input' is a reserved variable name and cannot be used with --var")
})
}
8 changes: 8 additions & 0 deletions cmd/generate/pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,15 @@ func (h *generateCommandHandler) runSingleTestWithContext(input string, modelNam
openaiMessages := []azuremodels.ChatMessage{}
for _, msg := range messages {
templateData := make(map[string]interface{})

// Add the input variable (backward compatibility)
templateData["input"] = input

// Add custom variables
for key, value := range h.templateVars {
templateData[key] = value
}

// Replace template variables in content
content, err := prompt.TemplateString(msg.Content, templateData)
if err != nil {
Expand Down
39 changes: 1 addition & 38 deletions cmd/run/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ func NewRunCommand(cfg *command.Config) *cobra.Command {
}

// Parse template variables from flags
templateVars, err := parseTemplateVariables(cmd.Flags())
templateVars, err := util.ParseTemplateVariables(cmd.Flags())
if err != nil {
return err
}
Expand Down Expand Up @@ -427,43 +427,6 @@ func NewRunCommand(cfg *command.Config) *cobra.Command {
return cmd
}

// parseTemplateVariables parses template variables from the --var flags
func parseTemplateVariables(flags *pflag.FlagSet) (map[string]string, error) {
varFlags, err := flags.GetStringSlice("var")
if err != nil {
return nil, err
}

templateVars := make(map[string]string)
for _, varFlag := range varFlags {
// Handle empty strings
if strings.TrimSpace(varFlag) == "" {
continue
}

parts := strings.SplitN(varFlag, "=", 2)
if len(parts) != 2 {
return nil, fmt.Errorf("invalid variable format '%s', expected 'key=value'", varFlag)
}

key := strings.TrimSpace(parts[0])
value := parts[1] // Don't trim value to preserve intentional whitespace

if key == "" {
return nil, fmt.Errorf("variable key cannot be empty in '%s'", varFlag)
}

// Check for duplicate keys
if _, exists := templateVars[key]; exists {
return nil, fmt.Errorf("duplicate variable key '%s'", key)
}

templateVars[key] = value
}

return templateVars, nil
}

type runCommandHandler struct {
ctx context.Context
cfg *command.Config
Expand Down
7 changes: 6 additions & 1 deletion cmd/run/run_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -470,14 +470,19 @@ func TestParseTemplateVariables(t *testing.T) {
varFlags: []string{"name=John", "name=Jane"},
expectErr: true,
},
{
name: "input variable is allowed in run command",
varFlags: []string{"input=test value"},
expected: map[string]string{"input": "test value"},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
flags := pflag.NewFlagSet("test", pflag.ContinueOnError)
flags.StringSlice("var", tt.varFlags, "test flag")

result, err := parseTemplateVariables(flags)
result, err := util.ParseTemplateVariables(flags)

if tt.expectErr {
require.Error(t, err)
Expand Down
40 changes: 40 additions & 0 deletions pkg/util/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ package util
import (
"fmt"
"io"
"strings"

"github.com/spf13/pflag"
)

// WriteToOut writes a message to the given io.Writer.
Expand All @@ -18,3 +21,40 @@ func WriteToOut(out io.Writer, message string) {
func Ptr[T any](value T) *T {
return &value
}

// ParseTemplateVariables parses template variables from the --var flags
func ParseTemplateVariables(flags *pflag.FlagSet) (map[string]string, error) {
varFlags, err := flags.GetStringSlice("var")
if err != nil {
return nil, err
}

templateVars := make(map[string]string)
for _, varFlag := range varFlags {
// Handle empty strings
if strings.TrimSpace(varFlag) == "" {
continue
}

parts := strings.SplitN(varFlag, "=", 2)
if len(parts) != 2 {
return nil, fmt.Errorf("invalid variable format '%s', expected 'key=value'", varFlag)
}

key := strings.TrimSpace(parts[0])
value := parts[1] // Don't trim value to preserve intentional whitespace

if key == "" {
return nil, fmt.Errorf("variable key cannot be empty in '%s'", varFlag)
}

// Check for duplicate keys
if _, exists := templateVars[key]; exists {
return nil, fmt.Errorf("duplicate variable key '%s'", key)
}

templateVars[key] = value
}

return templateVars, nil
}
Loading
Loading