@@ -2,7 +2,10 @@ package s3
22
33import (
44 "context"
5+ "crypto/hmac"
6+ "crypto/sha256"
57 "errors"
8+ "fmt"
69 "sync"
710 "time"
811
@@ -17,18 +20,49 @@ const s3ExpressCacheCap = 100
1720
1821const s3ExpressRefreshWindow = 1 * time .Minute
1922
23+ type cacheKey struct {
24+ CredentialsHash string // hmac(sigv4 akid, sigv4 secret)
25+ Bucket string
26+ }
27+
28+ func (c cacheKey ) Slug () string {
29+ return fmt .Sprintf ("%s%s" , c .CredentialsHash , c .Bucket )
30+ }
31+
32+ type sessionCredsCache struct {
33+ mu sync.Mutex
34+ cache cache.Cache
35+ }
36+
37+ func (c * sessionCredsCache ) Get (key cacheKey ) (* aws.Credentials , bool ) {
38+ c .mu .Lock ()
39+ defer c .mu .Unlock ()
40+
41+ if v , ok := c .cache .Get (key ); ok {
42+ return v .(* aws.Credentials ), true
43+ }
44+ return nil , false
45+ }
46+
47+ func (c * sessionCredsCache ) Put (key cacheKey , creds * aws.Credentials ) {
48+ c .mu .Lock ()
49+ defer c .mu .Unlock ()
50+
51+ c .cache .Put (key , creds )
52+ }
53+
2054// The default S3Express provider uses an LRU cache with a capacity of 100.
2155//
2256// Credentials will be refreshed asynchronously when a Retrieve() call is made
2357// for cached credentials within an expiry window (1 minute, currently
2458// non-configurable).
2559type defaultS3ExpressCredentialsProvider struct {
26- mu sync.Mutex
2760 sf singleflight.Group
2861
2962 client createSessionAPIClient
30- credsCache cache. Cache
63+ cache * sessionCredsCache
3164 refreshWindow time.Duration
65+ v4creds aws.CredentialsProvider // underlying credentials used for CreateSession
3266}
3367
3468type createSessionAPIClient interface {
@@ -37,35 +71,54 @@ type createSessionAPIClient interface {
3771
3872func newDefaultS3ExpressCredentialsProvider () * defaultS3ExpressCredentialsProvider {
3973 return & defaultS3ExpressCredentialsProvider {
40- credsCache : lru .New (s3ExpressCacheCap ),
74+ cache : & sessionCredsCache {
75+ cache : lru .New (s3ExpressCacheCap ),
76+ },
4177 refreshWindow : s3ExpressRefreshWindow ,
4278 }
4379}
4480
81+ // returns a cloned provider using new base credentials, used when per-op
82+ // config mutations change the credentials provider
83+ func (p * defaultS3ExpressCredentialsProvider ) CloneWithBaseCredentials (v4creds aws.CredentialsProvider ) * defaultS3ExpressCredentialsProvider {
84+ return & defaultS3ExpressCredentialsProvider {
85+ client : p .client ,
86+ cache : p .cache ,
87+ refreshWindow : p .refreshWindow ,
88+ v4creds : v4creds ,
89+ }
90+ }
91+
4592func (p * defaultS3ExpressCredentialsProvider ) Retrieve (ctx context.Context , bucket string ) (aws.Credentials , error ) {
46- p .mu .Lock ()
47- defer p .mu .Unlock ()
93+ v4creds , err := p .v4creds .Retrieve (ctx )
94+ if err != nil {
95+ return aws.Credentials {}, fmt .Errorf ("get sigv4 creds: %w" , err )
96+ }
4897
49- creds , ok := p .getCacheCredentials (bucket )
98+ key := cacheKey {
99+ CredentialsHash : gethmac (v4creds .AccessKeyID , v4creds .SecretAccessKey ),
100+ Bucket : bucket ,
101+ }
102+ creds , ok := p .cache .Get (key )
50103 if ! ok || creds .Expired () {
51- return p .awaitDoChanRetrieve (ctx , bucket )
104+ return p .awaitDoChanRetrieve (ctx , key )
52105 }
53106
54107 if creds .Expires .Sub (sdk .NowTime ()) <= p .refreshWindow {
55- p .doChanRetrieve (ctx , bucket )
108+ p .doChanRetrieve (ctx , key )
56109 }
57110
58111 return * creds , nil
59112}
60113
61- func (p * defaultS3ExpressCredentialsProvider ) doChanRetrieve (ctx context.Context , bucket string ) <- chan singleflight.Result {
62- return p .sf .DoChan (bucket , func () (interface {}, error ) {
63- return p .retrieve (ctx , bucket )
114+ func (p * defaultS3ExpressCredentialsProvider ) doChanRetrieve (ctx context.Context , key cacheKey ) <- chan singleflight.Result {
115+ return p .sf .DoChan (key . Slug () , func () (interface {}, error ) {
116+ return p .retrieve (ctx , key )
64117 })
65118}
66119
67- func (p * defaultS3ExpressCredentialsProvider ) awaitDoChanRetrieve (ctx context.Context , bucket string ) (aws.Credentials , error ) {
68- ch := p .doChanRetrieve (ctx , bucket )
120+ func (p * defaultS3ExpressCredentialsProvider ) awaitDoChanRetrieve (ctx context.Context , key cacheKey ) (aws.Credentials , error ) {
121+ ch := p .doChanRetrieve (ctx , key )
69122
70123 select {
71124 case r := <- ch :
@@ -75,9 +128,9 @@ func (p *defaultS3ExpressCredentialsProvider) awaitDoChanRetrieve(ctx context.Co
75128 }
76129}
77130
78- func (p * defaultS3ExpressCredentialsProvider ) retrieve (ctx context.Context , bucket string ) (aws.Credentials , error ) {
131+ func (p * defaultS3ExpressCredentialsProvider ) retrieve (ctx context.Context , key cacheKey ) (aws.Credentials , error ) {
79132 resp , err := p .client .CreateSession (ctx , & CreateSessionInput {
80- Bucket : aws .String (bucket ),
133+ Bucket : aws .String (key . Bucket ),
81134 })
82135 if err != nil {
83136 return aws.Credentials {}, err
@@ -88,22 +141,10 @@ func (p *defaultS3ExpressCredentialsProvider) retrieve(ctx context.Context, buck
88141 return aws.Credentials {}, err
89142 }
90143
91- p .putCacheCredentials ( bucket , creds )
144+ p .cache . Put ( key , creds )
92145 return * creds , nil
93146}
94147
95- func (p * defaultS3ExpressCredentialsProvider ) getCacheCredentials (bucket string ) (* aws.Credentials , bool ) {
96- if v , ok := p .credsCache .Get (bucket ); ok {
97- return v .(* aws.Credentials ), true
98- }
99-
100- return nil , false
101- }
102-
103- func (p * defaultS3ExpressCredentialsProvider ) putCacheCredentials (bucket string , creds * aws.Credentials ) {
104- p .credsCache .Put (bucket , creds )
105- }
106-
107148func credentialsFromResponse (o * CreateSessionOutput ) (* aws.Credentials , error ) {
108149 if o .Credentials == nil {
109150 return nil , errors .New ("s3express session credentials unset" )
@@ -121,3 +162,9 @@ func credentialsFromResponse(o *CreateSessionOutput) (*aws.Credentials, error) {
121162 Expires : * o .Credentials .Expiration ,
122163 }, nil
123164}
165+
166+ func gethmac (p , key string ) string {
167+ hash := hmac .New (sha256 .New , []byte (key ))
168+ hash .Write ([]byte (p ))
169+ return string (hash .Sum (nil ))
170+ }
0 commit comments