Skip to content

Commit dcbd6d9

Browse files
authored
Allow overriding client_id for token exchange (#20571)
1 parent 6ccb4ad commit dcbd6d9

File tree

5 files changed

+152
-53
lines changed

5 files changed

+152
-53
lines changed

sdk/identity/azure-identity/azure/identity/_constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,4 @@ class EnvironmentVariables:
4848
AZURE_REGIONAL_AUTHORITY_NAME = "AZURE_REGIONAL_AUTHORITY_NAME"
4949

5050
AZURE_FEDERATED_TOKEN_FILE = "AZURE_FEDERATED_TOKEN_FILE"
51-
TOKEN_EXCHANGE_VARS = (AZURE_CLIENT_ID, AZURE_TENANT_ID, AZURE_FEDERATED_TOKEN_FILE)
51+
TOKEN_EXCHANGE_VARS = (AZURE_AUTHORITY_HOST, AZURE_TENANT_ID, AZURE_FEDERATED_TOKEN_FILE)

sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,13 @@ def __init__(self, **kwargs):
6969
_LOGGER.info("%s will use token exchange", self.__class__.__name__)
7070
from .token_exchange import TokenExchangeCredential
7171

72+
client_id = kwargs.pop("client_id", None) or os.environ.get(EnvironmentVariables.AZURE_CLIENT_ID)
73+
if not client_id:
74+
raise ValueError('Configure the environment with a client ID or pass a value for "client_id" argument')
75+
7276
self._credential = TokenExchangeCredential(
7377
tenant_id=os.environ[EnvironmentVariables.AZURE_TENANT_ID],
74-
client_id=os.environ[EnvironmentVariables.AZURE_CLIENT_ID],
78+
client_id=client_id,
7579
token_file_path=os.environ[EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE],
7680
**kwargs
7781
)

sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,13 @@ def __init__(self, **kwargs: "Any") -> None:
6666
_LOGGER.info("%s will use token exchange", self.__class__.__name__)
6767
from .token_exchange import TokenExchangeCredential
6868

69+
client_id = kwargs.pop("client_id", None) or os.environ.get(EnvironmentVariables.AZURE_CLIENT_ID)
70+
if not client_id:
71+
raise ValueError('Configure the environment with a client ID or pass a value for "client_id" argument')
72+
6973
self._credential = TokenExchangeCredential(
7074
tenant_id=os.environ[EnvironmentVariables.AZURE_TENANT_ID],
71-
client_id=os.environ[EnvironmentVariables.AZURE_CLIENT_ID],
75+
client_id=client_id,
7276
token_file_path=os.environ[EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE],
7377
**kwargs
7478
)

sdk/identity/azure-identity/tests/test_managed_identity.py

Lines changed: 71 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
},
3232
{EnvironmentVariables.IDENTITY_ENDPOINT: "...", EnvironmentVariables.IMDS_ENDPOINT: "..."}, # Arc
3333
{ # token exchange
34+
EnvironmentVariables.AZURE_AUTHORITY_HOST: "https://localhost",
3435
EnvironmentVariables.AZURE_CLIENT_ID: "...",
3536
EnvironmentVariables.AZURE_TENANT_ID: "...",
3637
EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE: __file__,
@@ -73,24 +74,6 @@ def test_context_manager_incomplete_configuration():
7374
pass
7475

7576

76-
ALL_ENVIRONMENTS = (
77-
{EnvironmentVariables.MSI_ENDPOINT: "...", EnvironmentVariables.MSI_SECRET: "..."}, # App Service
78-
{EnvironmentVariables.MSI_ENDPOINT: "..."}, # Cloud Shell
79-
{ # Service Fabric
80-
EnvironmentVariables.IDENTITY_ENDPOINT: "...",
81-
EnvironmentVariables.IDENTITY_HEADER: "...",
82-
EnvironmentVariables.IDENTITY_SERVER_THUMBPRINT: "...",
83-
},
84-
{EnvironmentVariables.IDENTITY_ENDPOINT: "...", EnvironmentVariables.IMDS_ENDPOINT: "..."}, # Arc
85-
{ # token exchange
86-
EnvironmentVariables.AZURE_CLIENT_ID: "...",
87-
EnvironmentVariables.AZURE_TENANT_ID: "...",
88-
EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE: __file__,
89-
},
90-
{}, # IMDS
91-
)
92-
93-
9477
@pytest.mark.parametrize("environ", ALL_ENVIRONMENTS)
9578
def test_custom_hooks(environ):
9679
"""The credential's pipeline should include azure-core's CustomHookPolicy"""
@@ -790,10 +773,21 @@ def test_token_exchange(tmpdir):
790773
token_file.write(exchange_token)
791774
access_token = "***"
792775
authority = "https://localhost"
793-
client_id = "client_id"
776+
default_client_id = "default_client_id"
794777
tenant = "tenant_id"
795778
scope = "scope"
796779

780+
success_response = mock_response(
781+
json_payload={
782+
"access_token": access_token,
783+
"expires_in": 3600,
784+
"ext_expires_in": 3600,
785+
"expires_on": int(time.time()) + 3600,
786+
"not_before": int(time.time()),
787+
"resource": scope,
788+
"token_type": "Bearer",
789+
}
790+
)
797791
transport = validating_transport(
798792
requests=[
799793
Request(
@@ -802,38 +796,81 @@ def test_token_exchange(tmpdir):
802796
required_data={
803797
"client_assertion": exchange_token,
804798
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
805-
"client_id": client_id,
799+
"client_id": default_client_id,
806800
"grant_type": "client_credentials",
807801
"scope": scope,
808802
},
809803
)
810804
],
811-
responses=[
812-
mock_response(
813-
json_payload={
814-
"access_token": access_token,
815-
"expires_in": 3600,
816-
"ext_expires_in": 3600,
817-
"expires_on": int(time.time()) + 3600,
818-
"not_before": int(time.time()),
819-
"resource": scope,
820-
"token_type": "Bearer",
821-
}
805+
responses=[success_response],
806+
)
807+
808+
mock_environ = {
809+
EnvironmentVariables.AZURE_AUTHORITY_HOST: authority,
810+
EnvironmentVariables.AZURE_CLIENT_ID: default_client_id,
811+
EnvironmentVariables.AZURE_TENANT_ID: tenant,
812+
EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE: token_file.strpath,
813+
}
814+
# credential should default to AZURE_CLIENT_ID
815+
with mock.patch.dict("os.environ", mock_environ, clear=True):
816+
credential = ManagedIdentityCredential(transport=transport)
817+
token = credential.get_token(scope)
818+
assert token.token == access_token
819+
820+
# client_id kwarg should override AZURE_CLIENT_ID
821+
nondefault_client_id = "non" + default_client_id
822+
transport = validating_transport(
823+
requests=[
824+
Request(
825+
base_url=authority,
826+
method="POST",
827+
required_data={
828+
"client_assertion": exchange_token,
829+
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
830+
"client_id": nondefault_client_id,
831+
"grant_type": "client_credentials",
832+
"scope": scope,
833+
},
822834
)
823835
],
836+
responses=[success_response],
837+
)
838+
839+
with mock.patch.dict("os.environ", mock_environ, clear=True):
840+
credential = ManagedIdentityCredential(client_id=nondefault_client_id, transport=transport)
841+
token = credential.get_token(scope)
842+
assert token.token == access_token
843+
844+
# AZURE_CLIENT_ID may not have a value, in which case client_id is required
845+
transport = validating_transport(
846+
requests=[
847+
Request(
848+
base_url=authority,
849+
method="POST",
850+
required_data={
851+
"client_assertion": exchange_token,
852+
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
853+
"client_id": nondefault_client_id,
854+
"grant_type": "client_credentials",
855+
"scope": scope,
856+
},
857+
)
858+
],
859+
responses=[success_response],
824860
)
825861

826862
with mock.patch.dict(
827863
"os.environ",
828864
{
829865
EnvironmentVariables.AZURE_AUTHORITY_HOST: authority,
830-
EnvironmentVariables.AZURE_CLIENT_ID: client_id,
831866
EnvironmentVariables.AZURE_TENANT_ID: tenant,
832867
EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE: token_file.strpath,
833868
},
834869
clear=True,
835870
):
836-
credential = ManagedIdentityCredential(transport=transport)
837-
token = credential.get_token(scope)
871+
with pytest.raises(ValueError):
872+
ManagedIdentityCredential()
838873

874+
credential = ManagedIdentityCredential(client_id=nondefault_client_id, transport=transport)
875+
token = credential.get_token(scope)
839876
assert token.token == access_token

sdk/identity/azure-identity/tests/test_managed_identity_async.py

Lines changed: 70 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -731,10 +731,21 @@ async def test_token_exchange(tmpdir):
731731
token_file.write(exchange_token)
732732
access_token = "***"
733733
authority = "https://localhost"
734-
client_id = "client_id"
734+
default_client_id = "default_client_id"
735735
tenant = "tenant_id"
736736
scope = "scope"
737737

738+
success_response = mock_response(
739+
json_payload={
740+
"access_token": access_token,
741+
"expires_in": 3600,
742+
"ext_expires_in": 3600,
743+
"expires_on": int(time.time()) + 3600,
744+
"not_before": int(time.time()),
745+
"resource": scope,
746+
"token_type": "Bearer",
747+
}
748+
)
738749
transport = async_validating_transport(
739750
requests=[
740751
Request(
@@ -743,38 +754,81 @@ async def test_token_exchange(tmpdir):
743754
required_data={
744755
"client_assertion": exchange_token,
745756
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
746-
"client_id": client_id,
757+
"client_id": default_client_id,
747758
"grant_type": "client_credentials",
748759
"scope": scope,
749760
},
750761
)
751762
],
752-
responses=[
753-
mock_response(
754-
json_payload={
755-
"access_token": access_token,
756-
"expires_in": 3600,
757-
"ext_expires_in": 3600,
758-
"expires_on": int(time.time()) + 3600,
759-
"not_before": int(time.time()),
760-
"resource": scope,
761-
"token_type": "Bearer",
762-
}
763+
responses=[success_response],
764+
)
765+
766+
mock_environ = {
767+
EnvironmentVariables.AZURE_AUTHORITY_HOST: authority,
768+
EnvironmentVariables.AZURE_CLIENT_ID: default_client_id,
769+
EnvironmentVariables.AZURE_TENANT_ID: tenant,
770+
EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE: token_file.strpath,
771+
}
772+
# credential should default to AZURE_CLIENT_ID
773+
with mock.patch.dict("os.environ", mock_environ, clear=True):
774+
credential = ManagedIdentityCredential(transport=transport)
775+
token = await credential.get_token(scope)
776+
assert token.token == access_token
777+
778+
# client_id kwarg should override AZURE_CLIENT_ID
779+
nondefault_client_id = "non" + default_client_id
780+
transport = async_validating_transport(
781+
requests=[
782+
Request(
783+
base_url=authority,
784+
method="POST",
785+
required_data={
786+
"client_assertion": exchange_token,
787+
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
788+
"client_id": nondefault_client_id,
789+
"grant_type": "client_credentials",
790+
"scope": scope,
791+
},
763792
)
764793
],
794+
responses=[success_response],
795+
)
796+
797+
with mock.patch.dict("os.environ", mock_environ, clear=True):
798+
credential = ManagedIdentityCredential(client_id=nondefault_client_id, transport=transport)
799+
token = await credential.get_token(scope)
800+
assert token.token == access_token
801+
802+
# AZURE_CLIENT_ID may not have a value, in which case client_id is required
803+
transport = async_validating_transport(
804+
requests=[
805+
Request(
806+
base_url=authority,
807+
method="POST",
808+
required_data={
809+
"client_assertion": exchange_token,
810+
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
811+
"client_id": nondefault_client_id,
812+
"grant_type": "client_credentials",
813+
"scope": scope,
814+
},
815+
)
816+
],
817+
responses=[success_response],
765818
)
766819

767820
with mock.patch.dict(
768821
"os.environ",
769822
{
770823
EnvironmentVariables.AZURE_AUTHORITY_HOST: authority,
771-
EnvironmentVariables.AZURE_CLIENT_ID: client_id,
772824
EnvironmentVariables.AZURE_TENANT_ID: tenant,
773825
EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE: token_file.strpath,
774826
},
775827
clear=True,
776828
):
777-
credential = ManagedIdentityCredential(transport=transport)
778-
token = await credential.get_token(scope)
829+
with pytest.raises(ValueError):
830+
ManagedIdentityCredential()
779831

832+
credential = ManagedIdentityCredential(client_id=nondefault_client_id, transport=transport)
833+
token = await credential.get_token(scope)
780834
assert token.token == access_token

0 commit comments

Comments
 (0)