Skip to content

Commit f1676d2

Browse files
committed
Arc test cases
1 parent 56eb958 commit f1676d2

File tree

2 files changed

+38
-19
lines changed

2 files changed

+38
-19
lines changed

msal/imds.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -284,15 +284,18 @@ def _obtain_token_on_arc(http_client, endpoint, resource):
284284
params={"api-version": "2020-06-01", "resource": resource},
285285
headers={"Metadata": "true", "Authorization": "Basic {}".format(secret)},
286286
)
287-
payload = json.loads(response.text)
288-
if payload.get("access_token") and payload.get("expires_in"):
289-
# Example: https://learn.microsoft.com/en-us/azure/azure-arc/servers/media/managed-identity-authentication/bash-token-output-example.png
290-
return {
291-
"access_token": payload["access_token"],
292-
"expires_in": int(payload["expires_in"]),
293-
"token_type": payload.get("token_type", "Bearer"),
294-
"resource": payload.get("resource"),
295-
}
287+
try:
288+
payload = json.loads(response.text)
289+
if payload.get("access_token") and payload.get("expires_in"):
290+
# Example: https://learn.microsoft.com/en-us/azure/azure-arc/servers/media/managed-identity-authentication/bash-token-output-example.png
291+
return {
292+
"access_token": payload["access_token"],
293+
"expires_in": int(payload["expires_in"]),
294+
"token_type": payload.get("token_type", "Bearer"),
295+
"resource": payload.get("resource"),
296+
}
297+
except ValueError: # Typically json.decoder.JSONDecodeError
298+
pass
296299
return {
297300
"error": "invalid_request",
298301
"error_description": response.text,

tests/test_mi.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
import os
3+
import sys
34
import time
45
import unittest
56
try:
@@ -141,7 +142,7 @@ def test_unified_api_service_should_ignore_unnecessary_client_id(self):
141142
{"ManagedIdentityIdType": "ClientId", "Id": "foo"},
142143
token_cache=TokenCache()))
143144

144-
def test_sf_service_error_should_be_normalized(self):
145+
def test_sf_error_should_be_normalized(self):
145146
raw_error = '''
146147
{"error": {
147148
"correlationId": "foo",
@@ -163,18 +164,33 @@ def test_sf_service_error_should_be_normalized(self):
163164
"IDENTITY_ENDPOINT": "http://localhost/token",
164165
"IMDS_ENDPOINT": "http://localhost",
165166
})
167+
@patch(
168+
"builtins.open" if sys.version_info.major >= 3 else "__builtin__.open",
169+
mock_open(read_data="secret")
170+
)
166171
class ArcTestCase(ClientTestCase):
172+
challenge = MinimalResponse(status_code=401, text="", headers={
173+
"WWW-Authenticate": "Basic realm=/tmp/foo",
174+
})
167175

168-
@patch("builtins.open", mock_open(read_data="secret"))
169176
def test_happy_path(self):
170177
with patch.object(self.app._http_client, "get", side_effect=[
171-
MinimalResponse(status_code=401, text="", headers={
172-
"WWW-Authenticate": "Basic realm=/tmp/foo",
173-
}),
174-
MinimalResponse(
175-
status_code=200,
176-
text='{"access_token": "AT", "expires_in": "1234", "resource": "R"}',
177-
),
178-
]) as mocked_method:
178+
self.challenge,
179+
MinimalResponse(
180+
status_code=200,
181+
text='{"access_token": "AT", "expires_in": "1234", "resource": "R"}',
182+
),
183+
]) as mocked_method:
179184
super(ArcTestCase, self)._test_happy_path(self.app, mocked_method)
180185

186+
def test_arc_error_should_be_normalized(self):
187+
with patch.object(self.app._http_client, "get", side_effect=[
188+
self.challenge,
189+
MinimalResponse(status_code=400, text="undefined"),
190+
]) as mocked_method:
191+
self.assertEqual({
192+
"error": "invalid_request",
193+
"error_description": "undefined",
194+
}, self.app.acquire_token(resource="R"))
195+
self.assertEqual({}, self.app._token_cache._cache)
196+

0 commit comments

Comments
 (0)