Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 53 additions & 44 deletions cmd/eval/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"strings"
"time"

"github.com/MakeNowJust/heredoc"
"github.com/github/gh-models/internal/azuremodels"
Expand Down Expand Up @@ -80,6 +81,8 @@ func NewEvalCommand(cfg *command.Config) *cobra.Command {

By default, results are displayed in a human-readable format. Use the --json flag
to output structured JSON data for programmatic use or integration with CI/CD pipelines.
This command will automatically retry on rate limiting errors, waiting for the specified
duration before retrying the request.

See https://docs.github.com/github-models/use-github-models/storing-prompts-in-github-repositories#supported-file-format for more information.
`),
Expand Down Expand Up @@ -327,36 +330,65 @@ func (h *evalCommandHandler) templateString(templateStr string, data map[string]
return prompt.TemplateString(templateStr, data)
}

func (h *evalCommandHandler) callModel(ctx context.Context, messages []azuremodels.ChatMessage) (string, error) {
req := h.evalFile.BuildChatCompletionOptions(messages)

resp, err := h.client.GetChatCompletionStream(ctx, req, h.org)
if err != nil {
return "", err
}
// callModelWithRetry makes an API call with automatic retry on rate limiting
func (h *evalCommandHandler) callModelWithRetry(ctx context.Context, req azuremodels.ChatCompletionOptions) (string, error) {
const maxRetries = 3

// For non-streaming requests, we should get a single response
var content strings.Builder
for {
completion, err := resp.Reader.Read()
for attempt := 0; attempt <= maxRetries; attempt++ {
resp, err := h.client.GetChatCompletionStream(ctx, req, h.org)
if err != nil {
if errors.Is(err, context.Canceled) || strings.Contains(err.Error(), "EOF") {
break
var rateLimitErr *azuremodels.RateLimitError
if errors.As(err, &rateLimitErr) {
if attempt < maxRetries {
if !h.jsonOutput {
h.cfg.WriteToOut(fmt.Sprintf(" Rate limited, waiting %v before retry (attempt %d/%d)...\n",
rateLimitErr.RetryAfter, attempt+1, maxRetries+1))
}

// Wait for the specified duration
select {
case <-ctx.Done():
return "", ctx.Err()
case <-time.After(rateLimitErr.RetryAfter):
continue
}
}
return "", fmt.Errorf("rate limit exceeded after %d attempts: %w", attempt+1, err)
}
// For non-rate-limit errors, return immediately
return "", err
}

for _, choice := range completion.Choices {
if choice.Delta != nil && choice.Delta.Content != nil {
content.WriteString(*choice.Delta.Content)
var content strings.Builder
for {
completion, err := resp.Reader.Read()
if err != nil {
if errors.Is(err, context.Canceled) || strings.Contains(err.Error(), "EOF") {
break
}
return "", err
}
if choice.Message != nil && choice.Message.Content != nil {
content.WriteString(*choice.Message.Content)

for _, choice := range completion.Choices {
if choice.Delta != nil && choice.Delta.Content != nil {
content.WriteString(*choice.Delta.Content)
}
if choice.Message != nil && choice.Message.Content != nil {
content.WriteString(*choice.Message.Content)
}
}
}

return strings.TrimSpace(content.String()), nil
}

return strings.TrimSpace(content.String()), nil
// This should never be reached, but just in case
return "", errors.New("unexpected error calling model")
}

func (h *evalCommandHandler) callModel(ctx context.Context, messages []azuremodels.ChatMessage) (string, error) {
req := h.evalFile.BuildChatCompletionOptions(messages)
return h.callModelWithRetry(ctx, req)
}

func (h *evalCommandHandler) runEvaluators(ctx context.Context, testCase map[string]interface{}, response string) ([]EvaluationResult, error) {
Expand Down Expand Up @@ -437,7 +469,6 @@ func (h *evalCommandHandler) runStringEvaluator(name string, eval prompt.StringE
}

func (h *evalCommandHandler) runLLMEvaluator(ctx context.Context, name string, eval prompt.LLMEvaluator, testCase map[string]interface{}, response string) (EvaluationResult, error) {
// Template the evaluation prompt
evalData := make(map[string]interface{})
for k, v := range testCase {
evalData[k] = v
Expand All @@ -449,7 +480,6 @@ func (h *evalCommandHandler) runLLMEvaluator(ctx context.Context, name string, e
return EvaluationResult{}, fmt.Errorf("failed to template evaluation prompt: %w", err)
}

// Prepare messages for evaluation
var messages []azuremodels.ChatMessage
if eval.SystemPrompt != "" {
messages = append(messages, azuremodels.ChatMessage{
Expand All @@ -462,40 +492,19 @@ func (h *evalCommandHandler) runLLMEvaluator(ctx context.Context, name string, e
Content: util.Ptr(promptContent),
})

// Call the evaluation model
req := azuremodels.ChatCompletionOptions{
Messages: messages,
Model: eval.ModelID,
Stream: false,
}

resp, err := h.client.GetChatCompletionStream(ctx, req, h.org)
evalResponseText, err := h.callModelWithRetry(ctx, req)
if err != nil {
return EvaluationResult{}, fmt.Errorf("failed to call evaluation model: %w", err)
}

var evalResponse strings.Builder
for {
completion, err := resp.Reader.Read()
if err != nil {
if errors.Is(err, context.Canceled) || strings.Contains(err.Error(), "EOF") {
break
}
return EvaluationResult{}, err
}

for _, choice := range completion.Choices {
if choice.Delta != nil && choice.Delta.Content != nil {
evalResponse.WriteString(*choice.Delta.Content)
}
if choice.Message != nil && choice.Message.Content != nil {
evalResponse.WriteString(*choice.Message.Content)
}
}
}

// Match response to choices
evalResponseText := strings.TrimSpace(strings.ToLower(evalResponse.String()))
evalResponseText = strings.TrimSpace(strings.ToLower(evalResponseText))
for _, choice := range eval.Choices {
if strings.Contains(evalResponseText, strings.ToLower(choice.Choice)) {
return EvaluationResult{
Expand Down
48 changes: 48 additions & 0 deletions internal/azuremodels/azure_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ import (
"io"
"net/http"
"slices"
"strconv"
"strings"
"time"

"github.com/cli/go-gh/v2/pkg/api"
"github.com/github/gh-models/internal/modelkey"
Expand Down Expand Up @@ -259,6 +261,42 @@ func (c *AzureClient) handleHTTPError(resp *http.Response) error {
return err
}

case http.StatusTooManyRequests:
// Handle rate limiting
retryAfter := time.Duration(0)

// Check for x-ratelimit-timeremaining header (in seconds)
if timeRemainingStr := resp.Header.Get("x-ratelimit-timeremaining"); timeRemainingStr != "" {
if seconds, parseErr := strconv.Atoi(timeRemainingStr); parseErr == nil {
retryAfter = time.Duration(seconds) * time.Second
}
}

// Fall back to standard Retry-After header if x-ratelimit-timeremaining is not available
if retryAfter == 0 {
if retryAfterStr := resp.Header.Get("Retry-After"); retryAfterStr != "" {
if seconds, parseErr := strconv.Atoi(retryAfterStr); parseErr == nil {
retryAfter = time.Duration(seconds) * time.Second
}
}
}

// Default to 60 seconds if no retry-after information is provided
if retryAfter == 0 {
retryAfter = 60 * time.Second
}

body, _ := io.ReadAll(resp.Body)
message := "rate limit exceeded"
if len(body) > 0 {
message = string(body)
}

return &RateLimitError{
RetryAfter: retryAfter,
Message: strings.TrimSpace(message),
}

default:
_, err = sb.WriteString("unexpected response from the server: " + resp.Status)
if err != nil {
Expand Down Expand Up @@ -286,3 +324,13 @@ func (c *AzureClient) handleHTTPError(resp *http.Response) error {

return errors.New(sb.String())
}

// RateLimitError represents a rate limiting error from the API
type RateLimitError struct {
RetryAfter time.Duration
Message string
}

func (e *RateLimitError) Error() string {
return fmt.Sprintf("rate limited: %s (retry after %v)", e.Message, e.RetryAfter)
}
109 changes: 109 additions & 0 deletions internal/azuremodels/rate_limit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package azuremodels

import (
"net/http"
"strings"
"testing"
"time"
)

func TestRateLimitError(t *testing.T) {
err := &RateLimitError{
RetryAfter: 30 * time.Second,
Message: "Too many requests",
}

expected := "rate limited: Too many requests (retry after 30s)"
if err.Error() != expected {
t.Errorf("Expected error message %q, got %q", expected, err.Error())
}
}

func TestHandleHTTPError_RateLimit(t *testing.T) {
client := &AzureClient{}

tests := []struct {
name string
statusCode int
headers map[string]string
expectedRetryAfter time.Duration
}{
{
name: "Rate limit with x-ratelimit-timeremaining header",
statusCode: http.StatusTooManyRequests,
headers: map[string]string{
"x-ratelimit-timeremaining": "45",
},
expectedRetryAfter: 45 * time.Second,
},
{
name: "Rate limit with Retry-After header",
statusCode: http.StatusTooManyRequests,
headers: map[string]string{
"Retry-After": "60",
},
expectedRetryAfter: 60 * time.Second,
},
{
name: "Rate limit with both headers - x-ratelimit-timeremaining takes precedence",
statusCode: http.StatusTooManyRequests,
headers: map[string]string{
"x-ratelimit-timeremaining": "30",
"Retry-After": "90",
},
expectedRetryAfter: 30 * time.Second,
},
{
name: "Rate limit with no headers - default to 60s",
statusCode: http.StatusTooManyRequests,
headers: map[string]string{},
expectedRetryAfter: 60 * time.Second,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp := &http.Response{
StatusCode: tt.statusCode,
Header: make(http.Header),
Body: &mockReadCloser{reader: strings.NewReader("rate limit exceeded")},
}

for key, value := range tt.headers {
resp.Header.Set(key, value)
}

err := client.handleHTTPError(resp)

var rateLimitErr *RateLimitError
if !isRateLimitError(err, &rateLimitErr) {
t.Fatalf("Expected RateLimitError, got %T: %v", err, err)
}

if rateLimitErr.RetryAfter != tt.expectedRetryAfter {
t.Errorf("Expected RetryAfter %v, got %v", tt.expectedRetryAfter, rateLimitErr.RetryAfter)
}
})
}
}

// Helper function to check if error is a RateLimitError (mimics errors.As)
func isRateLimitError(err error, target **RateLimitError) bool {
if rateLimitErr, ok := err.(*RateLimitError); ok {
*target = rateLimitErr
return true
}
return false
}

type mockReadCloser struct {
reader *strings.Reader
}

func (m *mockReadCloser) Read(p []byte) (n int, err error) {
return m.reader.Read(p)
}

func (m *mockReadCloser) Close() error {
return nil
}
Loading