diff --git a/cmd/eval/eval.go b/cmd/eval/eval.go index 7374ba69..c9c7487d 100644 --- a/cmd/eval/eval.go +++ b/cmd/eval/eval.go @@ -356,7 +356,7 @@ func (h *evalCommandHandler) runEvaluators(ctx context.Context, testCase map[str func (h *evalCommandHandler) runSingleEvaluator(ctx context.Context, evaluator prompt.Evaluator, testCase map[string]interface{}, response string) (EvaluationResult, error) { switch { case evaluator.String != nil: - return h.runStringEvaluator(evaluator.Name, *evaluator.String, response) + return h.runStringEvaluator(evaluator.Name, *evaluator.String, testCase, response) case evaluator.LLM != nil: return h.runLLMEvaluator(ctx, evaluator.Name, *evaluator.LLM, testCase, response) case evaluator.Uses != "": @@ -366,23 +366,39 @@ func (h *evalCommandHandler) runSingleEvaluator(ctx context.Context, evaluator p } } -func (h *evalCommandHandler) runStringEvaluator(name string, eval prompt.StringEvaluator, response string) (EvaluationResult, error) { +func (h *evalCommandHandler) runStringEvaluator(name string, eval prompt.StringEvaluator, testCase map[string]interface{}, response string) (EvaluationResult, error) { var passed bool var details string switch { case eval.Equals != "": - passed = response == eval.Equals - details = fmt.Sprintf("Expected exact match: '%s'", eval.Equals) + equals, err := h.templateString(eval.Equals, testCase) + if err != nil { + return EvaluationResult{}, fmt.Errorf("failed to template message content: %w", err) + } + passed = response == equals + details = fmt.Sprintf("Expected exact match: '%s'", equals) case eval.Contains != "": - passed = strings.Contains(strings.ToLower(response), strings.ToLower(eval.Contains)) - details = fmt.Sprintf("Expected to contain: '%s'", eval.Contains) + contains, err := h.templateString(eval.Contains, testCase) + if err != nil { + return EvaluationResult{}, fmt.Errorf("failed to template message content: %w", err) + } + passed = strings.Contains(strings.ToLower(response), strings.ToLower(contains)) + details = fmt.Sprintf("Expected to contain: '%s'", contains) case eval.StartsWith != "": - passed = strings.HasPrefix(strings.ToLower(response), strings.ToLower(eval.StartsWith)) - details = fmt.Sprintf("Expected to start with: '%s'", eval.StartsWith) + startsWith, err := h.templateString(eval.StartsWith, testCase) + if err != nil { + return EvaluationResult{}, fmt.Errorf("failed to template message content: %w", err) + } + passed = strings.HasPrefix(strings.ToLower(response), strings.ToLower(startsWith)) + details = fmt.Sprintf("Expected to start with: '%s'", startsWith) case eval.EndsWith != "": - passed = strings.HasSuffix(strings.ToLower(response), strings.ToLower(eval.EndsWith)) - details = fmt.Sprintf("Expected to end with: '%s'", eval.EndsWith) + endsWith, err := h.templateString(eval.EndsWith, testCase) + if err != nil { + return EvaluationResult{}, fmt.Errorf("failed to template message content: %w", err) + } + passed = strings.HasSuffix(strings.ToLower(response), strings.ToLower(endsWith)) + details = fmt.Sprintf("Expected to end with: '%s'", endsWith) default: return EvaluationResult{}, errors.New("no string evaluation criteria specified") } diff --git a/cmd/eval/eval_test.go b/cmd/eval/eval_test.go index ed831705..1548d798 100644 --- a/cmd/eval/eval_test.go +++ b/cmd/eval/eval_test.go @@ -88,6 +88,7 @@ evaluators: evaluator prompt.StringEvaluator response string expected bool + variables map[string]interface{} }{ { name: "contains match", @@ -125,11 +126,25 @@ evaluators: response: "hello world", expected: true, }, + { + name: "contains with variable", + evaluator: prompt.StringEvaluator{Contains: "{{expected}}"}, + response: "hello world", + expected: true, + variables: map[string]interface{}{"expected": "world"}, + }, + { + name: "fails with variable not match", + evaluator: prompt.StringEvaluator{Contains: "{{expected}}"}, + response: "hello world", + expected: false, + variables: map[string]interface{}{"expected": "goodbye"}, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, err := handler.runStringEvaluator("test", tt.evaluator, tt.response) + result, err := handler.runStringEvaluator("test", tt.evaluator, tt.variables, tt.response) require.NoError(t, err) require.Equal(t, tt.expected, result.Passed) if tt.expected { diff --git a/examples/sample_prompt.yml b/examples/sample_prompt.yml index 342b4c81..ddf00a28 100644 --- a/examples/sample_prompt.yml +++ b/examples/sample_prompt.yml @@ -6,8 +6,10 @@ modelParameters: maxTokens: 50 testData: - input: 'hello world' + string: hello expected: 'greeting response' - input: 'goodbye world' + string: goodbye expected: 'farewell response' messages: - role: system @@ -17,6 +19,6 @@ messages: evaluators: - name: string evaluator string: - contains: world + contains: '{{string}}' - name: similarity check uses: github/similarity