31
31
},
32
32
{EnvironmentVariables .IDENTITY_ENDPOINT : "..." , EnvironmentVariables .IMDS_ENDPOINT : "..." }, # Arc
33
33
{ # token exchange
34
+ EnvironmentVariables .AZURE_AUTHORITY_HOST : "https://localhost" ,
34
35
EnvironmentVariables .AZURE_CLIENT_ID : "..." ,
35
36
EnvironmentVariables .AZURE_TENANT_ID : "..." ,
36
37
EnvironmentVariables .AZURE_FEDERATED_TOKEN_FILE : __file__ ,
@@ -73,24 +74,6 @@ def test_context_manager_incomplete_configuration():
73
74
pass
74
75
75
76
76
- ALL_ENVIRONMENTS = (
77
- {EnvironmentVariables .MSI_ENDPOINT : "..." , EnvironmentVariables .MSI_SECRET : "..." }, # App Service
78
- {EnvironmentVariables .MSI_ENDPOINT : "..." }, # Cloud Shell
79
- { # Service Fabric
80
- EnvironmentVariables .IDENTITY_ENDPOINT : "..." ,
81
- EnvironmentVariables .IDENTITY_HEADER : "..." ,
82
- EnvironmentVariables .IDENTITY_SERVER_THUMBPRINT : "..." ,
83
- },
84
- {EnvironmentVariables .IDENTITY_ENDPOINT : "..." , EnvironmentVariables .IMDS_ENDPOINT : "..." }, # Arc
85
- { # token exchange
86
- EnvironmentVariables .AZURE_CLIENT_ID : "..." ,
87
- EnvironmentVariables .AZURE_TENANT_ID : "..." ,
88
- EnvironmentVariables .AZURE_FEDERATED_TOKEN_FILE : __file__ ,
89
- },
90
- {}, # IMDS
91
- )
92
-
93
-
94
77
@pytest .mark .parametrize ("environ" , ALL_ENVIRONMENTS )
95
78
def test_custom_hooks (environ ):
96
79
"""The credential's pipeline should include azure-core's CustomHookPolicy"""
@@ -790,10 +773,21 @@ def test_token_exchange(tmpdir):
790
773
token_file .write (exchange_token )
791
774
access_token = "***"
792
775
authority = "https://localhost"
793
- client_id = "client_id "
776
+ default_client_id = "default_client_id "
794
777
tenant = "tenant_id"
795
778
scope = "scope"
796
779
780
+ success_response = mock_response (
781
+ json_payload = {
782
+ "access_token" : access_token ,
783
+ "expires_in" : 3600 ,
784
+ "ext_expires_in" : 3600 ,
785
+ "expires_on" : int (time .time ()) + 3600 ,
786
+ "not_before" : int (time .time ()),
787
+ "resource" : scope ,
788
+ "token_type" : "Bearer" ,
789
+ }
790
+ )
797
791
transport = validating_transport (
798
792
requests = [
799
793
Request (
@@ -802,38 +796,81 @@ def test_token_exchange(tmpdir):
802
796
required_data = {
803
797
"client_assertion" : exchange_token ,
804
798
"client_assertion_type" : "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" ,
805
- "client_id" : client_id ,
799
+ "client_id" : default_client_id ,
806
800
"grant_type" : "client_credentials" ,
807
801
"scope" : scope ,
808
802
},
809
803
)
810
804
],
811
- responses = [
812
- mock_response (
813
- json_payload = {
814
- "access_token" : access_token ,
815
- "expires_in" : 3600 ,
816
- "ext_expires_in" : 3600 ,
817
- "expires_on" : int (time .time ()) + 3600 ,
818
- "not_before" : int (time .time ()),
819
- "resource" : scope ,
820
- "token_type" : "Bearer" ,
821
- }
805
+ responses = [success_response ],
806
+ )
807
+
808
+ mock_environ = {
809
+ EnvironmentVariables .AZURE_AUTHORITY_HOST : authority ,
810
+ EnvironmentVariables .AZURE_CLIENT_ID : default_client_id ,
811
+ EnvironmentVariables .AZURE_TENANT_ID : tenant ,
812
+ EnvironmentVariables .AZURE_FEDERATED_TOKEN_FILE : token_file .strpath ,
813
+ }
814
+ # credential should default to AZURE_CLIENT_ID
815
+ with mock .patch .dict ("os.environ" , mock_environ , clear = True ):
816
+ credential = ManagedIdentityCredential (transport = transport )
817
+ token = credential .get_token (scope )
818
+ assert token .token == access_token
819
+
820
+ # client_id kwarg should override AZURE_CLIENT_ID
821
+ nondefault_client_id = "non" + default_client_id
822
+ transport = validating_transport (
823
+ requests = [
824
+ Request (
825
+ base_url = authority ,
826
+ method = "POST" ,
827
+ required_data = {
828
+ "client_assertion" : exchange_token ,
829
+ "client_assertion_type" : "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" ,
830
+ "client_id" : nondefault_client_id ,
831
+ "grant_type" : "client_credentials" ,
832
+ "scope" : scope ,
833
+ },
822
834
)
823
835
],
836
+ responses = [success_response ],
837
+ )
838
+
839
+ with mock .patch .dict ("os.environ" , mock_environ , clear = True ):
840
+ credential = ManagedIdentityCredential (client_id = nondefault_client_id , transport = transport )
841
+ token = credential .get_token (scope )
842
+ assert token .token == access_token
843
+
844
+ # AZURE_CLIENT_ID may not have a value, in which case client_id is required
845
+ transport = validating_transport (
846
+ requests = [
847
+ Request (
848
+ base_url = authority ,
849
+ method = "POST" ,
850
+ required_data = {
851
+ "client_assertion" : exchange_token ,
852
+ "client_assertion_type" : "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" ,
853
+ "client_id" : nondefault_client_id ,
854
+ "grant_type" : "client_credentials" ,
855
+ "scope" : scope ,
856
+ },
857
+ )
858
+ ],
859
+ responses = [success_response ],
824
860
)
825
861
826
862
with mock .patch .dict (
827
863
"os.environ" ,
828
864
{
829
865
EnvironmentVariables .AZURE_AUTHORITY_HOST : authority ,
830
- EnvironmentVariables .AZURE_CLIENT_ID : client_id ,
831
866
EnvironmentVariables .AZURE_TENANT_ID : tenant ,
832
867
EnvironmentVariables .AZURE_FEDERATED_TOKEN_FILE : token_file .strpath ,
833
868
},
834
869
clear = True ,
835
870
):
836
- credential = ManagedIdentityCredential ( transport = transport )
837
- token = credential . get_token ( scope )
871
+ with pytest . raises ( ValueError ):
872
+ ManagedIdentityCredential ( )
838
873
874
+ credential = ManagedIdentityCredential (client_id = nondefault_client_id , transport = transport )
875
+ token = credential .get_token (scope )
839
876
assert token .token == access_token
0 commit comments