@@ -2,6 +2,7 @@ package clients
2
2
3
3
import (
4
4
"bytes"
5
+ "context"
5
6
"encoding/base64"
6
7
"encoding/json"
7
8
"errors"
@@ -11,30 +12,33 @@ import (
11
12
"strings"
12
13
"sync"
13
14
15
+ "github.com/Azure/azure-sdk-for-go/sdk/azcore"
14
16
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
15
17
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
16
18
)
17
19
20
+ type ObjectIDProvider func (ctx context.Context ) (string , error )
21
+
18
22
type ResourceManagerAccount struct {
19
- tenantId * string
20
- subscriptionId * string
21
- objectId * string
22
- mutex * sync.Mutex
23
- client * Client
23
+ tenantId * string
24
+ subscriptionId * string
25
+ objectId * string
26
+ mutex * sync.Mutex
27
+ objectIDProvider ObjectIDProvider
24
28
}
25
29
26
- func NewResourceManagerAccount (client * Client ) ResourceManagerAccount {
30
+ func NewResourceManagerAccount (tenantId , subscriptionId string , provider ObjectIDProvider ) ResourceManagerAccount {
27
31
out := ResourceManagerAccount {
28
32
mutex : & sync.Mutex {},
29
33
}
30
- if client != nil && client . Account . tenantId != nil && * client . Account . tenantId != "" {
31
- out .tenantId = client . Account . tenantId
34
+ if tenantId != "" {
35
+ out .tenantId = & tenantId
32
36
}
33
- if client != nil && client . Account . subscriptionId != nil && * client . Account . subscriptionId != "" {
34
- out .subscriptionId = client . Account . subscriptionId
37
+ if subscriptionId != "" {
38
+ out .subscriptionId = & subscriptionId
35
39
}
36
40
// We lazy load object ID because it's not always needed and could cause a performance hit
37
- out .client = client
41
+ out .objectIDProvider = provider
38
42
return out
39
43
}
40
44
@@ -80,36 +84,29 @@ func (account *ResourceManagerAccount) GetSubscriptionId() string {
80
84
return * account .subscriptionId
81
85
}
82
86
83
- func (account * ResourceManagerAccount ) GetObjectId () string {
87
+ func (account * ResourceManagerAccount ) GetObjectId (ctx context. Context ) string {
84
88
account .mutex .Lock ()
85
89
defer account .mutex .Unlock ()
86
90
87
91
if account .objectId != nil {
88
92
return * account .objectId
89
93
}
90
94
91
- tok , err := account .client .Option .Cred .GetToken (account .client .StopContext , policy.TokenRequestOptions {
92
- TenantID : account .client .Option .TenantId ,
93
- Scopes : []string {account .client .Option .CloudCfg .Services [cloud .ResourceManager ].Endpoint + "/.default" }})
94
- if err != nil {
95
- log .Printf ("[DEBUG] Error getting requesting token from credentials: %s" , err )
96
- }
97
-
98
- if tok .Token == "" {
99
- err = account .loadSignedInUserFromAzCmd ()
100
- if err != nil {
101
- log .Printf ("[DEBUG] Error getting user object ID from az cli: %s" , err )
102
- }
103
- } else {
104
- cl , err := parseTokenClaims (tok .Token )
95
+ if account .objectIDProvider != nil {
96
+ objectId , err := account .objectIDProvider (ctx )
105
97
if err != nil {
106
- log .Printf ("[DEBUG] Error getting object id from token : %s" , err )
98
+ log .Printf ("[DEBUG] Error getting object ID : %s" , err )
107
99
}
108
- if cl != nil && cl .ObjectId != "" {
109
- account .objectId = & cl .ObjectId
100
+ if objectId != "" {
101
+ account .objectId = & objectId
102
+ return * account .objectId
110
103
}
111
104
}
112
105
106
+ err := account .loadSignedInUserFromAzCmd ()
107
+ if err != nil {
108
+ log .Printf ("[DEBUG] Error getting user object ID from az cli: %s" , err )
109
+ }
113
110
if account .objectId == nil {
114
111
log .Printf ("[DEBUG] No object ID found" )
115
112
return ""
@@ -215,3 +212,25 @@ type tokenClaims struct {
215
212
AppId string `json:"appid,omitempty"`
216
213
IdType string `json:"idtyp,omitempty"`
217
214
}
215
+
216
+ func ParsedTokenClaimsObjectIDProvider (cred azcore.TokenCredential , cloudCfg cloud.Configuration ) ObjectIDProvider {
217
+ return func (ctx context.Context ) (string , error ) {
218
+ tok , err := cred .GetToken (context .Background (), policy.TokenRequestOptions {
219
+ EnableCAE : true ,
220
+ Scopes : []string {cloudCfg .Services [cloud .ResourceManager ].Audience + "/.default" }})
221
+ if err != nil {
222
+ return "" , fmt .Errorf ("getting requesting token from credentials: %w" , err )
223
+ }
224
+ if tok .Token == "" {
225
+ return "" , errors .New ("token is empty" )
226
+ }
227
+ cl , err := parseTokenClaims (tok .Token )
228
+ if err != nil {
229
+ return "" , fmt .Errorf ("getting object id from token: %w" , err )
230
+ }
231
+ if cl == nil || cl .ObjectId == "" {
232
+ return "" , errors .New ("object id is empty" )
233
+ }
234
+ return cl .ObjectId , nil
235
+ }
236
+ }
0 commit comments