diff --git a/cmd/generate/generate.go b/cmd/generate/generate.go index be2cf91f..f5864227 100644 --- a/cmd/generate/generate.go +++ b/cmd/generate/generate.go @@ -13,13 +13,14 @@ import ( ) type generateCommandHandler struct { - ctx context.Context - cfg *command.Config - client azuremodels.Client - options *PromptPexOptions - promptFile string - org string - sessionFile *string + ctx context.Context + cfg *command.Config + client azuremodels.Client + options *PromptPexOptions + promptFile string + org string + sessionFile *string + templateVars map[string]string } // NewGenerateCommand returns a new command to generate tests using PromptPex. @@ -37,6 +38,7 @@ func NewGenerateCommand(cfg *command.Config) *cobra.Command { gh models generate prompt.yml gh models generate --org my-org --groundtruth-model "openai/gpt-4.1" prompt.yml gh models generate --session-file prompt.session.json prompt.yml + gh models generate --var name=Alice --var topic="machine learning" prompt.yml `), Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { @@ -50,6 +52,17 @@ func NewGenerateCommand(cfg *command.Config) *cobra.Command { return fmt.Errorf("failed to parse flags: %w", err) } + // Parse template variables from flags + templateVars, err := util.ParseTemplateVariables(cmd.Flags()) + if err != nil { + return err + } + + // Check for reserved keys specific to generate command + if _, exists := templateVars["input"]; exists { + return fmt.Errorf("'input' is a reserved variable name and cannot be used with --var") + } + // Get organization org, _ := cmd.Flags().GetString("org") @@ -67,13 +80,14 @@ func NewGenerateCommand(cfg *command.Config) *cobra.Command { // Create the command handler handler := &generateCommandHandler{ - ctx: ctx, - cfg: cfg, - client: cfg.Client, - options: options, - promptFile: promptFile, - org: org, - sessionFile: util.Ptr(sessionFile), + ctx: ctx, + cfg: cfg, + client: cfg.Client, + options: options, + promptFile: promptFile, + org: org, + sessionFile: util.Ptr(sessionFile), + templateVars: templateVars, } // Create context @@ -105,6 +119,7 @@ func AddCommandLineFlags(cmd *cobra.Command) { flags.String("effort", "", "Effort level (low, medium, high)") flags.String("groundtruth-model", "", "Model to use for generating groundtruth outputs. Defaults to openai/gpt-4o. Use 'none' to disable groundtruth generation.") flags.String("session-file", "", "Session file to load existing context from") + flags.StringSlice("var", []string{}, "Template variables for prompt files (can be used multiple times: --var name=value)") // Custom instruction flags for each phase flags.String("instruction-intent", "", "Custom system instruction for intent generation phase") diff --git a/cmd/generate/generate_test.go b/cmd/generate/generate_test.go index 05e05cbd..b0f81d47 100644 --- a/cmd/generate/generate_test.go +++ b/cmd/generate/generate_test.go @@ -11,7 +11,9 @@ import ( "testing" "github.com/github/gh-models/internal/azuremodels" + "github.com/github/gh-models/internal/sse" "github.com/github/gh-models/pkg/command" + "github.com/github/gh-models/pkg/util" "github.com/stretchr/testify/require" ) @@ -393,3 +395,128 @@ messages: require.Contains(t, err.Error(), "failed to load prompt file") }) } + +func TestGenerateCommandWithTemplateVariables(t *testing.T) { + t.Run("parse template variables in command handler", func(t *testing.T) { + client := azuremodels.NewMockClient() + cfg := command.NewConfig(new(bytes.Buffer), new(bytes.Buffer), client, true, 100) + + cmd := NewGenerateCommand(cfg) + args := []string{ + "--var", "name=Bob", + "--var", "location=Seattle", + "dummy.yml", + } + + // Parse flags without executing + err := cmd.ParseFlags(args[:len(args)-1]) // Exclude positional arg + require.NoError(t, err) + + // Test that the util.ParseTemplateVariables function works correctly + templateVars, err := util.ParseTemplateVariables(cmd.Flags()) + require.NoError(t, err) + require.Equal(t, map[string]string{ + "name": "Bob", + "location": "Seattle", + }, templateVars) + }) + + t.Run("runSingleTestWithContext applies template variables", func(t *testing.T) { + // Create test prompt file with template variables + const yamlBody = ` +name: Template Variable Test +description: Test prompt with template variables +model: openai/gpt-4o-mini +messages: + - role: system + content: "You are a helpful assistant for {{name}}." + - role: user + content: "Tell me about {{topic}} in {{style}} style." +` + + tmpDir := t.TempDir() + promptFile := filepath.Join(tmpDir, "test.prompt.yml") + err := os.WriteFile(promptFile, []byte(yamlBody), 0644) + require.NoError(t, err) + + // Setup mock client to capture template-rendered messages + var capturedOptions azuremodels.ChatCompletionOptions + client := azuremodels.NewMockClient() + client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions, org string) (*azuremodels.ChatCompletionResponse, error) { + capturedOptions = opt + + // Create a proper mock response with reader + mockResponse := "test response" + mockCompletion := azuremodels.ChatCompletion{ + Choices: []azuremodels.ChatChoice{ + { + Message: &azuremodels.ChatChoiceMessage{ + Content: &mockResponse, + }, + }, + }, + } + + return &azuremodels.ChatCompletionResponse{ + Reader: sse.NewMockEventReader([]azuremodels.ChatCompletion{mockCompletion}), + }, nil + } + + out := new(bytes.Buffer) + cfg := command.NewConfig(out, out, client, true, 100) + + // Create handler with template variables + templateVars := map[string]string{ + "name": "Alice", + "topic": "machine learning", + "style": "academic", + } + + handler := &generateCommandHandler{ + ctx: context.Background(), + cfg: cfg, + client: client, + options: GetDefaultOptions(), + promptFile: promptFile, + org: "", + templateVars: templateVars, + } + + // Create context from prompt + promptCtx, err := handler.CreateContextFromPrompt() + require.NoError(t, err) + + // Call runSingleTestWithContext directly + _, err = handler.runSingleTestWithContext("test input", "openai/gpt-4o-mini", promptCtx) + require.NoError(t, err) + + // Verify that template variables were applied correctly + require.NotNil(t, capturedOptions.Messages) + require.Len(t, capturedOptions.Messages, 2) + + // Check system message + systemMsg := capturedOptions.Messages[0] + require.Equal(t, azuremodels.ChatMessageRoleSystem, systemMsg.Role) + require.NotNil(t, systemMsg.Content) + require.Contains(t, *systemMsg.Content, "helpful assistant for Alice") + + // Check user message + userMsg := capturedOptions.Messages[1] + require.Equal(t, azuremodels.ChatMessageRoleUser, userMsg.Role) + require.NotNil(t, userMsg.Content) + require.Contains(t, *userMsg.Content, "about machine learning") + require.Contains(t, *userMsg.Content, "academic style") + }) + + t.Run("rejects input as template variable", func(t *testing.T) { + client := azuremodels.NewMockClient() + cfg := command.NewConfig(new(bytes.Buffer), new(bytes.Buffer), client, true, 100) + + cmd := NewGenerateCommand(cfg) + cmd.SetArgs([]string{"--var", "input=test", "dummy.yml"}) + + err := cmd.Execute() + require.Error(t, err) + require.Contains(t, err.Error(), "'input' is a reserved variable name and cannot be used with --var") + }) +} diff --git a/cmd/generate/pipeline.go b/cmd/generate/pipeline.go index 1a6615cd..673782f9 100644 --- a/cmd/generate/pipeline.go +++ b/cmd/generate/pipeline.go @@ -460,7 +460,15 @@ func (h *generateCommandHandler) runSingleTestWithContext(input string, modelNam openaiMessages := []azuremodels.ChatMessage{} for _, msg := range messages { templateData := make(map[string]interface{}) + + // Add the input variable (backward compatibility) templateData["input"] = input + + // Add custom variables + for key, value := range h.templateVars { + templateData[key] = value + } + // Replace template variables in content content, err := prompt.TemplateString(msg.Content, templateData) if err != nil { diff --git a/cmd/run/run.go b/cmd/run/run.go index 2d90da4f..6a33218f 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -236,7 +236,7 @@ func NewRunCommand(cfg *command.Config) *cobra.Command { } // Parse template variables from flags - templateVars, err := parseTemplateVariables(cmd.Flags()) + templateVars, err := util.ParseTemplateVariables(cmd.Flags()) if err != nil { return err } @@ -427,43 +427,6 @@ func NewRunCommand(cfg *command.Config) *cobra.Command { return cmd } -// parseTemplateVariables parses template variables from the --var flags -func parseTemplateVariables(flags *pflag.FlagSet) (map[string]string, error) { - varFlags, err := flags.GetStringSlice("var") - if err != nil { - return nil, err - } - - templateVars := make(map[string]string) - for _, varFlag := range varFlags { - // Handle empty strings - if strings.TrimSpace(varFlag) == "" { - continue - } - - parts := strings.SplitN(varFlag, "=", 2) - if len(parts) != 2 { - return nil, fmt.Errorf("invalid variable format '%s', expected 'key=value'", varFlag) - } - - key := strings.TrimSpace(parts[0]) - value := parts[1] // Don't trim value to preserve intentional whitespace - - if key == "" { - return nil, fmt.Errorf("variable key cannot be empty in '%s'", varFlag) - } - - // Check for duplicate keys - if _, exists := templateVars[key]; exists { - return nil, fmt.Errorf("duplicate variable key '%s'", key) - } - - templateVars[key] = value - } - - return templateVars, nil -} - type runCommandHandler struct { ctx context.Context cfg *command.Config diff --git a/cmd/run/run_test.go b/cmd/run/run_test.go index 94db2b63..7b21a06c 100644 --- a/cmd/run/run_test.go +++ b/cmd/run/run_test.go @@ -470,6 +470,11 @@ func TestParseTemplateVariables(t *testing.T) { varFlags: []string{"name=John", "name=Jane"}, expectErr: true, }, + { + name: "input variable is allowed in run command", + varFlags: []string{"input=test value"}, + expected: map[string]string{"input": "test value"}, + }, } for _, tt := range tests { @@ -477,7 +482,7 @@ func TestParseTemplateVariables(t *testing.T) { flags := pflag.NewFlagSet("test", pflag.ContinueOnError) flags.StringSlice("var", tt.varFlags, "test flag") - result, err := parseTemplateVariables(flags) + result, err := util.ParseTemplateVariables(flags) if tt.expectErr { require.Error(t, err) diff --git a/pkg/util/util.go b/pkg/util/util.go index 1856f20b..c0005f21 100644 --- a/pkg/util/util.go +++ b/pkg/util/util.go @@ -4,6 +4,9 @@ package util import ( "fmt" "io" + "strings" + + "github.com/spf13/pflag" ) // WriteToOut writes a message to the given io.Writer. @@ -18,3 +21,40 @@ func WriteToOut(out io.Writer, message string) { func Ptr[T any](value T) *T { return &value } + +// ParseTemplateVariables parses template variables from the --var flags +func ParseTemplateVariables(flags *pflag.FlagSet) (map[string]string, error) { + varFlags, err := flags.GetStringSlice("var") + if err != nil { + return nil, err + } + + templateVars := make(map[string]string) + for _, varFlag := range varFlags { + // Handle empty strings + if strings.TrimSpace(varFlag) == "" { + continue + } + + parts := strings.SplitN(varFlag, "=", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("invalid variable format '%s', expected 'key=value'", varFlag) + } + + key := strings.TrimSpace(parts[0]) + value := parts[1] // Don't trim value to preserve intentional whitespace + + if key == "" { + return nil, fmt.Errorf("variable key cannot be empty in '%s'", varFlag) + } + + // Check for duplicate keys + if _, exists := templateVars[key]; exists { + return nil, fmt.Errorf("duplicate variable key '%s'", key) + } + + templateVars[key] = value + } + + return templateVars, nil +} diff --git a/pkg/util/util_test.go b/pkg/util/util_test.go new file mode 100644 index 00000000..eef7cd88 --- /dev/null +++ b/pkg/util/util_test.go @@ -0,0 +1,111 @@ +package util + +import ( + "testing" + + "github.com/spf13/pflag" + "github.com/stretchr/testify/require" +) + +func TestParseTemplateVariables(t *testing.T) { + tests := []struct { + name string + varFlags []string + expected map[string]string + expectErr bool + }{ + { + name: "empty flags", + varFlags: []string{}, + expected: map[string]string{}, + }, + { + name: "single variable", + varFlags: []string{"name=Alice"}, + expected: map[string]string{"name": "Alice"}, + }, + { + name: "multiple variables", + varFlags: []string{"name=Alice", "age=30", "city=Boston"}, + expected: map[string]string{"name": "Alice", "age": "30", "city": "Boston"}, + }, + { + name: "variable with spaces in value", + varFlags: []string{"description=Hello World"}, + expected: map[string]string{"description": "Hello World"}, + }, + { + name: "variable with equals in value", + varFlags: []string{"equation=x=y+1"}, + expected: map[string]string{"equation": "x=y+1"}, + }, + { + name: "variable with empty value", + varFlags: []string{"empty="}, + expected: map[string]string{"empty": ""}, + }, + { + name: "variable with whitespace around key", + varFlags: []string{" name =Alice"}, + expected: map[string]string{"name": "Alice"}, + }, + { + name: "preserve whitespace in value", + varFlags: []string{"message= Hello World "}, + expected: map[string]string{"message": " Hello World "}, + }, + { + name: "empty string flag is ignored", + varFlags: []string{"", "name=Alice"}, + expected: map[string]string{"name": "Alice"}, + expectErr: false, + }, + { + name: "whitespace only flag is ignored", + varFlags: []string{" ", "name=Alice"}, + expected: map[string]string{"name": "Alice"}, + expectErr: false, + }, + { + name: "missing equals sign", + varFlags: []string{"name"}, + expectErr: true, + }, + { + name: "missing equals sign with multiple vars", + varFlags: []string{"name=Alice", "age"}, + expectErr: true, + }, + { + name: "empty key", + varFlags: []string{"=value"}, + expectErr: true, + }, + { + name: "whitespace only key", + varFlags: []string{" =value"}, + expectErr: true, + }, + { + name: "duplicate keys", + varFlags: []string{"name=Alice", "name=Bob"}, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + flags := pflag.NewFlagSet("test", pflag.ContinueOnError) + flags.StringSlice("var", tt.varFlags, "test flag") + + result, err := ParseTemplateVariables(flags) + + if tt.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tt.expected, result) + } + }) + } +}