Skip to content

Commit 5e28c13

Browse files
blktdmjb
andauthored
Add db-backed implementation of RegistryService (#161)
This change adds a db-backed implementation of the `RegistryService` interface. There are multiple public interfaces added in this change, specifically * `internal/service` which exposes the `RegistryService` interface with an `Option` type used to configure filtering and pagination of SQL statements. * `internal/service/db` which exposes means to create a `RegistryService` backed by a database. Testing is implemented as integration tests running containers in the same fashion as the `internal/db/sqlc` package. This might be revised in the future by using mocks, but (a) a mock based implementation is not going to be more succinct, and (b) does its job of covering most of the new code. Co-authored-by: Don Browne <[email protected]>
1 parent adca57b commit 5e28c13

File tree

5 files changed

+1333
-2
lines changed

5 files changed

+1333
-2
lines changed

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ require (
8787
github.com/ianlancetaylor/demangle v0.0.0-20250417193237-f615e6bd150b // indirect
8888
github.com/jackc/pgpassfile v1.0.0 // indirect
8989
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
90+
github.com/jackc/puddle/v2 v2.2.2 // indirect
9091
github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect
9192
github.com/josharian/intern v1.0.0 // indirect
9293
github.com/json-iterator/go v1.1.12 // indirect

internal/api/extension/v0/routes.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,10 @@ func Router(svc service.RegistryService) http.Handler {
3535
r.Put("/registries/{registryName}", routes.upsertRegistry)
3636
r.Delete("/registries/{registryName}", routes.deleteRegistry)
3737

38-
r.Put("/registries/{registryName}/servers/{serverName}/versions/{version}", routes.upsertVersion)
39-
r.Delete("/registries/{registryName}/servers/{serverName}/versions/{version}", routes.deleteVersion)
38+
r.Route("/registries/{registryName}/servers/{serverName}", func(r chi.Router) {
39+
r.Put("/versions/{version}", routes.upsertVersion)
40+
r.Delete("/versions/{version}", routes.deleteVersion)
41+
})
4042

4143
return r
4244
}

internal/service/db/impl.go

Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
// Package database provides a database-backed implementation of the RegistryService interface
2+
package database
3+
4+
import (
5+
"context"
6+
"encoding/base64"
7+
"errors"
8+
"fmt"
9+
"time"
10+
11+
"github.com/google/uuid"
12+
"github.com/jackc/pgx/v5"
13+
"github.com/jackc/pgx/v5/pgxpool"
14+
upstreamv0 "github.com/modelcontextprotocol/registry/pkg/api/v0"
15+
toolhivetypes "github.com/stacklok/toolhive/pkg/registry/registry"
16+
17+
"github.com/stacklok/toolhive-registry-server/internal/db/sqlc"
18+
"github.com/stacklok/toolhive-registry-server/internal/service"
19+
)
20+
21+
var (
22+
// ErrBug is returned when a server is not found
23+
ErrBug = errors.New("bug")
24+
)
25+
26+
// options holds configuration options for the database service
27+
type options struct {
28+
pool *pgxpool.Pool
29+
}
30+
31+
// Option is a functional option for configuring the database service
32+
type Option func(*options) error
33+
34+
// WithConnectionPool creates a new database-backed registry service with the
35+
// given pgx pool. The caller is responsible for closing the pool when it is
36+
// done.
37+
func WithConnectionPool(pool *pgxpool.Pool) Option {
38+
return func(o *options) error {
39+
if pool == nil {
40+
return fmt.Errorf("pgx pool is required")
41+
}
42+
o.pool = pool
43+
return nil
44+
}
45+
}
46+
47+
// dbService implements the RegistryService interface using a database backend
48+
type dbService struct {
49+
pool *pgxpool.Pool
50+
}
51+
52+
var _ service.RegistryService = (*dbService)(nil)
53+
54+
// New creates a new database-backed registry service with the given options
55+
func New(opts ...Option) (service.RegistryService, error) {
56+
o := &options{}
57+
58+
for _, opt := range opts {
59+
if err := opt(o); err != nil {
60+
return nil, err
61+
}
62+
}
63+
64+
return &dbService{
65+
pool: o.pool,
66+
}, nil
67+
}
68+
69+
// CheckReadiness checks if the service is ready to serve requests
70+
func (s *dbService) CheckReadiness(ctx context.Context) error {
71+
err := s.pool.Ping(ctx)
72+
if err != nil {
73+
return fmt.Errorf("failed to ping database: %w", err)
74+
}
75+
return nil
76+
}
77+
78+
// GetRegistry returns the registry data with metadata
79+
func (*dbService) GetRegistry(
80+
_ context.Context,
81+
) (*toolhivetypes.UpstreamRegistry, string, error) {
82+
return nil, "", service.ErrNotImplemented
83+
}
84+
85+
// ListServers returns all servers in the registry
86+
func (s *dbService) ListServers(
87+
ctx context.Context,
88+
opts ...service.Option[service.ListServersOptions],
89+
) ([]*upstreamv0.ServerJSON, error) {
90+
options := &service.ListServersOptions{}
91+
for _, opt := range opts {
92+
if err := opt(options); err != nil {
93+
return nil, err
94+
}
95+
}
96+
97+
decoded, err := base64.StdEncoding.DecodeString(options.Cursor)
98+
if err != nil {
99+
return nil, fmt.Errorf("invalid cursor format: %w", err)
100+
}
101+
nextTime, err := time.Parse(time.RFC3339, string(decoded))
102+
if err != nil {
103+
return nil, fmt.Errorf("invalid cursor format: %w", err)
104+
}
105+
106+
// Note: this function fetches a list of servers. In case no records are
107+
// found, the called function should return an empty slice as it's
108+
// customary in Go.
109+
querierFunc := func(ctx context.Context, querier sqlc.Querier) ([]helper, error) {
110+
servers, err := querier.ListServers(
111+
ctx,
112+
sqlc.ListServersParams{
113+
Next: &nextTime,
114+
Size: int64(options.Limit),
115+
},
116+
)
117+
if err != nil {
118+
return nil, err
119+
}
120+
121+
helpers := make([]helper, len(servers))
122+
for i, server := range servers {
123+
helpers[i] = listServersRowToHelper(server)
124+
}
125+
126+
return helpers, nil
127+
}
128+
129+
return s.sharedListServers(ctx, querierFunc)
130+
}
131+
132+
// ListServerVersions implements RegistryService.ListServerVersions
133+
func (s *dbService) ListServerVersions(
134+
ctx context.Context,
135+
opts ...service.Option[service.ListServerVersionsOptions],
136+
) ([]*upstreamv0.ServerJSON, error) {
137+
options := &service.ListServerVersionsOptions{}
138+
for _, opt := range opts {
139+
if err := opt(options); err != nil {
140+
return nil, err
141+
}
142+
}
143+
144+
// Note: this function fetches a list of server versions. In case no records are
145+
// found, the called function should return an empty slice as it's
146+
// customary in Go.
147+
querierFunc := func(ctx context.Context, querier sqlc.Querier) ([]helper, error) {
148+
servers, err := querier.ListServerVersions(
149+
ctx,
150+
sqlc.ListServerVersionsParams{
151+
Name: options.Name,
152+
Next: options.Next,
153+
Prev: options.Prev,
154+
Size: int64(options.Limit),
155+
},
156+
)
157+
if err != nil {
158+
return nil, err
159+
}
160+
161+
helpers := make([]helper, len(servers))
162+
for i, server := range servers {
163+
helpers[i] = listServerVersionsRowToHelper(server)
164+
}
165+
166+
return helpers, nil
167+
}
168+
169+
return s.sharedListServers(ctx, querierFunc)
170+
}
171+
172+
// GetServer returns a specific server by name
173+
func (s *dbService) GetServerVersion(
174+
ctx context.Context,
175+
opts ...service.Option[service.GetServerVersionOptions],
176+
) (*upstreamv0.ServerJSON, error) {
177+
options := &service.GetServerVersionOptions{}
178+
for _, opt := range opts {
179+
if err := opt(options); err != nil {
180+
return nil, err
181+
}
182+
}
183+
184+
// Note: this function fetches a single record given name and version.
185+
// In case no record is found, the called function should return an
186+
// `sql.ErrNoRows` error as it's customary in Go.
187+
querierFunc := func(ctx context.Context, querier sqlc.Querier) ([]helper, error) {
188+
server, err := querier.GetServerVersion(
189+
ctx,
190+
sqlc.GetServerVersionParams{
191+
Name: options.Name,
192+
Version: options.Version,
193+
},
194+
)
195+
if err != nil {
196+
return nil, err
197+
}
198+
199+
return []helper{getServerVersionRowToHelper(server)}, nil
200+
}
201+
202+
res, err := s.sharedListServers(ctx, querierFunc)
203+
if err != nil {
204+
return nil, err
205+
}
206+
207+
// Note: the `queryFunc` function is expected to return an error
208+
// sooner if no records are found, so getting this far with
209+
// a length result slice other than 1 means there's a bug.
210+
if len(res) != 1 {
211+
return nil, fmt.Errorf("%w: number of servers returned is not 1", ErrBug)
212+
}
213+
214+
return res[0], nil
215+
}
216+
217+
// querierFunction is a function that uses the given querier object to run the
218+
// main extraction. As of the time of this writing, its main use is accessing
219+
// the `mcp_server` table in a type-agnostic way. This is to overcome a
220+
// limitation of sqlc that prevents us from having the exact same go type
221+
// despite the fact that the underlying columns returned are the same.
222+
//
223+
// Note that the underlying table does not have to be the `mcp_server` table
224+
// as it is used now, as long as the result is a slice of helpers.
225+
type querierFunction func(ctx context.Context, querier sqlc.Querier) ([]helper, error)
226+
227+
// sharedListServers is a helper function to list servers and mapping them to
228+
// the API schema.
229+
//
230+
// Its responsibilities are:
231+
// * Begin a transaction
232+
// * Execute the querier function
233+
// * List packages and remotes using the server IDs
234+
// * Map the results to the API schema
235+
// * Return the results
236+
//
237+
// The argument `querierFunc` is a function that uses the given querier object
238+
// to run the main extraction. Note that the underlying table does not have
239+
// to be the `mcp_server` table as it is used now, as long as the result is a
240+
// slice of helpers.
241+
func (s *dbService) sharedListServers(
242+
ctx context.Context,
243+
querierFunc querierFunction,
244+
) ([]*upstreamv0.ServerJSON, error) {
245+
tx, err := s.pool.BeginTx(ctx, pgx.TxOptions{
246+
IsoLevel: pgx.ReadCommitted,
247+
AccessMode: pgx.ReadOnly,
248+
})
249+
if err != nil {
250+
return nil, fmt.Errorf("failed to begin transaction: %w", err)
251+
}
252+
defer func() {
253+
err := tx.Rollback(ctx)
254+
if err != nil && !errors.Is(err, pgx.ErrTxClosed) {
255+
// TODO: log the rollback error (add proper logging)
256+
_ = err
257+
}
258+
}()
259+
260+
querier := sqlc.New(tx)
261+
262+
servers, err := querierFunc(ctx, querier)
263+
if err != nil {
264+
return nil, err
265+
}
266+
267+
ids := make([]uuid.UUID, len(servers))
268+
for i, server := range servers {
269+
ids[i] = server.ID
270+
}
271+
272+
packages, err := querier.ListServerPackages(ctx, ids)
273+
if err != nil {
274+
return nil, err
275+
}
276+
packagesMap := make(map[uuid.UUID][]sqlc.McpServerPackage)
277+
for _, pkg := range packages {
278+
packagesMap[pkg.ServerID] = append(packagesMap[pkg.ServerID], pkg)
279+
}
280+
281+
remotes, err := querier.ListServerRemotes(ctx, ids)
282+
if err != nil {
283+
return nil, err
284+
}
285+
remotesMap := make(map[uuid.UUID][]sqlc.McpServerRemote)
286+
for _, remote := range remotes {
287+
remotesMap[remote.ServerID] = append(remotesMap[remote.ServerID], remote)
288+
}
289+
290+
result := make([]*upstreamv0.ServerJSON, 0, len(servers))
291+
for _, dbServer := range servers {
292+
server := helperToServer(
293+
dbServer,
294+
packagesMap[dbServer.ID],
295+
remotesMap[dbServer.ID],
296+
)
297+
result = append(result, &server)
298+
}
299+
300+
return result, nil
301+
}

0 commit comments

Comments
 (0)