diff --git a/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTestConstants.java b/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTestConstants.java new file mode 100644 index 00000000..02f95754 --- /dev/null +++ b/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTestConstants.java @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.microsoft.aad.msal4j; + +class ManagedIdentityTestConstants { + // ID types + static final String CLIENT_ID = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"; + static final String RESOURCE_ID = "/subscriptions/ffa4aaa2-4444-4444-5555-e3ccedd3d046/resourcegroups/UAMI_group/providers/Microsoft.ManagedIdentityClient/userAssignedIdentities/UAMI"; + static final String OBJECT_ID = "593b2662-5af7-4a90-a9cb-5a9de615b82f"; + + // Resources + static final String RESOURCE = "https://management.azure.com"; + static final String RESOURCE_DEFAULT_SUFFIX = "https://management.azure.com/.default"; + + // Endpoints + static final String APP_SERVICE_ENDPOINT = "http://127.0.0.1:41564/msi/token"; + static final String IMDS_ENDPOINT = "http://169.254.169.254/metadata/identity/oauth2/token"; + static final String AZURE_ARC_ENDPOINT = "http://localhost:40342/metadata/identity/oauth2/token"; + static final String CLOUDSHELL_ENDPOINT = "http://localhost:40342/metadata/identity/oauth2/token"; + static final String SERVICE_FABRIC_ENDPOINT = "http://localhost:40342/metadata/identity/oauth2/token"; + + // Example responses + static final String RESPONSE_MALFORMED_JSON = "missing starting bracket \"access_token\":\"accesstoken\",\"token_type\":" + "\"Bearer\",\"client_id\":\"a bunch of problems}"; + static final String MSI_ERROR_RESPONSE_500 = "{\"statusCode\":\"500\",\"message\":\"An unexpected error occurred while fetching the AAD Token.\",\"correlationId\":\"7d0c9763-ff1d-4842-a3f3-6d49e64f4513\"}"; + static final String CLOUDSHELL_ERROR_RESPONSE = "{\"error\":{\"code\":\"AudienceNotSupported\",\"message\":\"Audience user.read is not a supported MSI token audience.\"}}"; + static final String MSI_ERROR_RESPONSE_NORETRY = "{\"statusCode\":\"123\",\"message\":\"Not one of the retryable error responses\",\"correlationId\":\"7d0c9763-ff1d-4842-a3f3-6d49e64f4513\"}"; +} \ No newline at end of file diff --git a/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTestDataProvider.java b/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTestDataProvider.java index d70a2555..4abb69c7 100644 --- a/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTestDataProvider.java +++ b/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTestDataProvider.java @@ -8,114 +8,111 @@ import java.util.stream.Stream; class ManagedIdentityTestDataProvider { - private static final String CLIENT_ID = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"; - private static final String RESOURCE_ID = "/subscriptions/ffa4aaa2-4444-4444-5555-e3ccedd3d046/resourcegroups/UAMI_group/providers/Microsoft.ManagedIdentityClient/userAssignedIdentities/UAMI"; - private static final String OBJECT_ID = "593b2662-5af7-4a90-a9cb-5a9de615b82f"; - public static Stream createData() { + static Stream createData() { return Stream.of( - Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTests.appServiceEndpoint, - ManagedIdentityTests.resource), - Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTests.appServiceEndpoint, - ManagedIdentityTests.resourceDefaultSuffix), - Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTests.cloudShellEndpoint, - ManagedIdentityTests.resource), - Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTests.cloudShellEndpoint, - ManagedIdentityTests.resourceDefaultSuffix), - Arguments.of(ManagedIdentitySourceType.AZURE_ARC, ManagedIdentityTests.azureArcEndpoint, - ManagedIdentityTests.resource), - Arguments.of(ManagedIdentitySourceType.AZURE_ARC, ManagedIdentityTests.azureArcEndpoint, - ManagedIdentityTests.resourceDefaultSuffix), - Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTests.IMDS_ENDPOINT, - ManagedIdentityTests.resource), - Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTests.IMDS_ENDPOINT, - ManagedIdentityTests.resourceDefaultSuffix), + Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTestConstants.APP_SERVICE_ENDPOINT, + ManagedIdentityTestConstants.RESOURCE), + Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTestConstants.APP_SERVICE_ENDPOINT, + ManagedIdentityTestConstants.RESOURCE_DEFAULT_SUFFIX), + Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTestConstants.CLOUDSHELL_ENDPOINT, + ManagedIdentityTestConstants.RESOURCE), + Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTestConstants.CLOUDSHELL_ENDPOINT, + ManagedIdentityTestConstants.RESOURCE_DEFAULT_SUFFIX), + Arguments.of(ManagedIdentitySourceType.AZURE_ARC, ManagedIdentityTestConstants.AZURE_ARC_ENDPOINT, + ManagedIdentityTestConstants.RESOURCE), + Arguments.of(ManagedIdentitySourceType.AZURE_ARC, ManagedIdentityTestConstants.AZURE_ARC_ENDPOINT, + ManagedIdentityTestConstants.RESOURCE_DEFAULT_SUFFIX), + Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTestConstants.IMDS_ENDPOINT, + ManagedIdentityTestConstants.RESOURCE), + Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTestConstants.IMDS_ENDPOINT, + ManagedIdentityTestConstants.RESOURCE_DEFAULT_SUFFIX), Arguments.of(ManagedIdentitySourceType.IMDS, null, - ManagedIdentityTests.resource), - Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTests.serviceFabricEndpoint, - ManagedIdentityTests.resource), - Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTests.serviceFabricEndpoint, - ManagedIdentityTests.resourceDefaultSuffix)); + ManagedIdentityTestConstants.RESOURCE), + Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTestConstants.SERVICE_FABRIC_ENDPOINT, + ManagedIdentityTestConstants.RESOURCE), + Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTestConstants.SERVICE_FABRIC_ENDPOINT, + ManagedIdentityTestConstants.RESOURCE_DEFAULT_SUFFIX)); } - public static Stream createDataUserAssigned() { + static Stream createDataUserAssigned() { return Stream.of( - Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTests.appServiceEndpoint, - ManagedIdentityId.userAssignedClientId(CLIENT_ID)), - Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTests.appServiceEndpoint, - ManagedIdentityId.userAssignedResourceId(RESOURCE_ID)), - Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTests.appServiceEndpoint, - ManagedIdentityId.userAssignedObjectId(OBJECT_ID)), + Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTestConstants.APP_SERVICE_ENDPOINT, + ManagedIdentityId.userAssignedClientId(ManagedIdentityTestConstants.CLIENT_ID)), + Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTestConstants.APP_SERVICE_ENDPOINT, + ManagedIdentityId.userAssignedResourceId(ManagedIdentityTestConstants.RESOURCE_ID)), + Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTestConstants.APP_SERVICE_ENDPOINT, + ManagedIdentityId.userAssignedObjectId(ManagedIdentityTestConstants.OBJECT_ID)), Arguments.of(ManagedIdentitySourceType.IMDS, null, - ManagedIdentityId.userAssignedClientId(CLIENT_ID)), + ManagedIdentityId.userAssignedClientId(ManagedIdentityTestConstants.CLIENT_ID)), Arguments.of(ManagedIdentitySourceType.IMDS, null, - ManagedIdentityId.userAssignedResourceId(RESOURCE_ID)), + ManagedIdentityId.userAssignedResourceId(ManagedIdentityTestConstants.RESOURCE_ID)), Arguments.of(ManagedIdentitySourceType.IMDS, null, - ManagedIdentityId.userAssignedObjectId(OBJECT_ID)), - Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTests.serviceFabricEndpoint, - ManagedIdentityId.userAssignedResourceId(CLIENT_ID)), - Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTests.serviceFabricEndpoint, - ManagedIdentityId.userAssignedResourceId(RESOURCE_ID)), - Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTests.serviceFabricEndpoint, - ManagedIdentityId.userAssignedObjectId(OBJECT_ID))); + ManagedIdentityId.userAssignedObjectId(ManagedIdentityTestConstants.OBJECT_ID)), + Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTestConstants.SERVICE_FABRIC_ENDPOINT, + ManagedIdentityId.userAssignedResourceId(ManagedIdentityTestConstants.CLIENT_ID)), + Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTestConstants.SERVICE_FABRIC_ENDPOINT, + ManagedIdentityId.userAssignedResourceId(ManagedIdentityTestConstants.RESOURCE_ID)), + Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTestConstants.SERVICE_FABRIC_ENDPOINT, + ManagedIdentityId.userAssignedObjectId(ManagedIdentityTestConstants.OBJECT_ID))); } - public static Stream createDataUserAssignedNotSupported() { + static Stream createDataUserAssignedNotSupported() { return Stream.of( - Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTests.cloudShellEndpoint, - ManagedIdentityId.userAssignedClientId(CLIENT_ID)), - Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTests.cloudShellEndpoint, - ManagedIdentityId.userAssignedResourceId(RESOURCE_ID)), - Arguments.of(ManagedIdentitySourceType.AZURE_ARC, ManagedIdentityTests.azureArcEndpoint, - ManagedIdentityId.userAssignedClientId(CLIENT_ID)), - Arguments.of(ManagedIdentitySourceType.AZURE_ARC, ManagedIdentityTests.azureArcEndpoint, - ManagedIdentityId.userAssignedResourceId(RESOURCE_ID))); + Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTestConstants.CLOUDSHELL_ENDPOINT, + ManagedIdentityId.userAssignedClientId(ManagedIdentityTestConstants.CLIENT_ID)), + Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTestConstants.CLOUDSHELL_ENDPOINT, + ManagedIdentityId.userAssignedResourceId(ManagedIdentityTestConstants.RESOURCE_ID)), + Arguments.of(ManagedIdentitySourceType.AZURE_ARC, ManagedIdentityTestConstants.AZURE_ARC_ENDPOINT, + ManagedIdentityId.userAssignedClientId(ManagedIdentityTestConstants.CLIENT_ID)), + Arguments.of(ManagedIdentitySourceType.AZURE_ARC, ManagedIdentityTestConstants.AZURE_ARC_ENDPOINT, + ManagedIdentityId.userAssignedResourceId(ManagedIdentityTestConstants.RESOURCE_ID))); } - public static Stream createDataWrongScope() { + static Stream createDataWrongScope() { return Stream.of( - Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTests.appServiceEndpoint, + Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTestConstants.APP_SERVICE_ENDPOINT, "user.read"), - Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTests.appServiceEndpoint, + Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTestConstants.APP_SERVICE_ENDPOINT, "https://management.core.windows.net//user_impersonation"), - Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTests.cloudShellEndpoint, + Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTestConstants.CLOUDSHELL_ENDPOINT, "user.read"), - Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTests.cloudShellEndpoint, + Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTestConstants.CLOUDSHELL_ENDPOINT, "https://management.core.windows.net//user_impersonation"), - Arguments.of(ManagedIdentitySourceType.AZURE_ARC, ManagedIdentityTests.azureArcEndpoint, + Arguments.of(ManagedIdentitySourceType.AZURE_ARC, ManagedIdentityTestConstants.AZURE_ARC_ENDPOINT, "user.read"), - Arguments.of(ManagedIdentitySourceType.AZURE_ARC, ManagedIdentityTests.azureArcEndpoint, + Arguments.of(ManagedIdentitySourceType.AZURE_ARC, ManagedIdentityTestConstants.AZURE_ARC_ENDPOINT, "https://management.core.windows.net//user_impersonation"), - Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTests.IMDS_ENDPOINT, + Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTestConstants.IMDS_ENDPOINT, "user.read"), - Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTests.IMDS_ENDPOINT, + Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTestConstants.IMDS_ENDPOINT, "https://management.core.windows.net//user_impersonation"), - Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTests.serviceFabricEndpoint, + Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTestConstants.SERVICE_FABRIC_ENDPOINT, "user.read"), - Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTests.serviceFabricEndpoint, + Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTestConstants.SERVICE_FABRIC_ENDPOINT, "https://management.core.windows.net//user_impersonation")); } - public static Stream createDataError() { + static Stream createDataError() { return Stream.of( - Arguments.of(ManagedIdentitySourceType.AZURE_ARC, ManagedIdentityTests.azureArcEndpoint), - Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTests.appServiceEndpoint), - Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTests.cloudShellEndpoint), - Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTests.IMDS_ENDPOINT), - Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTests.serviceFabricEndpoint)); + Arguments.of(ManagedIdentitySourceType.AZURE_ARC, ManagedIdentityTestConstants.AZURE_ARC_ENDPOINT), + Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTestConstants.APP_SERVICE_ENDPOINT), + Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTestConstants.CLOUDSHELL_ENDPOINT), + Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTestConstants.IMDS_ENDPOINT), + Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTestConstants.SERVICE_FABRIC_ENDPOINT)); } - public static Stream createDataGetSource() { + static Stream createDataGetSource() { return Stream.of( - Arguments.of(ManagedIdentitySourceType.AZURE_ARC, ManagedIdentityTests.azureArcEndpoint, ManagedIdentitySourceType.AZURE_ARC), - Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTests.appServiceEndpoint, ManagedIdentitySourceType.APP_SERVICE), - Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTests.cloudShellEndpoint, ManagedIdentitySourceType.CLOUD_SHELL), - Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTests.IMDS_ENDPOINT, ManagedIdentitySourceType.DEFAULT_TO_IMDS), + Arguments.of(ManagedIdentitySourceType.AZURE_ARC, ManagedIdentityTestConstants.AZURE_ARC_ENDPOINT, ManagedIdentitySourceType.AZURE_ARC), + Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTestConstants.APP_SERVICE_ENDPOINT, ManagedIdentitySourceType.APP_SERVICE), + Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTestConstants.CLOUDSHELL_ENDPOINT, ManagedIdentitySourceType.CLOUD_SHELL), + Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTestConstants.IMDS_ENDPOINT, ManagedIdentitySourceType.DEFAULT_TO_IMDS), Arguments.of(ManagedIdentitySourceType.IMDS, "", ManagedIdentitySourceType.DEFAULT_TO_IMDS), - Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTests.serviceFabricEndpoint, ManagedIdentitySourceType.SERVICE_FABRIC)); + Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTestConstants.SERVICE_FABRIC_ENDPOINT, ManagedIdentitySourceType.SERVICE_FABRIC)); } - public static Stream createInvalidClaimsData() { + static Stream createInvalidClaimsData() { return Stream.of( Arguments.of("invalid json format"), Arguments.of("{\"access_token\": }") diff --git a/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTests.java b/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTests.java index 8ea74e5a..a5093b3a 100644 --- a/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTests.java +++ b/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTests.java @@ -27,10 +27,11 @@ import java.util.concurrent.ExecutionException; import static com.microsoft.aad.msal4j.ManagedIdentitySourceType.*; -import static com.microsoft.aad.msal4j.MsalError.*; +import static com.microsoft.aad.msal4j.MsalError.MANAGED_IDENTITY_FILE_READ_ERROR; +import static com.microsoft.aad.msal4j.MsalError.MANAGED_IDENTITY_REQUEST_FAILED; import static com.microsoft.aad.msal4j.MsalErrorMessage.*; import static java.util.Collections.*; -import static org.apache.http.HttpStatus.*; +import static org.apache.http.HttpStatus.SC_UNAUTHORIZED; import static org.junit.jupiter.api.Assertions.*; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.*; @@ -39,15 +40,6 @@ @TestInstance(TestInstance.Lifecycle.PER_METHOD) class ManagedIdentityTests { - static final String resource = "https://management.azure.com"; - final static String resourceDefaultSuffix = "https://management.azure.com/.default"; - final static String appServiceEndpoint = "http://127.0.0.1:41564/msi/token"; - final static String IMDS_ENDPOINT = "http://169.254.169.254/metadata/identity/oauth2/token"; - final static String azureArcEndpoint = "http://localhost:40342/metadata/identity/oauth2/token"; - final static String cloudShellEndpoint = "http://localhost:40342/metadata/identity/oauth2/token"; - final static String serviceFabricEndpoint = "http://localhost:40342/metadata/identity/oauth2/token"; - private static ManagedIdentityApplication miApp; - private String getSuccessfulResponse(String resource) { long expiresOn = (System.currentTimeMillis() / 1000) + (24 * 3600);//A long-lived, 24 hour token return "{\"access_token\":\"accesstoken\",\"expires_on\":\"" + expiresOn + "\",\"resource\":\"" + resource + "\",\"token_type\":" + @@ -60,89 +52,93 @@ private String getSuccessfulResponseWithISOExpiresOn(String resource) { "\"Bearer\",\"client_id\":\"client_id\"}"; } - private String getSuccessfulResponseWithInvalidJson() { - return "missing starting bracket \"access_token\":\"accesstoken\",\"token_type\":" + "\"Bearer\",\"client_id\":\"a bunch of problems}"; - } - - private String getMsiErrorResponse() { - return "{\"statusCode\":\"500\",\"message\":\"An unexpected error occured while fetching the AAD Token.\",\"correlationId\":\"7d0c9763-ff1d-4842-a3f3-6d49e64f4513\"}"; - } - - //Cloud Shell error responses follow a different style, the error info is in a second JSON - private String getMsiErrorResponseCloudShell() { - return "{\"error\":{\"code\":\"AudienceNotSupported\",\"message\":\"Audience user.read is not a supported MSI token audience.\"}}"; - } - - private String getMsiErrorResponseNoRetry() { - return "{\"statusCode\":\"123\",\"message\":\"Not one of the retryable error responses\",\"correlationId\":\"7d0c9763-ff1d-4842-a3f3-6d49e64f4513\"}"; - } - - private HttpRequest expectedRequest(ManagedIdentitySourceType source, String resource, boolean hasClaims, boolean hasCapabilities, String expectedTokenHash) { - return expectedRequest(source, resource, ManagedIdentityId.systemAssigned(), hasClaims, hasCapabilities, expectedTokenHash); + private HttpRequest expectedRequest(ManagedIdentitySourceType source, String resource) { + return expectedRequest(source, resource, ManagedIdentityId.systemAssigned(), false, false, null); } private HttpRequest expectedRequest(ManagedIdentitySourceType source, String resource, ManagedIdentityId id) { return expectedRequest(source, resource, id, false, false, null); } - private HttpRequest expectedRequest(ManagedIdentitySourceType source, String resource) { - return expectedRequest(source, resource, ManagedIdentityId.systemAssigned(), false, false, null); + private HttpRequest expectedRequest(ManagedIdentitySourceType source, String resource, + boolean hasClaims, boolean hasCapabilities, String expectedTokenHash) { + return expectedRequest(source, resource, ManagedIdentityId.systemAssigned(), hasClaims, hasCapabilities, expectedTokenHash); } private HttpRequest expectedRequest(ManagedIdentitySourceType source, String resource, - ManagedIdentityId id, boolean hasClaims, boolean hasCapabilities, String expectedTokenHash) { - String endpoint = null; + ManagedIdentityId id, boolean hasClaims, boolean hasCapabilities, String expectedTokenHash) { + // Create maps for headers and query parameters Map headers = new HashMap<>(); Map> queryParameters = new HashMap<>(); + // Add resource to query parameters (common for all sources) + queryParameters.put("resource", singletonList(resource)); + + // Handle claims and capabilities if supported if (Constants.TOKEN_REVOCATION_SUPPORTED_ENVIRONMENTS.contains(source)) { if (hasCapabilities) { - queryParameters.put(Constants.CLIENT_CAPABILITY_REQUEST_PARAM, Collections.singletonList("cp1")); + queryParameters.put(Constants.CLIENT_CAPABILITY_REQUEST_PARAM, singletonList("cp1")); } - if (hasClaims) { - queryParameters.put(Constants.TOKEN_HASH_CLAIM, Collections.singletonList(expectedTokenHash)); + queryParameters.put(Constants.TOKEN_HASH_CLAIM, singletonList(expectedTokenHash)); } } + // Configure source-specific parameters + String endpoint = configureSourceSpecificParameters(source, headers, queryParameters); + + // Configure idType-specific parameters + if (id.getIdType() != ManagedIdentityId.systemAssigned().getIdType()) { + configureIdentitySpecificParameters(id, queryParameters); + } + + if (!queryParameters.isEmpty()) { + endpoint = endpoint + "?" + URLUtils.serializeParameters(queryParameters); + } + + return new HttpRequest(HttpMethod.GET, endpoint, headers); + } + + private String configureSourceSpecificParameters(ManagedIdentitySourceType source, + Map headers, + Map> queryParameters) { switch (source) { case APP_SERVICE: - endpoint = appServiceEndpoint; - queryParameters.put("api-version", Collections.singletonList("2019-08-01")); - queryParameters.put("resource", Collections.singletonList(resource)); + queryParameters.put("api-version", singletonList("2019-08-01")); headers.put("X-IDENTITY-HEADER", "secret"); - break; + return ManagedIdentityTestConstants.APP_SERVICE_ENDPOINT; + case CLOUD_SHELL: - endpoint = cloudShellEndpoint; headers.put("ContentType", "application/x-www-form-urlencoded"); headers.put("Metadata", "true"); - queryParameters.put("resource", Collections.singletonList(resource)); - break; + return ManagedIdentityTestConstants.CLOUDSHELL_ENDPOINT; + case AZURE_ARC: - endpoint = azureArcEndpoint; - queryParameters.put("api-version", Collections.singletonList("2019-11-01")); - queryParameters.put("resource", Collections.singletonList(resource)); + queryParameters.put("api-version", singletonList("2019-11-01")); headers.put("Metadata", "true"); - break; + return ManagedIdentityTestConstants.AZURE_ARC_ENDPOINT; + case SERVICE_FABRIC: - endpoint = serviceFabricEndpoint; - queryParameters.put("api-version", Collections.singletonList("2019-07-01-preview")); - queryParameters.put("resource", Collections.singletonList(resource)); + queryParameters.put("api-version", singletonList("2019-07-01-preview")); headers.put("secret", "secret"); - break; + return ManagedIdentityTestConstants.SERVICE_FABRIC_ENDPOINT; + case IMDS: case NONE: case DEFAULT_TO_IMDS: - endpoint = IMDS_ENDPOINT; - queryParameters.put("api-version", Collections.singletonList("2018-02-01")); - queryParameters.put("resource", Collections.singletonList(resource)); + default: + queryParameters.put("api-version", singletonList("2018-02-01")); headers.put("Metadata", "true"); - break; + return ManagedIdentityTestConstants.IMDS_ENDPOINT; } + } + private void configureIdentitySpecificParameters(ManagedIdentityId id, Map> queryParameters) { switch (id.getIdType()) { + case SYSTEM_ASSIGNED: + break; case CLIENT_ID: - queryParameters.put("client_id", Collections.singletonList(id.getUserAssignedId())); + queryParameters.put("client_id", singletonList(id.getUserAssignedId())); break; case RESOURCE_ID: if (ManagedIdentityClient.getManagedIdentitySource() == ManagedIdentitySourceType.IMDS) { @@ -154,19 +150,9 @@ private HttpRequest expectedRequest(ManagedIdentitySourceType source, String res case OBJECT_ID: queryParameters.put("object_id", singletonList(id.getUserAssignedId())); break; + default: + throw new IllegalStateException("Unexpected value: " + id.getIdType()); } - - return new HttpRequest(HttpMethod.GET, computeUri(endpoint, queryParameters), headers); - } - - private String computeUri(String endpoint, Map> queryParameters) { - if (queryParameters.isEmpty()) { - return endpoint; - } - - String queryString = URLUtils.serializeParameters(queryParameters); - - return endpoint + "?" + queryString; } private HttpResponse expectedResponse(int statusCode, String response) { @@ -177,711 +163,496 @@ private HttpResponse expectedResponse(int statusCode, String response) { return httpResponse; } - @ParameterizedTest - @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataGetSource") - void managedIdentity_GetManagedIdentitySource(ManagedIdentitySourceType source, String endpoint, ManagedIdentitySourceType expectedSource) { - IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); - ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); + abstract class BaseManagedIdentityTest { + protected ManagedIdentityApplication miApp; + protected DefaultHttpClient httpClientMock; + protected IEnvironmentVariables environmentVariables; - miApp = ManagedIdentityApplication - .builder(ManagedIdentityId.systemAssigned()) - .build(); - - ManagedIdentitySourceType miClientSourceType = ManagedIdentityClient.getManagedIdentitySource(); - ManagedIdentitySourceType miAppSourceType = ManagedIdentityApplication.getManagedIdentitySource(); - assertEquals(expectedSource, miClientSourceType); - assertEquals(expectedSource, miAppSourceType); - } - - @ParameterizedTest - @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createData") - void managedIdentityTest_SystemAssigned_SuccessfulResponse(ManagedIdentitySourceType source, String endpoint, String resource) throws Exception { - IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); - ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); - DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); - if (source == SERVICE_FABRIC) { - ServiceFabricManagedIdentitySource.setHttpClient(httpClientMock); + void setUpCommonTest(ManagedIdentitySourceType source, String endpoint, ManagedIdentityId idType) { + initEnvironmentVariables(source, endpoint); + initHttpClientMock(source); + initManagedIdentityApplication(idType); } - when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); - - miApp = ManagedIdentityApplication - .builder(ManagedIdentityId.systemAssigned()) - .httpClient(httpClientMock) - .build(); - - // Clear caching to avoid cross test pollution. - miApp.tokenCache().accessTokens.clear(); - - IAuthenticationResult result = miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .build()).get(); - - assertNotNull(result.accessToken()); - assertEquals(TokenSource.IDENTITY_PROVIDER, result.metadata().tokenSource()); + void initEnvironmentVariables(ManagedIdentitySourceType source, String endpoint) { + environmentVariables = new EnvironmentVariablesHelper(source, endpoint); + ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); + } - result = miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .build()).get(); + void initHttpClientMock(ManagedIdentitySourceType source) { + httpClientMock = mock(DefaultHttpClient.class); + if (source == SERVICE_FABRIC) { + ServiceFabricManagedIdentitySource.setHttpClient(httpClientMock); + } + } - assertNotNull(result.accessToken()); - assertEquals(TokenSource.CACHE, result.metadata().tokenSource()); - verify(httpClientMock, times(1)).send(any()); - } + void initManagedIdentityApplication(ManagedIdentityId idType) { + miApp = ManagedIdentityApplication + .builder(idType) + .httpClient(httpClientMock) + .build(); - @ParameterizedTest - @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createData") - void managedIdentityTest_SuccessfulResponse_WithInvalidJson(ManagedIdentitySourceType source, String endpoint, String resource) throws Exception { - IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); - ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); - DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); - if (source == SERVICE_FABRIC) { - ServiceFabricManagedIdentitySource.setHttpClient(httpClientMock); + // ManagedIdentityApplication uses a static token cache, avoid cross test pollution by clearing it + miApp.tokenCache().accessTokens.clear(); } - when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(200, getSuccessfulResponseWithInvalidJson())); - - miApp = ManagedIdentityApplication - .builder(ManagedIdentityId.systemAssigned()) - .httpClient(httpClientMock) - .build(); + void setUpTestWithoutHttpClientMock(ManagedIdentitySourceType source, String endpoint) { + initEnvironmentVariables(source, endpoint); - // Clear caching to avoid cross test pollution. - miApp.tokenCache().accessTokens.clear(); + miApp = ManagedIdentityApplication + .builder(ManagedIdentityId.systemAssigned()) + .build(); - try { - miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .build()).get(); - fail("MsalServiceException is expected but not thrown."); - } catch (ExecutionException exception) { - assert(exception.getCause() instanceof MsalJsonParsingException); + // ManagedIdentityApplication uses a static token cache, avoid cross test pollution by clearing it + miApp.tokenCache().accessTokens.clear(); + } - MsalJsonParsingException miException = (MsalJsonParsingException) exception.getCause(); - assertEquals(source.name(), miException.managedIdentitySource()); - assertEquals(MsalError.MANAGED_IDENTITY_RESPONSE_PARSE_FAILURE, miException.errorCode()); + void assertTokenFromIdentityProvider(IAuthenticationResult result) { + assertNotNull(result.accessToken()); + assertEquals(TokenSource.IDENTITY_PROVIDER, result.metadata().tokenSource()); } - } - @ParameterizedTest - @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataUserAssigned") - void managedIdentityTest_UserAssigned_SuccessfulResponse(ManagedIdentitySourceType source, String endpoint, ManagedIdentityId id) throws Exception { - IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); - ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); - DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); - if (source == SERVICE_FABRIC) { - ServiceFabricManagedIdentitySource.setHttpClient(httpClientMock); + void assertTokenFromCache(IAuthenticationResult result) { + assertNotNull(result.accessToken()); + assertEquals(TokenSource.CACHE, result.metadata().tokenSource()); } - when(httpClientMock.send(expectedRequest(source, resource, id))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); + void assertMsalServiceException(CompletableFuture future, + ManagedIdentitySourceType expectedSource, + String expectedErrorCode) { + ExecutionException ex = assertThrows(ExecutionException.class, future::get); + assertInstanceOf(MsalServiceException.class, ex.getCause()); - miApp = ManagedIdentityApplication - .builder(id) - .httpClient(httpClientMock) - .build(); + MsalServiceException msalException = (MsalServiceException) ex.getCause(); + assertEquals(expectedSource.name(), msalException.managedIdentitySource()); + assertEquals(expectedErrorCode, msalException.errorCode()); + } - // Clear caching to avoid cross test pollution. - miApp.tokenCache().accessTokens.clear(); + void assertMsalClientException(CompletableFuture future, + String expectedErrorCode) { + ExecutionException ex = assertThrows(ExecutionException.class, future::get); + assertInstanceOf(MsalClientException.class, ex.getCause()); - IAuthenticationResult result = miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .build()).get(); - assertNotNull(result.accessToken()); - assertEquals(TokenSource.IDENTITY_PROVIDER, result.metadata().tokenSource()); - verify(httpClientMock, times(1)).send(any()); - } - @Test - void managedIdentityTest_RefreshOnHalfOfExpiresOn() throws Exception { - //All managed identity flows use the same AcquireTokenByManagedIdentitySupplier where refreshOn is set, - // so any of the MI options should let us verify that it's being set correctly - IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(ManagedIdentitySourceType.APP_SERVICE, appServiceEndpoint); - ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); - DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); + MsalClientException msalException = (MsalClientException) ex.getCause(); + assertEquals(expectedErrorCode, msalException.errorCode()); + } - when(httpClientMock.send(expectedRequest(ManagedIdentitySourceType.APP_SERVICE, resource))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); + CompletableFuture acquireTokenCommon(String resource) throws Exception { + return miApp.acquireTokenForManagedIdentity( + ManagedIdentityParameters.builder(resource) + .build()); + } + } - miApp = ManagedIdentityApplication - .builder(ManagedIdentityId.systemAssigned()) - .httpClient(httpClientMock) - .build(); + @Nested + class TokenAcquisitionAndCachingTests extends BaseManagedIdentityTest { - AuthenticationResult result = (AuthenticationResult) miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .build()).get(); + @ParameterizedTest + @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createData") + void managedIdentityTest_SystemAssigned_SuccessfulResponse(ManagedIdentitySourceType source, String endpoint, String resource) throws Exception { + setUpCommonTest(source, endpoint, ManagedIdentityId.systemAssigned()); - long timestampSeconds = (System.currentTimeMillis() / 1000); + when(httpClientMock.send(any())).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); - assertNotNull(result.accessToken()); - assertEquals(TokenSource.IDENTITY_PROVIDER, result.metadata().tokenSource()); - assertEquals((result.expiresOn() - timestampSeconds)/2, result.refreshOn() - timestampSeconds); + IAuthenticationResult result = acquireTokenCommon(resource).get(); - verify(httpClientMock, times(1)).send(any()); - } + assertTokenFromIdentityProvider(result); - @Test - void managedIdentityTest_ISOExpiresOn() throws Exception { - //All managed identity flows use the same AcquireTokenByManagedIdentitySupplier where refreshOn is set, - // so any of the MI options should let us verify that it's being set correctly - IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(ManagedIdentitySourceType.APP_SERVICE, appServiceEndpoint); - ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); - DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); + result = acquireTokenCommon(resource).get(); - when(httpClientMock.send(expectedRequest(ManagedIdentitySourceType.APP_SERVICE, resource))).thenReturn(expectedResponse(200, getSuccessfulResponseWithISOExpiresOn(resource))); + assertTokenFromCache(result); + verify(httpClientMock, times(1)).send(any()); + } - miApp = ManagedIdentityApplication - .builder(ManagedIdentityId.systemAssigned()) - .httpClient(httpClientMock) - .build(); + @ParameterizedTest + @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataUserAssigned") + void managedIdentityTest_UserAssigned_SuccessfulResponse(ManagedIdentitySourceType source, String endpoint, ManagedIdentityId id) throws Exception { + setUpCommonTest(source, endpoint, id); - // Clear caching to avoid cross test pollution. - miApp.tokenCache().accessTokens.clear(); + when(httpClientMock.send(expectedRequest(source, ManagedIdentityTestConstants.RESOURCE, id))) + .thenReturn(expectedResponse(200, getSuccessfulResponse(ManagedIdentityTestConstants.RESOURCE))); - AuthenticationResult result = (AuthenticationResult) miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .build()).get(); + IAuthenticationResult result = acquireTokenCommon(ManagedIdentityTestConstants.RESOURCE).get(); - // Calculate what the expected expiration time should be - long expectedExpiresOn = System.currentTimeMillis() / 1000 + (24 * 3600); // 24 hours from now, used in getSuccessfulResponseWithISOExpiresOn + assertTokenFromIdentityProvider(result); + verify(httpClientMock, times(1)).send(any()); + } - assertNotNull(result.accessToken()); - assertEquals(TokenSource.IDENTITY_PROVIDER, result.metadata().tokenSource()); - //Allow a few seconds of difference to account for execution time - assertTrue((result.expiresOn() - expectedExpiresOn) <= 5); + @ParameterizedTest + @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError") + void managedIdentity_SharedCache(ManagedIdentitySourceType source, String endpoint) throws Exception { + setUpCommonTest(source, endpoint, ManagedIdentityId.systemAssigned()); - verify(httpClientMock, times(1)).send(any()); - } + when(httpClientMock.send(any())).thenReturn(expectedResponse(200, getSuccessfulResponse(ManagedIdentityTestConstants.RESOURCE))); - @ParameterizedTest - @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataUserAssignedNotSupported") - void managedIdentityTest_UserAssigned_NotSupported(ManagedIdentitySourceType source, String endpoint, ManagedIdentityId id) throws Exception { - IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); - ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); - DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); + ManagedIdentityApplication miApp2 = ManagedIdentityApplication + .builder(ManagedIdentityId.systemAssigned()) + .httpClient(httpClientMock) + .build(); - miApp = ManagedIdentityApplication - .builder(id) - .httpClient(httpClientMock) - .build(); + IAuthenticationResult resultMiApp1 = acquireTokenCommon(ManagedIdentityTestConstants.RESOURCE).get(); - // Clear caching to avoid cross test pollution. - miApp.tokenCache().accessTokens.clear(); + assertTokenFromIdentityProvider(resultMiApp1); - try { - miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) + IAuthenticationResult resultMiApp2 = miApp2.acquireTokenForManagedIdentity( + ManagedIdentityParameters.builder(ManagedIdentityTestConstants.RESOURCE) .build()).get(); - fail("MsalServiceException is expected but not thrown."); - } catch (Exception e) { - assertNotNull(e); - assertNotNull(e.getCause()); - assertInstanceOf(MsalServiceException.class, e.getCause()); - MsalServiceException msalMsiException = (MsalServiceException) e.getCause(); - assertEquals(source.name(), msalMsiException.managedIdentitySource()); - assertEquals(MsalError.USER_ASSIGNED_MANAGED_IDENTITY_NOT_SUPPORTED, msalMsiException.errorCode()); - } - } + assertTokenFromCache(resultMiApp2); - @ParameterizedTest - @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createData") - void managedIdentityTest_DifferentScopes_RequestsNewToken(ManagedIdentitySourceType source, String endpoint) throws Exception { - String resource = "https://management.azure.com"; - String anotherResource = "https://graph.microsoft.com"; - - IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); - ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); - DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); - if (source == SERVICE_FABRIC) { - ServiceFabricManagedIdentitySource.setHttpClient(httpClientMock); + //acquireTokenForManagedIdentity does a cache lookup by default, and all ManagedIdentityApplication's share a cache, + // so calling acquireTokenForManagedIdentity with the same parameters in two different ManagedIdentityApplications + // should return the same token + assertEquals(resultMiApp1.accessToken(), resultMiApp2.accessToken()); + verify(httpClientMock, times(1)).send(any()); } - when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); - when(httpClientMock.send(expectedRequest(source, anotherResource))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); - - miApp = ManagedIdentityApplication - .builder(ManagedIdentityId.systemAssigned()) - .httpClient(httpClientMock) - .build(); + @ParameterizedTest + @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createData") + void managedIdentityTest_DifferentScopes_RequestsNewToken(ManagedIdentitySourceType source, String endpoint) throws Exception { + setUpCommonTest(source, endpoint, ManagedIdentityId.systemAssigned()); - // Clear caching to avoid cross test pollution. - miApp.tokenCache().accessTokens.clear(); + when(httpClientMock.send(any())).thenReturn(expectedResponse(200, getSuccessfulResponse(ManagedIdentityTestConstants.RESOURCE))); - IAuthenticationResult result = miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .build()).get(); + String anotherResource = "https://graph.microsoft.com"; - assertNotNull(result.accessToken()); - assertEquals(TokenSource.IDENTITY_PROVIDER, result.metadata().tokenSource()); + when(httpClientMock.send(expectedRequest(source, anotherResource))).thenReturn(expectedResponse(200, getSuccessfulResponse(ManagedIdentityTestConstants.RESOURCE))); - result = miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(anotherResource) - .build()).get(); + IAuthenticationResult result = acquireTokenCommon(ManagedIdentityTestConstants.RESOURCE).get(); - assertNotNull(result.accessToken()); - assertEquals(TokenSource.IDENTITY_PROVIDER, result.metadata().tokenSource()); - verify(httpClientMock, times(2)).send(any()); - } + assertTokenFromIdentityProvider(result); - @ParameterizedTest - @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataWrongScope") - void managedIdentityTest_WrongScopes(ManagedIdentitySourceType source, String endpoint, String resource) throws Exception { - IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); - ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); - DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); - if (source == SERVICE_FABRIC) { - ServiceFabricManagedIdentitySource.setHttpClient(httpClientMock); - } + result = acquireTokenCommon(anotherResource).get(); - if (environmentVariables.getEnvironmentVariable("SourceType").equals(ManagedIdentitySourceType.CLOUD_SHELL.toString())) { - when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(500, getMsiErrorResponseCloudShell())); - } else { - when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(500, getMsiErrorResponse())); + assertTokenFromIdentityProvider(result); + verify(httpClientMock, times(2)).send(any()); } - - miApp = ManagedIdentityApplication - .builder(ManagedIdentityId.systemAssigned()) - .httpClient(httpClientMock) - .build(); - - // Clear caching to avoid cross test pollution. - miApp.tokenCache().accessTokens.clear(); - - try { - miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .build()).get(); - } catch (Exception exception) { - assert(exception.getCause() instanceof MsalServiceException); - - MsalServiceException miException = (MsalServiceException) exception.getCause(); - assertEquals(source.name(), miException.managedIdentitySource()); - assertEquals(AuthenticationErrorCode.MANAGED_IDENTITY_REQUEST_FAILED, miException.errorCode()); - return; - } - - fail("MsalServiceException is expected but not thrown."); - verify(httpClientMock, times(1)).send(any()); } - @ParameterizedTest - @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataWrongScope") - void managedIdentityTest_Retry(ManagedIdentitySourceType source, String endpoint, String resource) throws Exception { - IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); - ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); + @Nested + class ManagedIdentityBehaviorTests extends BaseManagedIdentityTest { + //Tests covering specific behavior/scenarios/use cases/etc. for Managed Identity flows - DefaultHttpClientManagedIdentity httpClientMock = mock(DefaultHttpClientManagedIdentity.class); - if (source == SERVICE_FABRIC) { - ServiceFabricManagedIdentitySource.setHttpClient(httpClientMock); - } + @ParameterizedTest + @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError") + void managedIdentityTest_WithClaims(ManagedIdentitySourceType source, String endpoint) throws Exception { + setUpCommonTest(source, endpoint, ManagedIdentityId.systemAssigned()); - miApp = ManagedIdentityApplication - .builder(ManagedIdentityId.systemAssigned()) - .httpClient(httpClientMock) - .build(); + when(httpClientMock.send(any())).thenReturn(expectedResponse(200, getSuccessfulResponse(ManagedIdentityTestConstants.RESOURCE))); - // Clear caching to avoid cross test pollution. - miApp.tokenCache().accessTokens.clear(); + String claimsJson = "{\"default\":\"claim\"}"; - //Several specific 4xx and 5xx errors, such as 500, should trigger MSAL's retry logic - when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(500, getMsiErrorResponse())); + // First call, get the token from the identity provider. + IAuthenticationResult result = acquireTokenCommon(ManagedIdentityTestConstants.RESOURCE).get(); - try { - miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .build()).get(); - } catch (Exception exception) { - assert(exception.getCause() instanceof MsalServiceException); + assertTokenFromIdentityProvider(result); - //There should be three retries for certain MSI error codes, so there will be four invocations of - // HttpClient's send method: the original call, and the three retries - verify(httpClientMock, times(4)).send(any()); - } + // Second call, get the token from the cache without passing the claims. + result = acquireTokenCommon(ManagedIdentityTestConstants.RESOURCE).get(); - clearInvocations(httpClientMock); - //Status codes that aren't on the list, such as 123, should not cause a retry - when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(123, getMsiErrorResponseNoRetry())); + assertTokenFromCache(result); - try { - miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) + String expectedTokenHash = StringHelper.createSha256HashHexString(result.accessToken()); + when(httpClientMock.send(expectedRequest(source, ManagedIdentityTestConstants.RESOURCE, true, false, expectedTokenHash))) + .thenReturn(expectedResponse(200, getSuccessfulResponse(ManagedIdentityTestConstants.RESOURCE))); + + // Third call, when claims are passed bypass the cache. + result = miApp.acquireTokenForManagedIdentity( + ManagedIdentityParameters.builder(ManagedIdentityTestConstants.RESOURCE) + .claims(claimsJson) .build()).get(); - } catch (Exception exception) { - assert(exception.getCause() instanceof MsalServiceException); - //Because there was no retry, there should only be one invocation of HttpClient's send method - verify(httpClientMock, times(1)).send(any()); + assertTokenFromIdentityProvider(result); - return; + verify(httpClientMock, times(2)).send(any()); } - fail("MsalServiceException is expected but not thrown."); - } + @ParameterizedTest + @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError") + void managedIdentityTest_WithCapabilitiesOnly(ManagedIdentitySourceType source, String endpoint) throws Exception { + initEnvironmentVariables(source, endpoint); + initHttpClientMock(source); - @ParameterizedTest - @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError") - void managedIdentity_RequestFailed_NoPayload(ManagedIdentitySourceType source, String endpoint) throws Exception { - IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); - ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); - DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); - if (source == SERVICE_FABRIC) { - ServiceFabricManagedIdentitySource.setHttpClient(httpClientMock); - } + when(httpClientMock.send(expectedRequest(source, ManagedIdentityTestConstants.RESOURCE, false, true, null))) + .thenReturn(expectedResponse(200, getSuccessfulResponse(ManagedIdentityTestConstants.RESOURCE))); - when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(500, "")); + miApp = ManagedIdentityApplication + .builder(ManagedIdentityId.systemAssigned()) + .httpClient(httpClientMock) + .clientCapabilities(singletonList("cp1")) + .build(); - miApp = ManagedIdentityApplication - .builder(ManagedIdentityId.systemAssigned()) - .httpClient(httpClientMock) - .build(); + miApp.tokenCache.accessTokens.clear(); - // Clear caching to avoid cross test pollution. - miApp.tokenCache().accessTokens.clear(); + // First call, get the token from the identity provider. + IAuthenticationResult result = acquireTokenCommon(ManagedIdentityTestConstants.RESOURCE).get(); - try { - miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .build()).get(); - } catch (Exception exception) { - assert(exception.getCause() instanceof MsalServiceException); + assertTokenFromIdentityProvider(result); - MsalServiceException miException = (MsalServiceException) exception.getCause(); - assertEquals(source.name(), miException.managedIdentitySource()); - assertEquals(MsalError.MANAGED_IDENTITY_RESPONSE_PARSE_FAILURE, miException.errorCode()); - return; - } + // Second call, get the token from the cache without passing the claims. + result = acquireTokenCommon(ManagedIdentityTestConstants.RESOURCE).get(); - fail("MsalServiceException is expected but not thrown."); - verify(httpClientMock, times(1)).send(any()); - } + assertTokenFromCache(result); - @ParameterizedTest - @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError") - void managedIdentity_RequestFailed_NullResponse(ManagedIdentitySourceType source, String endpoint) throws Exception { - IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); - ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); - DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); - if (source == SERVICE_FABRIC) { - ServiceFabricManagedIdentitySource.setHttpClient(httpClientMock); + verify(httpClientMock, times(1)).send(any()); } - when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(200, "")); - miApp = ManagedIdentityApplication - .builder(ManagedIdentityId.systemAssigned()) - .httpClient(httpClientMock) - .build(); - - // Clear caching to avoid cross test pollution. - miApp.tokenCache().accessTokens.clear(); + @ParameterizedTest + @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError") + void managedIdentity_ClaimsAndCapabilities(ManagedIdentitySourceType source, String endpoint) throws Exception { + setUpCommonTest(source, endpoint, ManagedIdentityId.systemAssigned()); - try { - miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .build()).get(); - } catch (Exception exception) { - assert(exception.getCause() instanceof MsalServiceException); + when(httpClientMock.send(expectedRequest(source, ManagedIdentityTestConstants.RESOURCE, false, true, null))) + .thenReturn(expectedResponse(200, getSuccessfulResponse(ManagedIdentityTestConstants.RESOURCE))); - MsalServiceException miException = (MsalServiceException) exception.getCause(); - assertEquals(source.name(), miException.managedIdentitySource()); - assertEquals(AuthenticationErrorCode.MANAGED_IDENTITY_REQUEST_FAILED, miException.errorCode()); - return; - } + miApp = ManagedIdentityApplication + .builder(ManagedIdentityId.systemAssigned()) + .clientCapabilities(singletonList("cp1")) + .httpClient(httpClientMock) + .build(); - fail("MsalServiceException is expected but not thrown."); - verify(httpClientMock, times(1)).send(any()); - } + String claimsJson = "{\"default\":\"claim\"}"; + // First call, get the token from the identity provider. + IAuthenticationResult result = acquireTokenCommon(ManagedIdentityTestConstants.RESOURCE).get(); - @ParameterizedTest - @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError") - void managedIdentity_RequestFailed_UnreachableNetwork(ManagedIdentitySourceType source, String endpoint) throws Exception { - IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); - ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); - DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); - if (source == SERVICE_FABRIC) { - ServiceFabricManagedIdentitySource.setHttpClient(httpClientMock); - } + assertTokenFromIdentityProvider(result); - when(httpClientMock.send(expectedRequest(source, resource))).thenThrow(new SocketException("A socket operation was attempted to an unreachable network.")); + // Second call, get the token from the cache without passing the claims. + result = acquireTokenCommon(ManagedIdentityTestConstants.RESOURCE).get(); - miApp = ManagedIdentityApplication - .builder(ManagedIdentityId.systemAssigned()) - .httpClient(httpClientMock) - .build(); + assertTokenFromCache(result); - // Clear caching to avoid cross test pollution. - miApp.tokenCache().accessTokens.clear(); + String expectedTokenHash = StringHelper.createSha256HashHexString(result.accessToken()); + when(httpClientMock.send(expectedRequest(source, ManagedIdentityTestConstants.RESOURCE, true, true, expectedTokenHash))) + .thenReturn(expectedResponse(200, getSuccessfulResponse(ManagedIdentityTestConstants.RESOURCE))); - try { - miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) + // Third call, when claims are passed bypass the cache. + result = miApp.acquireTokenForManagedIdentity( + ManagedIdentityParameters.builder(ManagedIdentityTestConstants.RESOURCE) + .claims(claimsJson) .build()).get(); - } catch (Exception exception) { - assert(exception.getCause() instanceof MsalServiceException); - MsalServiceException miException = (MsalServiceException) exception.getCause(); - assertEquals(source.name(), miException.managedIdentitySource()); - assertEquals(MsalError.MANAGED_IDENTITY_UNREACHABLE_NETWORK, miException.errorCode()); - return; + assertTokenFromIdentityProvider(result); } - fail("MsalServiceException is expected but not thrown."); - verify(httpClientMock, times(1)).send(any()); - } + @ParameterizedTest + @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataGetSource") + void managedIdentity_GetManagedIdentitySource(ManagedIdentitySourceType source, String endpoint, ManagedIdentitySourceType expectedSource) { + setUpTestWithoutHttpClientMock(source, endpoint); - @ParameterizedTest - @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError") - void managedIdentity_SharedCache(ManagedIdentitySourceType source, String endpoint) throws Exception { - IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); - ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); - DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); - if (source == SERVICE_FABRIC) { - ServiceFabricManagedIdentitySource.setHttpClient(httpClientMock); + ManagedIdentitySourceType miClientSourceType = ManagedIdentityClient.getManagedIdentitySource(); + ManagedIdentitySourceType miAppSourceType = ManagedIdentityApplication.getManagedIdentitySource(); + assertEquals(expectedSource, miClientSourceType); + assertEquals(expectedSource, miAppSourceType); } - when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); - - miApp = ManagedIdentityApplication - .builder(ManagedIdentityId.systemAssigned()) - .httpClient(httpClientMock) - .build(); - - // Clear caching to avoid cross test pollution. - miApp.tokenCache().accessTokens.clear(); - - ManagedIdentityApplication miApp2 = ManagedIdentityApplication - .builder(ManagedIdentityId.systemAssigned()) - .httpClient(httpClientMock) - .build(); - - IAuthenticationResult resultMiApp1 = miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .build()).get(); + @Test + void managedIdentityTest_RefreshOnHalfOfExpiresOn() throws Exception { + //All managed identity flows use the same AcquireTokenByManagedIdentitySupplier where refreshOn is set, + // so any of the MI options should let us verify that it's being set correctly + setUpCommonTest(APP_SERVICE, ManagedIdentityTestConstants.APP_SERVICE_ENDPOINT, ManagedIdentityId.systemAssigned()); - assertNotNull(resultMiApp1.accessToken()); - assertEquals(TokenSource.IDENTITY_PROVIDER, resultMiApp1.metadata().tokenSource()); + when(httpClientMock.send(expectedRequest(APP_SERVICE, ManagedIdentityTestConstants.RESOURCE))).thenReturn(expectedResponse(200, getSuccessfulResponse(ManagedIdentityTestConstants.RESOURCE))); - IAuthenticationResult resultMiApp2 = miApp2.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .build()).get(); + AuthenticationResult result = (AuthenticationResult) acquireTokenCommon(ManagedIdentityTestConstants.RESOURCE).get(); - assertNotNull(resultMiApp2.accessToken()); - assertEquals(TokenSource.CACHE, resultMiApp2.metadata().tokenSource()); + long timestampSeconds = (System.currentTimeMillis() / 1000); + long expectedRefreshIn = result.refreshOn() - timestampSeconds; + long actualRefreshIn = (result.expiresOn() - timestampSeconds)/2; - //acquireTokenForManagedIdentity does a cache lookup by default, and all ManagedIdentityApplication's share a cache, - // so calling acquireTokenForManagedIdentity with the same parameters in two different ManagedIdentityApplications - // should return the same token - assertEquals(resultMiApp1.accessToken(), resultMiApp2.accessToken()); - verify(httpClientMock, times(1)).send(any()); - } + assertTokenFromIdentityProvider(result); + //Allow a few seconds of difference to account for execution time + assertTrue((actualRefreshIn - expectedRefreshIn) <= 5); - // managedIdentityTest_WithClaims: Tests that acquiring a token with claims works correctly - @ParameterizedTest - @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError") - void managedIdentityTest_WithClaims(ManagedIdentitySourceType source, String endpoint) throws Exception { - IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); - ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); - DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); - if (source == SERVICE_FABRIC) { - ServiceFabricManagedIdentitySource.setHttpClient(httpClientMock); + verify(httpClientMock, times(1)).send(any()); } - when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); - - miApp = ManagedIdentityApplication - .builder(ManagedIdentityId.systemAssigned()) - .httpClient(httpClientMock) - .build(); + @Test + void managedIdentityTest_ISOExpiresOn() throws Exception { + //All managed identity flows use the same AcquireTokenByManagedIdentitySupplier where refreshOn is set, + // so any of the MI options should let us verify that it's being set correctly + setUpCommonTest(APP_SERVICE, ManagedIdentityTestConstants.APP_SERVICE_ENDPOINT, ManagedIdentityId.systemAssigned()); - // Clear caching to avoid cross test pollution. - miApp.tokenCache().accessTokens.clear(); + when(httpClientMock.send(expectedRequest(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTestConstants.RESOURCE))).thenReturn(expectedResponse(200, getSuccessfulResponseWithISOExpiresOn(ManagedIdentityTestConstants.RESOURCE))); - String claimsJson = "{\"default\":\"claim\"}"; + AuthenticationResult result = (AuthenticationResult) miApp.acquireTokenForManagedIdentity( + ManagedIdentityParameters.builder(ManagedIdentityTestConstants.RESOURCE) + .build()).get(); - // First call, get the token from the identity provider. - IAuthenticationResult result = miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .build()).get(); + // Calculate what the expected expiration time should be + long expectedExpiresOn = System.currentTimeMillis() / 1000 + (24 * 3600); // 24 hours from now, used in getSuccessfulResponseWithISOExpiresOn - assertNotNull(result.accessToken()); - assertEquals(TokenSource.IDENTITY_PROVIDER, result.metadata().tokenSource()); + assertTokenFromIdentityProvider(result); + //Allow a few seconds of difference to account for execution time + assertTrue((result.expiresOn() - expectedExpiresOn) <= 5); - // Second call, get the token from the cache without passing the claims. - result = miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .build()).get(); + verify(httpClientMock, times(1)).send(any()); + } + } - assertNotNull(result.accessToken()); - assertEquals(TokenSource.CACHE, result.metadata().tokenSource()); + @Nested + class ErrorHandlingTests extends BaseManagedIdentityTest { - String expectedTokenHash = StringHelper.createSha256HashHexString(result.accessToken()); - when(httpClientMock.send(expectedRequest(source, resource, true, false, expectedTokenHash))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); + @ParameterizedTest + @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createData") + void managedIdentityTest_SuccessfulResponse_WithInvalidJson(ManagedIdentitySourceType source, String endpoint, String resource) throws Exception { + setUpCommonTest(source, endpoint, ManagedIdentityId.systemAssigned()); - // Third call, when claims are passed bypass the cache. - result = miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .claims(claimsJson) - .build()).get(); + when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(200, ManagedIdentityTestConstants.RESPONSE_MALFORMED_JSON)); - assertNotNull(result.accessToken()); - assertEquals(TokenSource.IDENTITY_PROVIDER, result.metadata().tokenSource()); + assertMsalServiceException(acquireTokenCommon(resource), source, MsalError.MANAGED_IDENTITY_RESPONSE_PARSE_FAILURE); + } - verify(httpClientMock, times(2)).send(any()); - } + @ParameterizedTest + @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataUserAssignedNotSupported") + void managedIdentityTest_UserAssigned_NotSupported(ManagedIdentitySourceType source, String endpoint, ManagedIdentityId id) throws Exception { + setUpCommonTest(source, endpoint, id); - @ParameterizedTest - @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError") - void managedIdentityTest_WithCapabilitiesOnly(ManagedIdentitySourceType source, String endpoint) throws Exception { - IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); - ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); - DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); - if (source == SERVICE_FABRIC) { - ServiceFabricManagedIdentitySource.setHttpClient(httpClientMock); + assertMsalServiceException(acquireTokenCommon(ManagedIdentityTestConstants.RESOURCE), source, MsalError.USER_ASSIGNED_MANAGED_IDENTITY_NOT_SUPPORTED); } - when(httpClientMock.send(expectedRequest(source, resource, false, true, null))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); + @ParameterizedTest + @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataWrongScope") + void managedIdentityTest_WrongScopes(ManagedIdentitySourceType source, String endpoint, String resource) throws Exception { + setUpCommonTest(source, endpoint, ManagedIdentityId.systemAssigned()); + + if (environmentVariables.getEnvironmentVariable("SourceType").equals(CLOUD_SHELL.toString())) { + when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(500, ManagedIdentityTestConstants.CLOUDSHELL_ERROR_RESPONSE)); + } else { + when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(500, ManagedIdentityTestConstants.MSI_ERROR_RESPONSE_500)); + } - miApp = ManagedIdentityApplication - .builder(ManagedIdentityId.systemAssigned()) - .httpClient(httpClientMock) - .clientCapabilities(singletonList("cp1")) - .build(); + assertMsalServiceException(acquireTokenCommon(resource), source, MsalError.MANAGED_IDENTITY_REQUEST_FAILED); + } - // Clear caching to avoid cross test pollution. - miApp.tokenCache().accessTokens.clear(); + @ParameterizedTest + @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataWrongScope") + void managedIdentityTest_Retry(ManagedIdentitySourceType source, String endpoint, String resource) throws Exception { + IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); + ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); - // First call, get the token from the identity provider. - IAuthenticationResult result = miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .build()).get(); + DefaultHttpClientManagedIdentity httpClientMock = mock(DefaultHttpClientManagedIdentity.class); + if (source == SERVICE_FABRIC) { + ServiceFabricManagedIdentitySource.setHttpClient(httpClientMock); + } - assertNotNull(result.accessToken()); - assertEquals(TokenSource.IDENTITY_PROVIDER, result.metadata().tokenSource()); + miApp = ManagedIdentityApplication + .builder(ManagedIdentityId.systemAssigned()) + .httpClient(httpClientMock) + .build(); - // Second call, get the token from the cache without passing the claims. - result = miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .build()).get(); + //Several specific 4xx and 5xx errors, such as 500, should trigger MSAL's retry logic + when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(500, ManagedIdentityTestConstants.MSI_ERROR_RESPONSE_500)); - assertNotNull(result.accessToken()); - assertEquals(TokenSource.CACHE, result.metadata().tokenSource()); + try { + acquireTokenCommon(resource).get(); + } catch (Exception exception) { + assert(exception.getCause() instanceof MsalServiceException); - verify(httpClientMock, times(1)).send(any()); - } + //There should be three retries for certain MSI error codes, so there will be four invocations of + // HttpClient's send method: the original call, and the three retries + verify(httpClientMock, times(4)).send(any()); + } - @ParameterizedTest - @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError") - void managedIdentity_ClaimsAndCapabilities(ManagedIdentitySourceType source, String endpoint) throws Exception { - IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); - ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); - DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); - if (source == SERVICE_FABRIC) { - ServiceFabricManagedIdentitySource.setHttpClient(httpClientMock); - } + clearInvocations(httpClientMock); + //Status codes that aren't on the list, such as 123, should not cause a retry + when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(123, ManagedIdentityTestConstants.MSI_ERROR_RESPONSE_NORETRY)); - when(httpClientMock.send(expectedRequest(source, resource, false, true, null))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); + try { + acquireTokenCommon(resource).get(); + } catch (Exception exception) { + assert(exception.getCause() instanceof MsalServiceException); - miApp = ManagedIdentityApplication - .builder(ManagedIdentityId.systemAssigned()) - .clientCapabilities(singletonList("cp1")) - .httpClient(httpClientMock) - .build(); + //Because there was no retry, there should only be one invocation of HttpClient's send method + verify(httpClientMock, times(1)).send(any()); - // Clear caching to avoid cross test pollution. - miApp.tokenCache().accessTokens.clear(); + return; + } - String claimsJson = "{\"default\":\"claim\"}"; - // First call, get the token from the identity provider. - IAuthenticationResult result = miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .build()).get(); + fail("MsalServiceException is expected but not thrown."); + } - assertNotNull(result.accessToken()); - assertEquals(TokenSource.IDENTITY_PROVIDER, result.metadata().tokenSource()); + @ParameterizedTest + @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError") + void managedIdentity_RequestFailed_NoPayload(ManagedIdentitySourceType source, String endpoint) throws Exception { + setUpCommonTest(source, endpoint, ManagedIdentityId.systemAssigned()); - // Second call, get the token from the cache without passing the claims. - result = miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .build()).get(); + when(httpClientMock.send(expectedRequest(source, ManagedIdentityTestConstants.RESOURCE))).thenReturn(expectedResponse(500, "")); - assertNotNull(result.accessToken()); - assertEquals(TokenSource.CACHE, result.metadata().tokenSource()); + assertMsalServiceException(acquireTokenCommon(ManagedIdentityTestConstants.RESOURCE), source, MsalError.MANAGED_IDENTITY_RESPONSE_PARSE_FAILURE); + } - String expectedTokenHash = StringHelper.createSha256HashHexString(result.accessToken()); - when(httpClientMock.send(expectedRequest(source, resource, true, true, expectedTokenHash))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); + @ParameterizedTest + @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError") + void managedIdentity_RequestFailed_NullResponse(ManagedIdentitySourceType source, String endpoint) throws Exception { + setUpCommonTest(source, endpoint, ManagedIdentityId.systemAssigned()); - // Third call, when claims are passed bypass the cache. - result = miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .claims(claimsJson) - .build()).get(); + when(httpClientMock.send(expectedRequest(source, ManagedIdentityTestConstants.RESOURCE))).thenReturn(expectedResponse(200, "")); - assertNotNull(result.accessToken()); - assertEquals(TokenSource.IDENTITY_PROVIDER, result.metadata().tokenSource()); - } + assertMsalServiceException(acquireTokenCommon(ManagedIdentityTestConstants.RESOURCE), source, MsalError.MANAGED_IDENTITY_REQUEST_FAILED); - @ParameterizedTest - @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createInvalidClaimsData") - void managedIdentity_InvalidClaims(String claimsJson) throws Exception { - IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(APP_SERVICE, appServiceEndpoint); - ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); - DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); + verify(httpClientMock, times(1)).send(any()); + } - miApp = ManagedIdentityApplication - .builder(ManagedIdentityId.systemAssigned()) - .httpClient(httpClientMock) - .build(); + @ParameterizedTest + @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError") + void managedIdentity_RequestFailed_UnreachableNetwork(ManagedIdentitySourceType source, String endpoint) throws Exception { + setUpCommonTest(source, endpoint, ManagedIdentityId.systemAssigned()); - CompletableFuture future = miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .claims(claimsJson) - .build()); + when(httpClientMock.send(expectedRequest(source, ManagedIdentityTestConstants.RESOURCE))).thenThrow(new SocketException("A socket operation was attempted to an unreachable network.")); - ExecutionException ex = assertThrows(ExecutionException.class, future::get); - assertInstanceOf(MsalClientException.class, ex.getCause()); + assertMsalServiceException(acquireTokenCommon(ManagedIdentityTestConstants.RESOURCE), source, MsalError.MANAGED_IDENTITY_UNREACHABLE_NETWORK); - MsalClientException msalException = (MsalClientException) ex.getCause(); - assertEquals(AuthenticationErrorCode.INVALID_JSON, msalException.errorCode()); + verify(httpClientMock, times(1)).send(any()); + } - // Verify no HTTP requests were made for invalid claims - verify(httpClientMock, never()).send(any()); - } + @ParameterizedTest + @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createInvalidClaimsData") + void managedIdentity_InvalidClaims(String claimsJson) throws Exception { + setUpCommonTest(APP_SERVICE, ManagedIdentityTestConstants.APP_SERVICE_ENDPOINT, ManagedIdentityId.systemAssigned()); - @Test - void managedIdentityTest_WithEmptyClaims() throws Exception { - IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(APP_SERVICE, appServiceEndpoint); - ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); - DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); + CompletableFuture future = miApp.acquireTokenForManagedIdentity( + ManagedIdentityParameters.builder(ManagedIdentityTestConstants.RESOURCE) + .claims(claimsJson) + .build()); - miApp = ManagedIdentityApplication - .builder(ManagedIdentityId.systemAssigned()) - .httpClient(httpClientMock) - .build(); + assertMsalClientException(future, AuthenticationErrorCode.INVALID_JSON); - try { - miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .claims("") - .build()); - } catch (Exception exception) { - assert(exception instanceof IllegalArgumentException); + // Verify no HTTP requests were made for invalid claims + verify(httpClientMock, never()).send(any()); } - try { - miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .claims(null) - .build()); - } catch (Exception exception) { - assert(exception instanceof IllegalArgumentException); - } + @Test + void managedIdentityTest_WithEmptyClaims() throws Exception { + setUpCommonTest(APP_SERVICE, ManagedIdentityTestConstants.APP_SERVICE_ENDPOINT, ManagedIdentityId.systemAssigned()); + + try { + miApp.acquireTokenForManagedIdentity( + ManagedIdentityParameters.builder(ManagedIdentityTestConstants.RESOURCE) + .claims("") + .build()); + } catch (Exception exception) { + assert(exception instanceof IllegalArgumentException); + } + + try { + miApp.acquireTokenForManagedIdentity( + ManagedIdentityParameters.builder(ManagedIdentityTestConstants.RESOURCE) + .claims(null) + .build()); + } catch (Exception exception) { + assert(exception instanceof IllegalArgumentException); + } - // Verify no HTTP requests were made for invalid claims - verify(httpClientMock, never()).send(any()); + // Verify no HTTP requests were made for invalid claims + verify(httpClientMock, never()).send(any()); + } } @Nested - class AzureArc { + class AzureArc extends BaseManagedIdentityTest{ @Test void missingAuthHeader() throws Exception { @@ -926,31 +697,19 @@ void invalidPathWithRealFile(String authHeaderKey) } private void mockHttpResponse(Map> responseHeaders) throws Exception { - IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(AZURE_ARC, azureArcEndpoint); - ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); - DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); + setUpCommonTest(AZURE_ARC, ManagedIdentityTestConstants.AZURE_ARC_ENDPOINT, ManagedIdentityId.systemAssigned()); HttpResponse response = new HttpResponse(); response.statusCode(SC_UNAUTHORIZED); response.headers().putAll(responseHeaders); when(httpClientMock.send( - expectedRequest(AZURE_ARC, resource))).thenReturn( + expectedRequest(AZURE_ARC, ManagedIdentityTestConstants.RESOURCE))).thenReturn( response); - - miApp = ManagedIdentityApplication - .builder(ManagedIdentityId.systemAssigned()) - .httpClient(httpClientMock) - .build(); - - // Clear caching to avoid cross test pollution. - miApp.tokenCache().accessTokens.clear(); } private void assertMsalServiceException(String errorCode, String message) throws Exception { - CompletableFuture future = - miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource).build()); + CompletableFuture future = acquireTokenCommon(ManagedIdentityTestConstants.RESOURCE); ExecutionException ex = assertThrows(ExecutionException.class, future::get); assertInstanceOf(MsalServiceException.class, ex.getCause());