Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 24 additions & 15 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Run models list",
"type": "go",
"request": "launch",
"mode": "auto",
"program": "${workspaceFolder}/main.go",
"args": ["list"]
}
]
}
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Run models list",
"type": "go",
"request": "launch",
"mode": "auto",
"program": "${workspaceFolder}/main.go",
"args": ["list"]
},
{
"name": "Run models view",
"type": "go",
"request": "launch",
"mode": "auto",
"program": "${workspaceFolder}/main.go",
"args": ["view"],
"console": "integratedTerminal"
}
]
}
2 changes: 1 addition & 1 deletion cmd/list/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func NewListCommand(cfg *command.Config) *cobra.Command {
printer.EndRow()

for _, model := range models {
printer.AddField(azuremodels.FormatIdentifier(model.Publisher, model.Name))
printer.AddField(model.ID)
printer.AddField(model.FriendlyName)
printer.EndRow()
}
Expand Down
5 changes: 2 additions & 3 deletions cmd/list/list_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,13 @@ func TestList(t *testing.T) {
t.Run("NewListCommand happy path", func(t *testing.T) {
client := azuremodels.NewMockClient()
modelSummary := &azuremodels.ModelSummary{
ID: "test-id-1",
ID: "openai/test-id-1",
Name: "test-model-1",
FriendlyName: "Test Model 1",
Task: "chat-completion",
Publisher: "OpenAI",
Summary: "This is a test model",
Version: "1.0",
RegistryName: "azure-openai",
}
listModelsCallCount := 0
client.MockListModels = func(ctx context.Context) ([]*azuremodels.ModelSummary, error) {
Expand All @@ -41,7 +40,7 @@ func TestList(t *testing.T) {
require.Contains(t, output, "DISPLAY NAME")
require.Contains(t, output, "ID")
require.Contains(t, output, modelSummary.FriendlyName)
require.Contains(t, output, azuremodels.FormatIdentifier(modelSummary.Publisher, modelSummary.Name))
require.Contains(t, output, modelSummary.ID)
})

t.Run("--help prints usage info", func(t *testing.T) {
Expand Down
5 changes: 3 additions & 2 deletions cmd/run/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,8 @@ func (h *runCommandHandler) getModelNameFromArgs(models []*azuremodels.ModelSumm
if !model.IsChatModel() {
continue
}
prompt.Options = append(prompt.Options, azuremodels.FormatIdentifier(model.Publisher, model.Name))

prompt.Options = append(prompt.Options, model.ID)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's nice being able to use a simple ID property and not have to format an identifier from other bits.

}

err := survey.AskOne(prompt, &modelName, survey.WithPageSize(10))
Expand Down Expand Up @@ -533,7 +534,7 @@ func validateModelName(modelName string, models []*azuremodels.ModelSummary) (st
}

// For non-custom providers, validate the model exists
expectedModelID := azuremodels.FormatIdentifier(parsedModel.Publisher, parsedModel.ModelName)
expectedModelID := parsedModel.String()
foundMatch := false
for _, model := range models {
if model.HasName(expectedModelID) {
Expand Down
15 changes: 10 additions & 5 deletions cmd/run/run_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@ func TestRun(t *testing.T) {
t.Run("NewRunCommand happy path", func(t *testing.T) {
client := azuremodels.NewMockClient()
modelSummary := &azuremodels.ModelSummary{
ID: "test-id-1",
ID: "openai/test-model-1",
Name: "test-model-1",
FriendlyName: "Test Model 1",
Task: "chat-completion",
Publisher: "OpenAI",
Summary: "This is a test model",
Version: "1.0",
RegistryName: "azure-openai",
}
listModelsCallCount := 0
client.MockListModels = func(ctx context.Context) ([]*azuremodels.ModelSummary, error) {
Expand All @@ -52,7 +51,7 @@ func TestRun(t *testing.T) {
buf := new(bytes.Buffer)
cfg := command.NewConfig(buf, buf, client, true, 80)
runCmd := NewRunCommand(cfg)
runCmd.SetArgs([]string{azuremodels.FormatIdentifier(modelSummary.Publisher, modelSummary.Name), "this is my prompt"})
runCmd.SetArgs([]string{modelSummary.ID, "this is my prompt"})

_, err := runCmd.ExecuteC()

Expand Down Expand Up @@ -104,6 +103,7 @@ messages:

client := azuremodels.NewMockClient()
modelSummary := &azuremodels.ModelSummary{
ID: "openai/test-model",
Name: "test-model",
Publisher: "openai",
Task: "chat-completion",
Expand Down Expand Up @@ -134,7 +134,7 @@ messages:
runCmd := NewRunCommand(cfg)
runCmd.SetArgs([]string{
"--file", tmp.Name(),
azuremodels.FormatIdentifier("openai", "test-model"),
"openai/test-model",
})

_, err = runCmd.ExecuteC()
Expand Down Expand Up @@ -170,6 +170,7 @@ messages:

client := azuremodels.NewMockClient()
modelSummary := &azuremodels.ModelSummary{
ID: "openai/test-model",
Name: "test-model",
Publisher: "openai",
Task: "chat-completion",
Expand Down Expand Up @@ -214,7 +215,7 @@ messages:
runCmd := NewRunCommand(cfg)
runCmd.SetArgs([]string{
"--file", tmp.Name(),
azuremodels.FormatIdentifier("openai", "test-model"),
"openai/test-model",
initialPrompt,
})

Expand Down Expand Up @@ -252,11 +253,13 @@ messages:

client := azuremodels.NewMockClient()
modelSummary := &azuremodels.ModelSummary{
ID: "openai/example-model",
Name: "example-model",
Publisher: "openai",
Task: "chat-completion",
}
modelSummary2 := &azuremodels.ModelSummary{
ID: "openai/example-model-4o-mini-plus",
Name: "example-model-4o-mini-plus",
Publisher: "openai",
Task: "chat-completion",
Expand Down Expand Up @@ -369,6 +372,7 @@ messages:

client := azuremodels.NewMockClient()
modelSummary := &azuremodels.ModelSummary{
ID: "openai/test-model",
Name: "test-model",
Publisher: "openai",
Task: "chat-completion",
Expand Down Expand Up @@ -533,6 +537,7 @@ func TestValidateModelName(t *testing.T) {

// Create a mock model for testing
mockModel := &azuremodels.ModelSummary{
ID: "openai/gpt-4",
Name: "gpt-4",
Publisher: "openai",
Task: "chat-completion",
Expand Down
5 changes: 2 additions & 3 deletions cmd/view/view.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func NewViewCommand(cfg *command.Config) *cobra.Command {
if !model.IsChatModel() {
continue
}
prompt.Options = append(prompt.Options, azuremodels.FormatIdentifier(model.Publisher, model.Name))
prompt.Options = append(prompt.Options, model.ID)
}

err = survey.AskOne(prompt, &modelName, survey.WithPageSize(10))
Expand All @@ -61,13 +61,12 @@ func NewViewCommand(cfg *command.Config) *cobra.Command {
case len(args) >= 1:
modelName = args[0]
}

modelSummary, err := getModelByName(modelName, models)
if err != nil {
return err
}

modelDetails, err := client.GetModelDetails(ctx, modelSummary.RegistryName, modelSummary.Name, modelSummary.Version)
modelDetails, err := client.GetModelDetails(ctx, modelSummary.Registry, modelSummary.Name, modelSummary.Version)
if err != nil {
return err
}
Expand Down
5 changes: 2 additions & 3 deletions cmd/view/view_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,13 @@ func TestView(t *testing.T) {
t.Run("NewViewCommand happy path", func(t *testing.T) {
client := azuremodels.NewMockClient()
modelSummary := &azuremodels.ModelSummary{
ID: "test-id-1",
ID: "openai/test-model-1",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought I remembered talk recently of moving away from requiring the publisher name in model identifiers, something from @sgoedecke. Did that happen? Should we also have a test that just has the model name for its ID if so?

Name: "test-model-1",
FriendlyName: "Test Model 1",
Task: "chat-completion",
Publisher: "OpenAI",
Summary: "This is a test model",
Version: "1.0",
RegistryName: "azure-openai",
}
listModelsCallCount := 0
client.MockListModels = func(ctx context.Context) ([]*azuremodels.ModelSummary, error) {
Expand Down Expand Up @@ -49,7 +48,7 @@ func TestView(t *testing.T) {
buf := new(bytes.Buffer)
cfg := command.NewConfig(buf, buf, client, true, 80)
viewCmd := NewViewCommand(cfg)
viewCmd.SetArgs([]string{azuremodels.FormatIdentifier(modelSummary.Publisher, modelSummary.Name)})
viewCmd.SetArgs([]string{modelSummary.ID})

_, err := viewCmd.ExecuteC()

Expand Down
48 changes: 22 additions & 26 deletions internal/azuremodels/azure_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ import (
"fmt"
"io"
"net/http"
"slices"
"strings"

"github.com/cli/go-gh/v2/pkg/api"
"github.com/github/gh-models/internal/modelkey"
"github.com/github/gh-models/internal/sse"
"golang.org/x/text/language"
"golang.org/x/text/language/display"
Expand Down Expand Up @@ -185,19 +187,7 @@ func lowercaseStrings(input []string) []string {

// ListModels returns a list of available models.
func (c *AzureClient) ListModels(ctx context.Context) ([]*ModelSummary, error) {
body := bytes.NewReader([]byte(`
{
"filters": [
{ "field": "freePlayground", "values": ["true"], "operator": "eq"},
{ "field": "labels", "values": ["latest"], "operator": "eq"}
],
"order": [
{ "field": "displayName", "direction": "asc" }
]
}
`))

httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.ModelsURL, body)
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, c.cfg.ModelsURL, nil)
if err != nil {
return nil, err
}
Expand All @@ -218,28 +208,34 @@ func (c *AzureClient) ListModels(ctx context.Context) ([]*ModelSummary, error) {
decoder := json.NewDecoder(resp.Body)
decoder.UseNumber()

var searchResponse modelCatalogSearchResponse
err = decoder.Decode(&searchResponse)
var catalog githubModelCatalogResponse
err = decoder.Decode(&catalog)
if err != nil {
return nil, err
}

models := make([]*ModelSummary, 0, len(searchResponse.Summaries))
for _, summary := range searchResponse.Summaries {
models := make([]*ModelSummary, 0, len(catalog))
for _, catalogModel := range catalog {
// Determine task from supported modalities - if it supports text input/output, it's likely a chat model
inferenceTask := ""
if len(summary.InferenceTasks) > 0 {
inferenceTask = summary.InferenceTasks[0]
if slices.Contains(catalogModel.SupportedInputModalities, "text") && slices.Contains(catalogModel.SupportedOutputModalities, "text") {
inferenceTask = "chat-completion"
}

modelKey, err := modelkey.ParseModelKey(catalogModel.ID)
if err != nil {
return nil, fmt.Errorf("parsing model key %q: %w", catalogModel.ID, err)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh you calling it a "model key" here when printing out the ID makes me think that we will indeed have the publisher + model name combined in the field, so probably don't need such a test as I suggested above. 👍🏻

}

models = append(models, &ModelSummary{
ID: summary.AssetID,
Name: summary.Name,
FriendlyName: summary.DisplayName,
ID: catalogModel.ID,
Name: modelKey.ModelName,
Registry: catalogModel.Registry,
FriendlyName: catalogModel.Name,
Task: inferenceTask,
Publisher: summary.Publisher,
Summary: summary.Summary,
Version: summary.Version,
RegistryName: summary.RegistryName,
Publisher: catalogModel.Publisher,
Summary: catalogModel.Summary,
Version: catalogModel.Version,
})
}

Expand Down
2 changes: 1 addition & 1 deletion internal/azuremodels/azure_client_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ const (
defaultInferenceRoot = "https://models.github.ai"
defaultInferencePath = "inference/chat/completions"
defaultAzureAiStudioURL = "https://api.catalog.azureml.ms"
defaultModelsURL = defaultAzureAiStudioURL + "/asset-gallery/v1.0/models"
defaultModelsURL = "https://models.github.ai/catalog/models"
)

// AzureClientConfig represents configurable settings for the Azure client.
Expand Down
Loading
Loading