@@ -4,6 +4,7 @@ package database
44import (
55 "context"
66 "encoding/base64"
7+ "errors"
78 "fmt"
89 "time"
910
@@ -17,6 +18,11 @@ import (
1718 "github.com/stacklok/toolhive-registry-server/internal/service"
1819)
1920
21+ var (
22+ // ErrBug is returned when a server is not found
23+ ErrBug = errors .New ("bug" )
24+ )
25+
2026// options holds configuration options for the database service
2127type options struct {
2228 pool * pgxpool.Pool
@@ -37,24 +43,6 @@ func WithConnectionPool(pool *pgxpool.Pool) Option {
3743 }
3844}
3945
40- // WithConnectionString creates a new database-backed registry service with
41- // the given connection string.
42- func WithConnectionString (connString string ) Option {
43- return func (o * options ) error {
44- if connString == "" {
45- return fmt .Errorf ("connection string is required" )
46- }
47-
48- pool , err := pgxpool .New (context .Background (), connString )
49- if err != nil {
50- return fmt .Errorf ("failed to create pgx pool: %w" , err )
51- }
52-
53- o .pool = pool
54- return nil
55- }
56- }
57-
5846// dbService implements the RegistryService interface using a database backend
5947type dbService struct {
6048 pool * pgxpool.Pool
@@ -98,7 +86,6 @@ func (s *dbService) ListServers(
9886 ctx context.Context ,
9987 opts ... service.Option [service.ListServersOptions ],
10088) ([]* upstreamv0.ServerJSON , error ) {
101- // TODO: implement
10289 options := & service.ListServersOptions {}
10390 for _ , opt := range opts {
10491 if err := opt (options ); err != nil {
@@ -108,22 +95,24 @@ func (s *dbService) ListServers(
10895
10996 decoded , err := base64 .StdEncoding .DecodeString (options .Cursor )
11097 if err != nil {
111- return nil , err
98+ return nil , fmt . Errorf ( "invalid cursor format: %w" , err )
11299 }
113100 nextTime , err := time .Parse (time .RFC3339 , string (decoded ))
114101 if err != nil {
115- return nil , err
102+ return nil , fmt . Errorf ( "invalid cursor format: %w" , err )
116103 }
117104
118- querierFunc := func (querier sqlc.Querier ) ([]helper , error ) {
105+ // Note: this function fetches a list of servers. In case no records are
106+ // found, the called function should return an empty slice as it's
107+ // customary in Go.
108+ querierFunc := func (ctx context.Context , querier sqlc.Querier ) ([]helper , error ) {
119109 servers , err := querier .ListServers (
120110 ctx ,
121111 sqlc.ListServersParams {
122112 Next : & nextTime ,
123113 Size : int64 (options .Limit ),
124114 },
125115 )
126-
127116 if err != nil {
128117 return nil , err
129118 }
@@ -151,7 +140,10 @@ func (s *dbService) ListServerVersions(
151140 }
152141 }
153142
154- querierFunc := func (querier sqlc.Querier ) ([]helper , error ) {
143+ // Note: this function fetches a list of server versions. In case no records are
144+ // found, the called function should return an empty slice as it's
145+ // customary in Go.
146+ querierFunc := func (ctx context.Context , querier sqlc.Querier ) ([]helper , error ) {
155147 servers , err := querier .ListServerVersions (
156148 ctx ,
157149 sqlc.ListServerVersionsParams {
@@ -188,7 +180,10 @@ func (s *dbService) GetServerVersion(
188180 }
189181 }
190182
191- querierFunc := func (querier sqlc.Querier ) ([]helper , error ) {
183+ // Note: this function fetches a single record given name and version.
184+ // In case no record is found, the called function should return an
185+ // `sql.ErrNoRows` error as it's customary in Go.
186+ querierFunc := func (ctx context.Context , querier sqlc.Querier ) ([]helper , error ) {
192187 server , err := querier .GetServerVersion (
193188 ctx ,
194189 sqlc.GetServerVersionParams {
@@ -208,25 +203,62 @@ func (s *dbService) GetServerVersion(
208203 return nil , err
209204 }
210205
206+ // Note: the `queryFunc` function is expected to return an error
207+ // sooner if no records are found, so getting this far with
208+ // a length result slice other than 1 means there's a bug.
209+ if len (res ) != 1 {
210+ return nil , fmt .Errorf ("%w: number of servers returned is not 1" , ErrBug )
211+ }
212+
211213 return res [0 ], nil
212214}
213215
216+ // querierFunction is a function that uses the given querier object to run the
217+ // main extraction. As of the time of this writing, its main use is accessing
218+ // the `mcp_server` table in a type-agnostic way. This is to overcome a
219+ // limitation of sqlc that prevents us from having the exact same go type
220+ // despite the fact that the underlying columns returned are the same.
221+ //
222+ // Note that the underlying table does not have to be the `mcp_server` table
223+ // as it is used now, as long as the result is a slice of helpers.
224+ type querierFunction func (ctx context.Context , querier sqlc.Querier ) ([]helper , error )
225+
226+ // sharedListServers is a helper function to list servers and mapping them to
227+ // the API schema.
228+ //
229+ // Its responsibilities are:
230+ // * Begin a transaction
231+ // * Execute the querier function
232+ // * List packages and remotes using the server IDs
233+ // * Map the results to the API schema
234+ // * Return the results
235+ //
236+ // The argument `querierFunc` is a function that uses the given querier object
237+ // to run the main extraction. Note that the underlying table does not have
238+ // to be the `mcp_server` table as it is used now, as long as the result is a
239+ // slice of helpers.
214240func (s * dbService ) sharedListServers (
215241 ctx context.Context ,
216- querierFunc func ( querier sqlc. Querier ) ([] helper , error ) ,
242+ querierFunc querierFunction ,
217243) ([]* upstreamv0.ServerJSON , error ) {
218244 tx , err := s .pool .BeginTx (ctx , pgx.TxOptions {
219245 IsoLevel : pgx .ReadCommitted ,
220246 AccessMode : pgx .ReadOnly ,
221247 })
222248 if err != nil {
223- return nil , err
249+ return nil , fmt . Errorf ( "failed to begin transaction: %w" , err )
224250 }
225- defer tx .Rollback (ctx )
251+ defer func () {
252+ err := tx .Rollback (ctx )
253+ if err != nil && ! errors .Is (err , pgx .ErrTxClosed ) {
254+ // TODO: log the rollback error (add proper logging)
255+ _ = err
256+ }
257+ }()
226258
227259 querier := sqlc .New (tx )
228260
229- servers , err := querierFunc (querier )
261+ servers , err := querierFunc (ctx , querier )
230262 if err != nil {
231263 return nil , err
232264 }
0 commit comments