Skip to content

Commit d9fe0df

Browse files
committed
PR comments
1 parent aee2b76 commit d9fe0df

File tree

3 files changed

+106
-128
lines changed

3 files changed

+106
-128
lines changed

internal/service/db/impl.go

Lines changed: 63 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ package database
44
import (
55
"context"
66
"encoding/base64"
7+
"errors"
78
"fmt"
89
"time"
910

@@ -17,6 +18,11 @@ import (
1718
"github.com/stacklok/toolhive-registry-server/internal/service"
1819
)
1920

21+
var (
22+
// ErrBug is returned when a server is not found
23+
ErrBug = errors.New("bug")
24+
)
25+
2026
// options holds configuration options for the database service
2127
type options struct {
2228
pool *pgxpool.Pool
@@ -26,7 +32,8 @@ type options struct {
2632
type Option func(*options) error
2733

2834
// WithConnectionPool creates a new database-backed registry service with the
29-
// given pgx pool.
35+
// given pgx pool. The caller is responsible for closing the pool when it is
36+
// done.
3037
func WithConnectionPool(pool *pgxpool.Pool) Option {
3138
return func(o *options) error {
3239
if pool == nil {
@@ -37,24 +44,6 @@ func WithConnectionPool(pool *pgxpool.Pool) Option {
3744
}
3845
}
3946

40-
// WithConnectionString creates a new database-backed registry service with
41-
// the given connection string.
42-
func WithConnectionString(connString string) Option {
43-
return func(o *options) error {
44-
if connString == "" {
45-
return fmt.Errorf("connection string is required")
46-
}
47-
48-
pool, err := pgxpool.New(context.Background(), connString)
49-
if err != nil {
50-
return fmt.Errorf("failed to create pgx pool: %w", err)
51-
}
52-
53-
o.pool = pool
54-
return nil
55-
}
56-
}
57-
5847
// dbService implements the RegistryService interface using a database backend
5948
type dbService struct {
6049
pool *pgxpool.Pool
@@ -98,7 +87,6 @@ func (s *dbService) ListServers(
9887
ctx context.Context,
9988
opts ...service.Option[service.ListServersOptions],
10089
) ([]*upstreamv0.ServerJSON, error) {
101-
// TODO: implement
10290
options := &service.ListServersOptions{}
10391
for _, opt := range opts {
10492
if err := opt(options); err != nil {
@@ -108,22 +96,24 @@ func (s *dbService) ListServers(
10896

10997
decoded, err := base64.StdEncoding.DecodeString(options.Cursor)
11098
if err != nil {
111-
return nil, err
99+
return nil, fmt.Errorf("invalid cursor format: %w", err)
112100
}
113101
nextTime, err := time.Parse(time.RFC3339, string(decoded))
114102
if err != nil {
115-
return nil, err
103+
return nil, fmt.Errorf("invalid cursor format: %w", err)
116104
}
117105

118-
querierFunc := func(querier sqlc.Querier) ([]helper, error) {
106+
// Note: this function fetches a list of servers. In case no records are
107+
// found, the called function should return an empty slice as it's
108+
// customary in Go.
109+
querierFunc := func(ctx context.Context, querier sqlc.Querier) ([]helper, error) {
119110
servers, err := querier.ListServers(
120111
ctx,
121112
sqlc.ListServersParams{
122113
Next: &nextTime,
123114
Size: int64(options.Limit),
124115
},
125116
)
126-
127117
if err != nil {
128118
return nil, err
129119
}
@@ -151,7 +141,10 @@ func (s *dbService) ListServerVersions(
151141
}
152142
}
153143

154-
querierFunc := func(querier sqlc.Querier) ([]helper, error) {
144+
// Note: this function fetches a list of server versions. In case no records are
145+
// found, the called function should return an empty slice as it's
146+
// customary in Go.
147+
querierFunc := func(ctx context.Context, querier sqlc.Querier) ([]helper, error) {
155148
servers, err := querier.ListServerVersions(
156149
ctx,
157150
sqlc.ListServerVersionsParams{
@@ -188,7 +181,10 @@ func (s *dbService) GetServerVersion(
188181
}
189182
}
190183

191-
querierFunc := func(querier sqlc.Querier) ([]helper, error) {
184+
// Note: this function fetches a single record given name and version.
185+
// In case no record is found, the called function should return an
186+
// `sql.ErrNoRows` error as it's customary in Go.
187+
querierFunc := func(ctx context.Context, querier sqlc.Querier) ([]helper, error) {
192188
server, err := querier.GetServerVersion(
193189
ctx,
194190
sqlc.GetServerVersionParams{
@@ -208,25 +204,62 @@ func (s *dbService) GetServerVersion(
208204
return nil, err
209205
}
210206

207+
// Note: the `queryFunc` function is expected to return an error
208+
// sooner if no records are found, so getting this far with
209+
// a length result slice other than 1 means there's a bug.
210+
if len(res) != 1 {
211+
return nil, fmt.Errorf("%w: number of servers returned is not 1", ErrBug)
212+
}
213+
211214
return res[0], nil
212215
}
213216

217+
// querierFunction is a function that uses the given querier object to run the
218+
// main extraction. As of the time of this writing, its main use is accessing
219+
// the `mcp_server` table in a type-agnostic way. This is to overcome a
220+
// limitation of sqlc that prevents us from having the exact same go type
221+
// despite the fact that the underlying columns returned are the same.
222+
//
223+
// Note that the underlying table does not have to be the `mcp_server` table
224+
// as it is used now, as long as the result is a slice of helpers.
225+
type querierFunction func(ctx context.Context, querier sqlc.Querier) ([]helper, error)
226+
227+
// sharedListServers is a helper function to list servers and mapping them to
228+
// the API schema.
229+
//
230+
// Its responsibilities are:
231+
// * Begin a transaction
232+
// * Execute the querier function
233+
// * List packages and remotes using the server IDs
234+
// * Map the results to the API schema
235+
// * Return the results
236+
//
237+
// The argument `querierFunc` is a function that uses the given querier object
238+
// to run the main extraction. Note that the underlying table does not have
239+
// to be the `mcp_server` table as it is used now, as long as the result is a
240+
// slice of helpers.
214241
func (s *dbService) sharedListServers(
215242
ctx context.Context,
216-
querierFunc func(querier sqlc.Querier) ([]helper, error),
243+
querierFunc querierFunction,
217244
) ([]*upstreamv0.ServerJSON, error) {
218245
tx, err := s.pool.BeginTx(ctx, pgx.TxOptions{
219246
IsoLevel: pgx.ReadCommitted,
220247
AccessMode: pgx.ReadOnly,
221248
})
222249
if err != nil {
223-
return nil, err
250+
return nil, fmt.Errorf("failed to begin transaction: %w", err)
224251
}
225-
defer tx.Rollback(ctx)
252+
defer func() {
253+
err := tx.Rollback(ctx)
254+
if err != nil && !errors.Is(err, pgx.ErrTxClosed) {
255+
// TODO: log the rollback error (add proper logging)
256+
_ = err
257+
}
258+
}()
226259

227260
querier := sqlc.New(tx)
228261

229-
servers, err := querierFunc(querier)
262+
servers, err := querierFunc(ctx, querier)
230263
if err != nil {
231264
return nil, err
232265
}

internal/service/db/impl_test.go

Lines changed: 2 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -730,77 +730,6 @@ func TestWithConnectionPool(t *testing.T) {
730730
}
731731
}
732732

733-
func TestWithConnectionString(t *testing.T) {
734-
t.Parallel()
735-
736-
tests := []struct {
737-
name string
738-
setupFunc func(*testing.T) string
739-
validateFunc func(*testing.T, error, *options)
740-
}{
741-
{
742-
name: "success with valid connection string",
743-
//nolint:thelper // We want to see these lines in the test output
744-
setupFunc: func(t *testing.T) string {
745-
ctx := context.Background()
746-
db, cleanupFunc := database.SetupTestDB(t)
747-
t.Cleanup(cleanupFunc)
748-
749-
connStr := db.Config().ConnString()
750-
pool, err := pgxpool.New(ctx, connStr)
751-
require.NoError(t, err)
752-
t.Cleanup(func() { pool.Close() })
753-
754-
return connStr
755-
},
756-
//nolint:thelper // We want to see these lines in the test output
757-
validateFunc: func(t *testing.T, err error, o *options) {
758-
require.NoError(t, err)
759-
require.NotNil(t, o.pool)
760-
// Clean up the pool created by the option
761-
if o.pool != nil {
762-
o.pool.Close()
763-
}
764-
},
765-
},
766-
{
767-
name: "failure with empty connection string",
768-
//nolint:thelper // We want to see these lines in the test output
769-
setupFunc: func(_ *testing.T) string {
770-
return ""
771-
},
772-
//nolint:thelper // We want to see these lines in the test output
773-
validateFunc: func(t *testing.T, err error, _ *options) {
774-
require.Error(t, err)
775-
},
776-
},
777-
{
778-
name: "failure with invalid connection string",
779-
//nolint:thelper // We want to see these lines in the test output
780-
setupFunc: func(_ *testing.T) string {
781-
return "invalid-connection-string"
782-
},
783-
//nolint:thelper // We want to see these lines in the test output
784-
validateFunc: func(t *testing.T, err error, _ *options) {
785-
require.Error(t, err)
786-
},
787-
},
788-
}
789-
790-
for _, tt := range tests {
791-
t.Run(tt.name, func(t *testing.T) {
792-
t.Parallel()
793-
794-
connString := tt.setupFunc(t)
795-
opt := WithConnectionString(connString)
796-
o := &options{}
797-
err := opt(o)
798-
799-
tt.validateFunc(t, err, o)
800-
})
801-
}
802-
}
803-
804733
func TestNew(t *testing.T) {
805734
t.Parallel()
806735

@@ -843,7 +772,7 @@ func TestNew(t *testing.T) {
843772
require.NoError(t, err)
844773
t.Cleanup(func() { pool.Close() })
845774

846-
return []Option{WithConnectionString(connStr)}
775+
return []Option{WithConnectionPool(pool)}
847776
},
848777
//nolint:thelper // We want to see these lines in the test output
849778
validateFunc: func(t *testing.T, svc service.RegistryService, err error) {
@@ -855,7 +784,7 @@ func TestNew(t *testing.T) {
855784
name: "failure with invalid connection string",
856785
//nolint:thelper // We want to see these lines in the test output
857786
setupFunc: func(_ *testing.T) []Option {
858-
return []Option{WithConnectionString("invalid-connection-string")}
787+
return []Option{WithConnectionPool(nil)}
859788
},
860789
//nolint:thelper // We want to see these lines in the test output
861790
validateFunc: func(t *testing.T, svc service.RegistryService, err error) {

internal/service/db/types.go

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package database
33
import (
44
"time"
55

6+
"github.com/aws/smithy-go/ptr"
67
"github.com/google/uuid"
78
upstreamv0 "github.com/modelcontextprotocol/registry/pkg/api/v0"
89
model "github.com/modelcontextprotocol/registry/pkg/model"
@@ -113,29 +114,44 @@ func helperToServer(
113114
server := upstreamv0.ServerJSON{
114115
Schema: "https://static.modelcontextprotocol.io/schemas/2025-10-17/server.schema.json",
115116
Name: dbServer.Name,
116-
Description: *dbServer.Description,
117-
Title: *dbServer.Title,
118-
Repository: &model.Repository{
119-
URL: *dbServer.RepositoryUrl,
120-
Source: *dbServer.RepositoryType,
121-
ID: *dbServer.RepositoryID,
122-
Subfolder: *dbServer.RepositorySubfolder,
123-
},
124-
Version: dbServer.Version,
125-
WebsiteURL: *dbServer.Website,
126-
Packages: toPackages(packages),
127-
Remotes: toRemotes(remotes),
117+
Description: ptr.ToString(dbServer.Description),
118+
Title: ptr.ToString(dbServer.Title),
119+
Version: dbServer.Version,
120+
WebsiteURL: ptr.ToString(dbServer.Website),
121+
Packages: toPackages(packages),
122+
Remotes: toRemotes(remotes),
123+
}
124+
125+
if dbServer.RepositoryUrl != nil {
126+
server.Repository = &model.Repository{
127+
URL: ptr.ToString(dbServer.RepositoryUrl),
128+
Source: ptr.ToString(dbServer.RepositoryType),
129+
ID: ptr.ToString(dbServer.RepositoryID),
130+
Subfolder: ptr.ToString(dbServer.RepositorySubfolder),
131+
}
128132
}
129133

130134
server.Meta = &upstreamv0.ServerMeta{
131135
PublisherProvided: make(map[string]any),
132136
}
133-
server.Meta.PublisherProvided["upstream_meta"] = dbServer.UpstreamMeta
134-
server.Meta.PublisherProvided["server_meta"] = dbServer.ServerMeta
135-
server.Meta.PublisherProvided["repository_url"] = dbServer.RepositoryUrl
136-
server.Meta.PublisherProvided["repository_id"] = dbServer.RepositoryID
137-
server.Meta.PublisherProvided["repository_subfolder"] = dbServer.RepositorySubfolder
138-
server.Meta.PublisherProvided["repository_type"] = dbServer.RepositoryType
137+
if len(dbServer.UpstreamMeta) > 0 {
138+
server.Meta.PublisherProvided["upstream_meta"] = dbServer.UpstreamMeta
139+
}
140+
if len(dbServer.ServerMeta) > 0 {
141+
server.Meta.PublisherProvided["server_meta"] = dbServer.ServerMeta
142+
}
143+
if dbServer.RepositoryUrl != nil {
144+
server.Meta.PublisherProvided["repository_url"] = ptr.ToString(dbServer.RepositoryUrl)
145+
}
146+
if dbServer.RepositoryID != nil {
147+
server.Meta.PublisherProvided["repository_id"] = ptr.ToString(dbServer.RepositoryID)
148+
}
149+
if dbServer.RepositorySubfolder != nil {
150+
server.Meta.PublisherProvided["repository_subfolder"] = ptr.ToString(dbServer.RepositorySubfolder)
151+
}
152+
if dbServer.RepositoryType != nil {
153+
server.Meta.PublisherProvided["repository_type"] = ptr.ToString(dbServer.RepositoryType)
154+
}
139155

140156
return server
141157
}
@@ -150,11 +166,11 @@ func toPackages(
150166
RegistryBaseURL: dbPackage.PkgRegistryUrl,
151167
Identifier: dbPackage.PkgIdentifier,
152168
Version: dbPackage.PkgVersion,
153-
FileSHA256: *dbPackage.Sha256Hash,
154-
RunTimeHint: *dbPackage.RuntimeHint,
169+
FileSHA256: ptr.ToString(dbPackage.Sha256Hash),
170+
RunTimeHint: ptr.ToString(dbPackage.RuntimeHint),
155171
Transport: model.Transport{
156172
Type: dbPackage.Transport,
157-
URL: *dbPackage.TransportUrl,
173+
URL: ptr.ToString(dbPackage.TransportUrl),
158174
Headers: toKeyValueInputs(dbPackage.TransportHeaders),
159175
},
160176
RuntimeArguments: toArguments(dbPackage.RuntimeArguments),
@@ -183,9 +199,9 @@ func toKeyValueInputs(
183199
strings []string,
184200
) []model.KeyValueInput {
185201
result := make([]model.KeyValueInput, len(strings))
186-
for i, string := range strings {
202+
for i, str := range strings {
187203
result[i] = model.KeyValueInput{
188-
Name: string,
204+
Name: str,
189205
}
190206
}
191207
return result
@@ -195,9 +211,9 @@ func toArguments(
195211
strings []string,
196212
) []model.Argument {
197213
result := make([]model.Argument, len(strings))
198-
for i, string := range strings {
214+
for i, str := range strings {
199215
result[i] = model.Argument{
200-
Name: string,
216+
Name: str,
201217
}
202218
}
203219
return result

0 commit comments

Comments
 (0)