Skip to content

Commit a398b44

Browse files
committed
Add db-backed implementation of RegistryService
This change adds a db-backed implementation of the `RegistryService` interface. There are multiple public interfaces added in this change, specifically * `internal/service` which exposes the `RegistryService` interface with an `Option` type used to configure filtering and pagination of SQL statements. * `internal/service/db` which exposes means to create a `RegistryService` backed by a database. Testing is implemented as integration tests running containers in the same fashion as the `internal/db/sqlc` package. This might be revised in the future by using mocks, but (a) a mock based implementation is not going to be more succinct, and (b) does its job of covering most of the new code. The change set is pretty big, but more than 65% of it are tests.
1 parent 486398b commit a398b44

File tree

5 files changed

+1355
-2
lines changed

5 files changed

+1355
-2
lines changed

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ require (
8787
github.com/ianlancetaylor/demangle v0.0.0-20250417193237-f615e6bd150b // indirect
8888
github.com/jackc/pgpassfile v1.0.0 // indirect
8989
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
90+
github.com/jackc/puddle/v2 v2.2.2 // indirect
9091
github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect
9192
github.com/josharian/intern v1.0.0 // indirect
9293
github.com/json-iterator/go v1.1.12 // indirect

internal/api/extension/v0/routes.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,10 @@ func Router(svc service.RegistryService) http.Handler {
3535
r.Put("/registries/{registryName}", routes.upsertRegistry)
3636
r.Delete("/registries/{registryName}", routes.deleteRegistry)
3737

38-
r.Put("/registries/{registryName}/servers/{serverName}/versions/{version}", routes.upsertVersion)
39-
r.Delete("/registries/{registryName}/servers/{serverName}/versions/{version}", routes.deleteVersion)
38+
r.Route("/registries/{registryName}/servers/{serverName}", func(r chi.Router) {
39+
r.Put("/versions/{version}", routes.upsertVersion)
40+
r.Delete("/versions/{version}", routes.deleteVersion)
41+
})
4042

4143
return r
4244
}

internal/service/db/impl.go

Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
// Package database provides a database-backed implementation of the RegistryService interface
2+
package database
3+
4+
import (
5+
"context"
6+
"encoding/base64"
7+
"fmt"
8+
"time"
9+
10+
"github.com/google/uuid"
11+
"github.com/jackc/pgx/v5"
12+
"github.com/jackc/pgx/v5/pgxpool"
13+
upstreamv0 "github.com/modelcontextprotocol/registry/pkg/api/v0"
14+
toolhivetypes "github.com/stacklok/toolhive/pkg/registry/registry"
15+
16+
"github.com/stacklok/toolhive-registry-server/internal/db/sqlc"
17+
"github.com/stacklok/toolhive-registry-server/internal/service"
18+
)
19+
20+
// options holds configuration options for the database service
21+
type options struct {
22+
pool *pgxpool.Pool
23+
}
24+
25+
// Option is a functional option for configuring the database service
26+
type Option func(*options) error
27+
28+
// WithConnectionPool creates a new database-backed registry service with the
29+
// given pgx pool.
30+
func WithConnectionPool(pool *pgxpool.Pool) Option {
31+
return func(o *options) error {
32+
if pool == nil {
33+
return fmt.Errorf("pgx pool is required")
34+
}
35+
o.pool = pool
36+
return nil
37+
}
38+
}
39+
40+
// WithConnectionString creates a new database-backed registry service with
41+
// the given connection string.
42+
func WithConnectionString(connString string) Option {
43+
return func(o *options) error {
44+
if connString == "" {
45+
return fmt.Errorf("connection string is required")
46+
}
47+
48+
pool, err := pgxpool.New(context.Background(), connString)
49+
if err != nil {
50+
return fmt.Errorf("failed to create pgx pool: %w", err)
51+
}
52+
53+
o.pool = pool
54+
return nil
55+
}
56+
}
57+
58+
// dbService implements the RegistryService interface using a database backend
59+
type dbService struct {
60+
pool *pgxpool.Pool
61+
}
62+
63+
var _ service.RegistryService = (*dbService)(nil)
64+
65+
// New creates a new database-backed registry service with the given options
66+
func New(opts ...Option) (service.RegistryService, error) {
67+
o := &options{}
68+
69+
for _, opt := range opts {
70+
if err := opt(o); err != nil {
71+
return nil, err
72+
}
73+
}
74+
75+
return &dbService{
76+
pool: o.pool,
77+
}, nil
78+
}
79+
80+
// CheckReadiness checks if the service is ready to serve requests
81+
func (s *dbService) CheckReadiness(ctx context.Context) error {
82+
err := s.pool.Ping(ctx)
83+
if err != nil {
84+
return fmt.Errorf("failed to ping database: %w", err)
85+
}
86+
return nil
87+
}
88+
89+
// GetRegistry returns the registry data with metadata
90+
func (*dbService) GetRegistry(
91+
_ context.Context,
92+
) (*toolhivetypes.UpstreamRegistry, string, error) {
93+
return nil, "", service.ErrNotImplemented
94+
}
95+
96+
// ListServers returns all servers in the registry
97+
func (s *dbService) ListServers(
98+
ctx context.Context,
99+
opts ...service.Option[service.ListServersOptions],
100+
) ([]*upstreamv0.ServerJSON, error) {
101+
// TODO: implement
102+
options := &service.ListServersOptions{}
103+
for _, opt := range opts {
104+
if err := opt(options); err != nil {
105+
return nil, err
106+
}
107+
}
108+
109+
decoded, err := base64.StdEncoding.DecodeString(options.Cursor)
110+
if err != nil {
111+
return nil, err
112+
}
113+
nextTime, err := time.Parse(time.RFC3339, string(decoded))
114+
if err != nil {
115+
return nil, err
116+
}
117+
118+
querierFunc := func(querier sqlc.Querier) ([]helper, error) {
119+
servers, err := querier.ListServers(
120+
ctx,
121+
sqlc.ListServersParams{
122+
Next: &nextTime,
123+
Size: int64(options.Limit),
124+
},
125+
)
126+
127+
if err != nil {
128+
return nil, err
129+
}
130+
131+
helpers := make([]helper, len(servers))
132+
for i, server := range servers {
133+
helpers[i] = listServersRowToHelper(server)
134+
}
135+
136+
return helpers, nil
137+
}
138+
139+
return s.sharedListServers(ctx, querierFunc)
140+
}
141+
142+
// ListServerVersions implements RegistryService.ListServerVersions
143+
func (s *dbService) ListServerVersions(
144+
ctx context.Context,
145+
opts ...service.Option[service.ListServerVersionsOptions],
146+
) ([]*upstreamv0.ServerJSON, error) {
147+
options := &service.ListServerVersionsOptions{}
148+
for _, opt := range opts {
149+
if err := opt(options); err != nil {
150+
return nil, err
151+
}
152+
}
153+
154+
querierFunc := func(querier sqlc.Querier) ([]helper, error) {
155+
servers, err := querier.ListServerVersions(
156+
ctx,
157+
sqlc.ListServerVersionsParams{
158+
Name: options.Name,
159+
Next: options.Next,
160+
Prev: options.Prev,
161+
Size: int64(options.Limit),
162+
},
163+
)
164+
if err != nil {
165+
return nil, err
166+
}
167+
168+
helpers := make([]helper, len(servers))
169+
for i, server := range servers {
170+
helpers[i] = listServerVersionsRowToHelper(server)
171+
}
172+
173+
return helpers, nil
174+
}
175+
176+
return s.sharedListServers(ctx, querierFunc)
177+
}
178+
179+
// GetServer returns a specific server by name
180+
func (s *dbService) GetServerVersion(
181+
ctx context.Context,
182+
opts ...service.Option[service.GetServerVersionOptions],
183+
) (*upstreamv0.ServerJSON, error) {
184+
options := &service.GetServerVersionOptions{}
185+
for _, opt := range opts {
186+
if err := opt(options); err != nil {
187+
return nil, err
188+
}
189+
}
190+
191+
querierFunc := func(querier sqlc.Querier) ([]helper, error) {
192+
server, err := querier.GetServerVersion(
193+
ctx,
194+
sqlc.GetServerVersionParams{
195+
Name: options.Name,
196+
Version: options.Version,
197+
},
198+
)
199+
if err != nil {
200+
return nil, err
201+
}
202+
203+
return []helper{getServerVersionRowToHelper(server)}, nil
204+
}
205+
206+
res, err := s.sharedListServers(ctx, querierFunc)
207+
if err != nil {
208+
return nil, err
209+
}
210+
211+
return res[0], nil
212+
}
213+
214+
func (s *dbService) sharedListServers(
215+
ctx context.Context,
216+
querierFunc func(querier sqlc.Querier) ([]helper, error),
217+
) ([]*upstreamv0.ServerJSON, error) {
218+
tx, err := s.pool.BeginTx(ctx, pgx.TxOptions{
219+
IsoLevel: pgx.ReadCommitted,
220+
AccessMode: pgx.ReadOnly,
221+
})
222+
if err != nil {
223+
return nil, err
224+
}
225+
defer tx.Rollback(ctx)
226+
227+
querier := sqlc.New(tx)
228+
229+
servers, err := querierFunc(querier)
230+
if err != nil {
231+
return nil, err
232+
}
233+
234+
ids := make([]uuid.UUID, len(servers))
235+
for i, server := range servers {
236+
ids[i] = server.ID
237+
}
238+
239+
packages, err := querier.ListServerPackages(ctx, ids)
240+
if err != nil {
241+
return nil, err
242+
}
243+
packagesMap := make(map[uuid.UUID][]sqlc.McpServerPackage)
244+
for _, pkg := range packages {
245+
packagesMap[pkg.ServerID] = append(packagesMap[pkg.ServerID], pkg)
246+
}
247+
248+
remotes, err := querier.ListServerRemotes(ctx, ids)
249+
if err != nil {
250+
return nil, err
251+
}
252+
remotesMap := make(map[uuid.UUID][]sqlc.McpServerRemote)
253+
for _, remote := range remotes {
254+
remotesMap[remote.ServerID] = append(remotesMap[remote.ServerID], remote)
255+
}
256+
257+
result := make([]*upstreamv0.ServerJSON, 0, len(servers))
258+
for _, dbServer := range servers {
259+
server := helperToServer(
260+
dbServer,
261+
packagesMap[dbServer.ID],
262+
remotesMap[dbServer.ID],
263+
)
264+
result = append(result, &server)
265+
}
266+
267+
return result, nil
268+
}

0 commit comments

Comments
 (0)