@@ -4,6 +4,7 @@ package database
44import (
55 "context"
66 "encoding/base64"
7+ "errors"
78 "fmt"
89 "time"
910
@@ -17,6 +18,11 @@ import (
1718 "github.com/stacklok/toolhive-registry-server/internal/service"
1819)
1920
21+ var (
22+ // ErrBug is returned when a server is not found
23+ ErrBug = errors .New ("bug" )
24+ )
25+
2026// options holds configuration options for the database service
2127type options struct {
2228 pool * pgxpool.Pool
@@ -26,7 +32,8 @@ type options struct {
2632type Option func (* options ) error
2733
2834// WithConnectionPool creates a new database-backed registry service with the
29- // given pgx pool.
35+ // given pgx pool. The caller is responsible for closing the pool when it is
36+ // done.
3037func WithConnectionPool (pool * pgxpool.Pool ) Option {
3138 return func (o * options ) error {
3239 if pool == nil {
@@ -37,24 +44,6 @@ func WithConnectionPool(pool *pgxpool.Pool) Option {
3744 }
3845}
3946
40- // WithConnectionString creates a new database-backed registry service with
41- // the given connection string.
42- func WithConnectionString (connString string ) Option {
43- return func (o * options ) error {
44- if connString == "" {
45- return fmt .Errorf ("connection string is required" )
46- }
47-
48- pool , err := pgxpool .New (context .Background (), connString )
49- if err != nil {
50- return fmt .Errorf ("failed to create pgx pool: %w" , err )
51- }
52-
53- o .pool = pool
54- return nil
55- }
56- }
57-
5847// dbService implements the RegistryService interface using a database backend
5948type dbService struct {
6049 pool * pgxpool.Pool
@@ -98,7 +87,6 @@ func (s *dbService) ListServers(
9887 ctx context.Context ,
9988 opts ... service.Option [service.ListServersOptions ],
10089) ([]* upstreamv0.ServerJSON , error ) {
101- // TODO: implement
10290 options := & service.ListServersOptions {}
10391 for _ , opt := range opts {
10492 if err := opt (options ); err != nil {
@@ -108,22 +96,24 @@ func (s *dbService) ListServers(
10896
10997 decoded , err := base64 .StdEncoding .DecodeString (options .Cursor )
11098 if err != nil {
111- return nil , err
99+ return nil , fmt . Errorf ( "invalid cursor format: %w" , err )
112100 }
113101 nextTime , err := time .Parse (time .RFC3339 , string (decoded ))
114102 if err != nil {
115- return nil , err
103+ return nil , fmt . Errorf ( "invalid cursor format: %w" , err )
116104 }
117105
118- querierFunc := func (querier sqlc.Querier ) ([]helper , error ) {
106+ // Note: this function fetches a list of servers. In case no records are
107+ // found, the called function should return an empty slice as it's
108+ // customary in Go.
109+ querierFunc := func (ctx context.Context , querier sqlc.Querier ) ([]helper , error ) {
119110 servers , err := querier .ListServers (
120111 ctx ,
121112 sqlc.ListServersParams {
122113 Next : & nextTime ,
123114 Size : int64 (options .Limit ),
124115 },
125116 )
126-
127117 if err != nil {
128118 return nil , err
129119 }
@@ -151,7 +141,10 @@ func (s *dbService) ListServerVersions(
151141 }
152142 }
153143
154- querierFunc := func (querier sqlc.Querier ) ([]helper , error ) {
144+ // Note: this function fetches a list of server versions. In case no records are
145+ // found, the called function should return an empty slice as it's
146+ // customary in Go.
147+ querierFunc := func (ctx context.Context , querier sqlc.Querier ) ([]helper , error ) {
155148 servers , err := querier .ListServerVersions (
156149 ctx ,
157150 sqlc.ListServerVersionsParams {
@@ -188,7 +181,10 @@ func (s *dbService) GetServerVersion(
188181 }
189182 }
190183
191- querierFunc := func (querier sqlc.Querier ) ([]helper , error ) {
184+ // Note: this function fetches a single record given name and version.
185+ // In case no record is found, the called function should return an
186+ // `sql.ErrNoRows` error as it's customary in Go.
187+ querierFunc := func (ctx context.Context , querier sqlc.Querier ) ([]helper , error ) {
192188 server , err := querier .GetServerVersion (
193189 ctx ,
194190 sqlc.GetServerVersionParams {
@@ -208,25 +204,62 @@ func (s *dbService) GetServerVersion(
208204 return nil , err
209205 }
210206
207+ // Note: the `queryFunc` function is expected to return an error
208+ // sooner if no records are found, so getting this far with
209+ // a length result slice other than 1 means there's a bug.
210+ if len (res ) != 1 {
211+ return nil , fmt .Errorf ("%w: number of servers returned is not 1" , ErrBug )
212+ }
213+
211214 return res [0 ], nil
212215}
213216
217+ // querierFunction is a function that uses the given querier object to run the
218+ // main extraction. As of the time of this writing, its main use is accessing
219+ // the `mcp_server` table in a type-agnostic way. This is to overcome a
220+ // limitation of sqlc that prevents us from having the exact same go type
221+ // despite the fact that the underlying columns returned are the same.
222+ //
223+ // Note that the underlying table does not have to be the `mcp_server` table
224+ // as it is used now, as long as the result is a slice of helpers.
225+ type querierFunction func (ctx context.Context , querier sqlc.Querier ) ([]helper , error )
226+
227+ // sharedListServers is a helper function to list servers and mapping them to
228+ // the API schema.
229+ //
230+ // Its responsibilities are:
231+ // * Begin a transaction
232+ // * Execute the querier function
233+ // * List packages and remotes using the server IDs
234+ // * Map the results to the API schema
235+ // * Return the results
236+ //
237+ // The argument `querierFunc` is a function that uses the given querier object
238+ // to run the main extraction. Note that the underlying table does not have
239+ // to be the `mcp_server` table as it is used now, as long as the result is a
240+ // slice of helpers.
214241func (s * dbService ) sharedListServers (
215242 ctx context.Context ,
216- querierFunc func ( querier sqlc. Querier ) ([] helper , error ) ,
243+ querierFunc querierFunction ,
217244) ([]* upstreamv0.ServerJSON , error ) {
218245 tx , err := s .pool .BeginTx (ctx , pgx.TxOptions {
219246 IsoLevel : pgx .ReadCommitted ,
220247 AccessMode : pgx .ReadOnly ,
221248 })
222249 if err != nil {
223- return nil , err
250+ return nil , fmt . Errorf ( "failed to begin transaction: %w" , err )
224251 }
225- defer tx .Rollback (ctx )
252+ defer func () {
253+ err := tx .Rollback (ctx )
254+ if err != nil && ! errors .Is (err , pgx .ErrTxClosed ) {
255+ // TODO: log the rollback error (add proper logging)
256+ _ = err
257+ }
258+ }()
226259
227260 querier := sqlc .New (tx )
228261
229- servers , err := querierFunc (querier )
262+ servers , err := querierFunc (ctx , querier )
230263 if err != nil {
231264 return nil , err
232265 }
0 commit comments