diff --git a/msal/managed_identity.py b/msal/managed_identity.py index 6f85571d..78389e39 100644 --- a/msal/managed_identity.py +++ b/msal/managed_identity.py @@ -2,9 +2,12 @@ # All rights reserved. # # This code is licensed under the MIT License. +import calendar +import datetime import json import logging import os +import re import socket import sys import time @@ -432,6 +435,35 @@ def _obtain_token(http_client, managed_identity, resource): return _obtain_token_on_azure_vm(http_client, managed_identity, resource) +def _parse_expires_on(raw: str) -> int: + try: + return int(raw) # It is typically an epoch time + except ValueError: + pass + try: + # '2024-10-18T19:51:37.0000000+00:00' was observed in + # https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/issues/4963 + if sys.version_info < (3, 11): # Does not support 7-digit microseconds + raw = re.sub( # Strip microseconds portion using regex + r'(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2})(\.\d+)([+-]\d{2}:\d{2})', + r'\1\3', + raw) + return int(datetime.datetime.fromisoformat(raw).timestamp()) + except ValueError: + pass + for format in ( + "%m/%d/%Y %H:%M:%S %z", # Support "06/20/2019 02:57:58 +00:00" + # Derived from https://github.com/Azure/azure-sdk-for-python/blob/azure-identity_1.21.0/sdk/identity/azure-identity/azure/identity/_credentials/azure_ml.py#L52 + "%m/%d/%Y %I:%M:%S %p %z", # Support "1/16/2020 12:0:12 AM +00:00" + # Derived from https://github.com/Azure/azure-sdk-for-python/blob/azure-identity_1.21.0/sdk/identity/azure-identity/azure/identity/_credentials/azure_ml.py#L51 + ): + try: + return calendar.timegm(time.strptime(raw, format)) + except ValueError: + pass + raise ManagedIdentityError(f"Cannot parse expires_on: {raw}") + + def _adjust_param(params, managed_identity, types_mapping=None): # Modify the params dict in place id_name = (types_mapping or ManagedIdentity._types_mapping).get( @@ -504,7 +536,7 @@ def _obtain_token_on_app_service( if payload.get("access_token") and payload.get("expires_on"): return { # Normalizing the payload into OAuth2 format "access_token": payload["access_token"], - "expires_in": int(payload["expires_on"]) - int(time.time()), + "expires_in": _parse_expires_on(payload["expires_on"]) - int(time.time()), "resource": payload.get("resource"), "token_type": payload.get("token_type", "Bearer"), } @@ -538,7 +570,7 @@ def _obtain_token_on_machine_learning( if payload.get("access_token") and payload.get("expires_on"): return { # Normalizing the payload into OAuth2 format "access_token": payload["access_token"], - "expires_in": int(payload["expires_on"]) - int(time.time()), + "expires_in": _parse_expires_on(payload["expires_on"]) - int(time.time()), "resource": payload.get("resource"), "token_type": payload.get("token_type", "Bearer"), } diff --git a/tests/test_mi.py b/tests/test_mi.py index a7c2cb6c..be22a5ef 100644 --- a/tests/test_mi.py +++ b/tests/test_mi.py @@ -25,6 +25,7 @@ MACHINE_LEARNING, SERVICE_FABRIC, DEFAULT_TO_VM, + _parse_expires_on, ) from msal.token_cache import is_subdict_of @@ -49,6 +50,21 @@ def test_helper_class_should_be_interchangable_with_dict_which_could_be_loaded_f {"ManagedIdentityIdType": "SystemAssigned", "Id": None}) +class ExpiresOnTestCase(unittest.TestCase): + def test_expires_on_parsing(self): + for input, epoch in { + "1234567890": 1234567890, + "1970-01-01T00:00:12.0000000+00:00": 12, + "2024-10-18T19:51:37.0000000+00:00": 1729281097, # Copied from https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/issues/4963 + "01/01/1970 00:00:12 +00:00": 12, + "06/20/2019 02:57:58 +00:00": 1560999478, # Derived from https://github.com/Azure/azure-sdk-for-python/blob/azure-identity_1.21.0/sdk/identity/azure-identity/azure/identity/_credentials/azure_ml.py#L52 + "1/1/1970 12:0:12 AM +00:00": 12, + "1/1/1970 12:0:12 PM +00:00": 43212, + "1/16/2020 5:24:12 AM +00:00": 1579152252, # Derived from https://github.com/Azure/azure-sdk-for-python/blob/azure-identity_1.21.0/sdk/identity/azure-identity/azure/identity/_credentials/azure_ml.py#L51 + }.items(): + self.assertEqual(_parse_expires_on(input), epoch, f'Should parse "{input}" to {epoch}') + + class ClientTestCase(unittest.TestCase): maxDiff = None