From 578253832935e14cc6a5d6305b592b9653f8867b Mon Sep 17 00:00:00 2001 From: avdunn Date: Thu, 15 May 2025 17:33:19 -0700 Subject: [PATCH 1/5] Refactor ManagedIdentityTests --- .../aad/msal4j/ManagedIdentityTests.java | 1006 +++++++---------- 1 file changed, 393 insertions(+), 613 deletions(-) diff --git a/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTests.java b/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTests.java index 54225f0c..6a1cf4f2 100644 --- a/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTests.java +++ b/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTests.java @@ -4,7 +4,6 @@ package com.microsoft.aad.msal4j; import com.nimbusds.oauth2.sdk.util.URLUtils; -import labapi.App; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; @@ -12,7 +11,6 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; -import org.mockito.ArgumentCaptor; import org.mockito.junit.jupiter.MockitoExtension; import java.net.SocketException; @@ -26,10 +24,11 @@ import java.util.concurrent.ExecutionException; import static com.microsoft.aad.msal4j.ManagedIdentitySourceType.*; -import static com.microsoft.aad.msal4j.MsalError.*; +import static com.microsoft.aad.msal4j.MsalError.MANAGED_IDENTITY_FILE_READ_ERROR; +import static com.microsoft.aad.msal4j.MsalError.MANAGED_IDENTITY_REQUEST_FAILED; import static com.microsoft.aad.msal4j.MsalErrorMessage.*; import static java.util.Collections.*; -import static org.apache.http.HttpStatus.*; +import static org.apache.http.HttpStatus.SC_UNAUTHORIZED; import static org.junit.jupiter.api.Assertions.*; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.*; @@ -39,103 +38,115 @@ class ManagedIdentityTests { static final String resource = "https://management.azure.com"; - final static String resourceDefaultSuffix = "https://management.azure.com/.default"; - final static String appServiceEndpoint = "http://127.0.0.1:41564/msi/token"; - final static String IMDS_ENDPOINT = "http://169.254.169.254/metadata/identity/oauth2/token"; - final static String azureArcEndpoint = "http://localhost:40342/metadata/identity/oauth2/token"; - final static String cloudShellEndpoint = "http://localhost:40342/metadata/identity/oauth2/token"; - final static String serviceFabricEndpoint = "http://localhost:40342/metadata/identity/oauth2/token"; + static final String resourceDefaultSuffix = "https://management.azure.com/.default"; + static final String appServiceEndpoint = "http://127.0.0.1:41564/msi/token"; + static final String IMDS_ENDPOINT = "http://169.254.169.254/metadata/identity/oauth2/token"; + static final String azureArcEndpoint = "http://localhost:40342/metadata/identity/oauth2/token"; + static final String cloudShellEndpoint = "http://localhost:40342/metadata/identity/oauth2/token"; + static final String serviceFabricEndpoint = "http://localhost:40342/metadata/identity/oauth2/token"; private static ManagedIdentityApplication miApp; + static final String SUCCESSFUL_RESPONSE_INVALID_JSON = "missing starting bracket \"access_token\":\"accesstoken\",\"token_type\":" + "\"Bearer\",\"client_id\":\"a bunch of problems}"; + //Realistic error response that should trigger a retry + static final String MSI_ERROR_RESPONSE_500 = "{\"statusCode\":\"500\",\"message\":\"An unexpected error occured while fetching the AAD Token.\",\"correlationId\":\"7d0c9763-ff1d-4842-a3f3-6d49e64f4513\"}"; + //Cloud Shell error responses follow a different style, the error info is in a second JSON + static final String CLOUDSHELL_ERROR_RESPONSE = "{\"error\":{\"code\":\"AudienceNotSupported\",\"message\":\"Audience user.read is not a supported MSI token audience.\"}}"; + //Response with an error code that should not trigger a retry + static final String MSI_ERROR_RESPONSE_NORETRY = "{\"statusCode\":\"123\",\"message\":\"Not one of the retryable error responses\",\"correlationId\":\"7d0c9763-ff1d-4842-a3f3-6d49e64f4513\"}"; + private String getSuccessfulResponse(String resource) { long expiresOn = (System.currentTimeMillis() / 1000) + (24 * 3600);//A long-lived, 24 hour token return "{\"access_token\":\"accesstoken\",\"expires_on\":\"" + expiresOn + "\",\"resource\":\"" + resource + "\",\"token_type\":" + "\"Bearer\",\"client_id\":\"client_id\"}"; } - private String getSuccessfulResponseWithInvalidJson() { - return "missing starting bracket \"access_token\":\"accesstoken\",\"token_type\":" + "\"Bearer\",\"client_id\":\"a bunch of problems}"; - } - - private String getMsiErrorResponse() { - return "{\"statusCode\":\"500\",\"message\":\"An unexpected error occured while fetching the AAD Token.\",\"correlationId\":\"7d0c9763-ff1d-4842-a3f3-6d49e64f4513\"}"; - } - - //Cloud Shell error responses follow a different style, the error info is in a second JSON - private String getMsiErrorResponseCloudShell() { - return "{\"error\":{\"code\":\"AudienceNotSupported\",\"message\":\"Audience user.read is not a supported MSI token audience.\"}}"; - } - - private String getMsiErrorResponseNoRetry() { - return "{\"statusCode\":\"123\",\"message\":\"Not one of the retryable error responses\",\"correlationId\":\"7d0c9763-ff1d-4842-a3f3-6d49e64f4513\"}"; - } - - private HttpRequest expectedRequest(ManagedIdentitySourceType source, String resource, boolean hasClaims, boolean hasCapabilities, String expectedTokenHash) { - return expectedRequest(source, resource, ManagedIdentityId.systemAssigned(), hasClaims, hasCapabilities, expectedTokenHash); + private HttpRequest expectedRequest(ManagedIdentitySourceType source, String resource) { + return expectedRequest(source, resource, ManagedIdentityId.systemAssigned(), false, false, null); } private HttpRequest expectedRequest(ManagedIdentitySourceType source, String resource, ManagedIdentityId id) { return expectedRequest(source, resource, id, false, false, null); } - private HttpRequest expectedRequest(ManagedIdentitySourceType source, String resource) { - return expectedRequest(source, resource, ManagedIdentityId.systemAssigned(), false, false, null); + private HttpRequest expectedRequest(ManagedIdentitySourceType source, String resource, + boolean hasClaims, boolean hasCapabilities, String expectedTokenHash) { + return expectedRequest(source, resource, ManagedIdentityId.systemAssigned(), hasClaims, hasCapabilities, expectedTokenHash); } private HttpRequest expectedRequest(ManagedIdentitySourceType source, String resource, - ManagedIdentityId id, boolean hasClaims, boolean hasCapabilities, String expectedTokenHash) { - String endpoint = null; + ManagedIdentityId id, boolean hasClaims, boolean hasCapabilities, String expectedTokenHash) { + // Create maps for headers and query parameters Map headers = new HashMap<>(); Map> queryParameters = new HashMap<>(); + // Add resource to query parameters (common for all sources) + queryParameters.put("resource", singletonList(resource)); + + // Handle claims and capabilities if supported if (Constants.TOKEN_REVOCATION_SUPPORTED_ENVIRONMENTS.contains(source)) { if (hasCapabilities) { - queryParameters.put(Constants.CLIENT_CAPABILITY_REQUEST_PARAM, Collections.singletonList("cp1")); + queryParameters.put(Constants.CLIENT_CAPABILITY_REQUEST_PARAM, singletonList("cp1")); } - if (hasClaims) { - queryParameters.put(Constants.TOKEN_HASH_CLAIM, Collections.singletonList(expectedTokenHash)); + queryParameters.put(Constants.TOKEN_HASH_CLAIM, singletonList(expectedTokenHash)); } } + // Configure source-specific parameters + String endpoint = configureSourceSpecificParameters(source, headers, queryParameters); + + // Configure idType-specific parameters + if (id.getIdType() != ManagedIdentityId.systemAssigned().getIdType()) { + configureIdentitySpecificParameters(id, queryParameters); + } + + if (!queryParameters.isEmpty()) { + endpoint = endpoint + "?" + URLUtils.serializeParameters(queryParameters); + } + + return new HttpRequest(HttpMethod.GET, endpoint, headers); + } + + private String configureSourceSpecificParameters(ManagedIdentitySourceType source, + Map headers, + Map> queryParameters) { switch (source) { case APP_SERVICE: - endpoint = appServiceEndpoint; - queryParameters.put("api-version", Collections.singletonList("2019-08-01")); - queryParameters.put("resource", Collections.singletonList(resource)); + queryParameters.put("api-version", singletonList("2019-08-01")); headers.put("X-IDENTITY-HEADER", "secret"); - break; + return appServiceEndpoint; + case CLOUD_SHELL: - endpoint = cloudShellEndpoint; headers.put("ContentType", "application/x-www-form-urlencoded"); headers.put("Metadata", "true"); - queryParameters.put("resource", Collections.singletonList(resource)); - break; + return cloudShellEndpoint; + case AZURE_ARC: - endpoint = azureArcEndpoint; - queryParameters.put("api-version", Collections.singletonList("2019-11-01")); - queryParameters.put("resource", Collections.singletonList(resource)); + queryParameters.put("api-version", singletonList("2019-11-01")); headers.put("Metadata", "true"); - break; + return azureArcEndpoint; + case SERVICE_FABRIC: - endpoint = serviceFabricEndpoint; - queryParameters.put("api-version", Collections.singletonList("2019-07-01-preview")); - queryParameters.put("resource", Collections.singletonList(resource)); + queryParameters.put("api-version", singletonList("2019-07-01-preview")); headers.put("secret", "secret"); - break; + return serviceFabricEndpoint; + case IMDS: case NONE: case DEFAULT_TO_IMDS: - endpoint = IMDS_ENDPOINT; - queryParameters.put("api-version", Collections.singletonList("2018-02-01")); - queryParameters.put("resource", Collections.singletonList(resource)); + default: + queryParameters.put("api-version", singletonList("2018-02-01")); headers.put("Metadata", "true"); - break; + return IMDS_ENDPOINT; } + } + private void configureIdentitySpecificParameters(ManagedIdentityId id, Map> queryParameters) { switch (id.getIdType()) { + case SYSTEM_ASSIGNED: + break; case CLIENT_ID: - queryParameters.put("client_id", Collections.singletonList(id.getUserAssignedId())); + queryParameters.put("client_id", singletonList(id.getUserAssignedId())); break; case RESOURCE_ID: if (ManagedIdentityClient.getManagedIdentitySource() == ManagedIdentitySourceType.IMDS) { @@ -147,19 +158,9 @@ private HttpRequest expectedRequest(ManagedIdentitySourceType source, String res case OBJECT_ID: queryParameters.put("object_id", singletonList(id.getUserAssignedId())); break; + default: + throw new IllegalStateException("Unexpected value: " + id.getIdType()); } - - return new HttpRequest(HttpMethod.GET, computeUri(endpoint, queryParameters), headers); - } - - private String computeUri(String endpoint, Map> queryParameters) { - if (queryParameters.isEmpty()) { - return endpoint; - } - - String queryString = URLUtils.serializeParameters(queryParameters); - - return endpoint + "?" + queryString; } private HttpResponse expectedResponse(int statusCode, String response) { @@ -170,678 +171,469 @@ private HttpResponse expectedResponse(int statusCode, String response) { return httpResponse; } - @ParameterizedTest - @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataGetSource") - void managedIdentity_GetManagedIdentitySource(ManagedIdentitySourceType source, String endpoint, ManagedIdentitySourceType expectedSource) { - IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); - ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); + abstract class BaseManagedIdentityTest { + protected ManagedIdentityApplication miApp; + protected DefaultHttpClient httpClientMock; + protected IEnvironmentVariables environmentVariables; - miApp = ManagedIdentityApplication - .builder(ManagedIdentityId.systemAssigned()) - .build(); + void setUpCommonTest(ManagedIdentitySourceType source, String endpoint, ManagedIdentityId idType) { + initEnvironmentVariables(source, endpoint); + initHttpClientMock(source); + initManagedIdentityApplication(idType); + } - ManagedIdentitySourceType miClientSourceType = ManagedIdentityClient.getManagedIdentitySource(); - ManagedIdentitySourceType miAppSourceType = ManagedIdentityApplication.getManagedIdentitySource(); - assertEquals(expectedSource, miClientSourceType); - assertEquals(expectedSource, miAppSourceType); - } + void initEnvironmentVariables(ManagedIdentitySourceType source, String endpoint) { + environmentVariables = new EnvironmentVariablesHelper(source, endpoint); + ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); + } - @ParameterizedTest - @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createData") - void managedIdentityTest_SystemAssigned_SuccessfulResponse(ManagedIdentitySourceType source, String endpoint, String resource) throws Exception { - IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); - ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); - DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); - if (source == SERVICE_FABRIC) { - ServiceFabricManagedIdentitySource.setHttpClient(httpClientMock); + void initHttpClientMock(ManagedIdentitySourceType source) { + httpClientMock = mock(DefaultHttpClient.class); + if (source == SERVICE_FABRIC) { + ServiceFabricManagedIdentitySource.setHttpClient(httpClientMock); + } } - when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); + void initManagedIdentityApplication(ManagedIdentityId idType) { + miApp = ManagedIdentityApplication + .builder(idType) + .httpClient(httpClientMock) + .build(); - miApp = ManagedIdentityApplication - .builder(ManagedIdentityId.systemAssigned()) - .httpClient(httpClientMock) - .build(); + // ManagedIdentityApplication uses a static token cache, avoid cross test pollution by clearing it + miApp.tokenCache().accessTokens.clear(); + } - // Clear caching to avoid cross test pollution. - miApp.tokenCache().accessTokens.clear(); + void setUpTestWithoutHttpClientMock(ManagedIdentitySourceType source, String endpoint) { + initEnvironmentVariables(source, endpoint); - IAuthenticationResult result = miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .build()).get(); + miApp = ManagedIdentityApplication + .builder(ManagedIdentityId.systemAssigned()) + .build(); - assertNotNull(result.accessToken()); - assertEquals(TokenSource.IDENTITY_PROVIDER, result.metadata().tokenSource()); + // ManagedIdentityApplication uses a static token cache, avoid cross test pollution by clearing it + miApp.tokenCache().accessTokens.clear(); + } - result = miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .build()).get(); + void assertTokenFromIdentityProvider(IAuthenticationResult result) { + assertNotNull(result.accessToken()); + assertEquals(TokenSource.IDENTITY_PROVIDER, result.metadata().tokenSource()); + } - assertNotNull(result.accessToken()); - assertEquals(TokenSource.CACHE, result.metadata().tokenSource()); - verify(httpClientMock, times(1)).send(any()); - } + void assertTokenFromCache(IAuthenticationResult result) { + assertNotNull(result.accessToken()); + assertEquals(TokenSource.CACHE, result.metadata().tokenSource()); + } - @ParameterizedTest - @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createData") - void managedIdentityTest_SuccessfulResponse_WithInvalidJson(ManagedIdentitySourceType source, String endpoint, String resource) throws Exception { - IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); - ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); - DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); - if (source == SERVICE_FABRIC) { - ServiceFabricManagedIdentitySource.setHttpClient(httpClientMock); + void assertMsalServiceException(CompletableFuture future, + ManagedIdentitySourceType expectedSource, + String expectedErrorCode) { + ExecutionException ex = assertThrows(ExecutionException.class, future::get); + assertInstanceOf(MsalServiceException.class, ex.getCause()); + + MsalServiceException msalException = (MsalServiceException) ex.getCause(); + assertEquals(expectedSource.name(), msalException.managedIdentitySource()); + assertEquals(expectedErrorCode, msalException.errorCode()); } - when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(200, getSuccessfulResponseWithInvalidJson())); + void assertMsalClientException(CompletableFuture future, + String expectedErrorCode) { + ExecutionException ex = assertThrows(ExecutionException.class, future::get); + assertInstanceOf(MsalClientException.class, ex.getCause()); - miApp = ManagedIdentityApplication - .builder(ManagedIdentityId.systemAssigned()) - .httpClient(httpClientMock) - .build(); - // Clear caching to avoid cross test pollution. - miApp.tokenCache().accessTokens.clear(); - try { - miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .build()).get(); - fail("MsalServiceException is expected but not thrown."); - } catch (ExecutionException exception) { - assert(exception.getCause() instanceof MsalJsonParsingException); + MsalClientException msalException = (MsalClientException) ex.getCause(); + assertEquals(expectedErrorCode, msalException.errorCode()); + } - MsalJsonParsingException miException = (MsalJsonParsingException) exception.getCause(); - assertEquals(source.name(), miException.managedIdentitySource()); - assertEquals(MsalError.MANAGED_IDENTITY_RESPONSE_PARSE_FAILURE, miException.errorCode()); + CompletableFuture acquireTokenCommon(String resource) throws Exception { + return miApp.acquireTokenForManagedIdentity( + ManagedIdentityParameters.builder(resource) + .build()); } } - @ParameterizedTest - @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataUserAssigned") - void managedIdentityTest_UserAssigned_SuccessfulResponse(ManagedIdentitySourceType source, String endpoint, ManagedIdentityId id) throws Exception { - IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); - ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); - DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); - if (source == SERVICE_FABRIC) { - ServiceFabricManagedIdentitySource.setHttpClient(httpClientMock); - } + @Nested + class TokenAcquisitionAndCachingTests extends BaseManagedIdentityTest { - when(httpClientMock.send(expectedRequest(source, resource, id))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); + @ParameterizedTest + @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createData") + void managedIdentityTest_SystemAssigned_SuccessfulResponse(ManagedIdentitySourceType source, String endpoint, String resource) throws Exception { + setUpCommonTest(source, endpoint, ManagedIdentityId.systemAssigned()); - miApp = ManagedIdentityApplication - .builder(id) - .httpClient(httpClientMock) - .build(); + when(httpClientMock.send(any())).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); - // Clear caching to avoid cross test pollution. - miApp.tokenCache().accessTokens.clear(); + IAuthenticationResult result = acquireTokenCommon(resource).get(); - IAuthenticationResult result = miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .build()).get(); + assertTokenFromIdentityProvider(result); - assertNotNull(result.accessToken()); - assertEquals(TokenSource.IDENTITY_PROVIDER, result.metadata().tokenSource()); - verify(httpClientMock, times(1)).send(any()); - } + result = acquireTokenCommon(resource).get(); - @Test - void managedIdentityTest_RefreshOnHalfOfExpiresOn() throws Exception { - //All managed identity flows use the same AcquireTokenByManagedIdentitySupplier where refreshOn is set, - // so any of the MI options should let us verify that it's being set correctly - IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(ManagedIdentitySourceType.APP_SERVICE, appServiceEndpoint); - ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); - DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); + assertTokenFromCache(result); + verify(httpClientMock, times(1)).send(any()); + } - when(httpClientMock.send(expectedRequest(ManagedIdentitySourceType.APP_SERVICE, resource))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); + @ParameterizedTest + @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataUserAssigned") + void managedIdentityTest_UserAssigned_SuccessfulResponse(ManagedIdentitySourceType source, String endpoint, ManagedIdentityId id) throws Exception { + setUpCommonTest(source, endpoint, id); - miApp = ManagedIdentityApplication - .builder(ManagedIdentityId.systemAssigned()) - .httpClient(httpClientMock) - .build(); + when(httpClientMock.send(expectedRequest(source, resource, id))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); - AuthenticationResult result = (AuthenticationResult) miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .build()).get(); + IAuthenticationResult result = acquireTokenCommon(resource).get(); - long timestampSeconds = (System.currentTimeMillis() / 1000); + assertTokenFromIdentityProvider(result); + verify(httpClientMock, times(1)).send(any()); + } - assertNotNull(result.accessToken()); - assertEquals(TokenSource.IDENTITY_PROVIDER, result.metadata().tokenSource()); - assertEquals((result.expiresOn() - timestampSeconds)/2, result.refreshOn() - timestampSeconds); + @ParameterizedTest + @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError") + void managedIdentity_SharedCache(ManagedIdentitySourceType source, String endpoint) throws Exception { + setUpCommonTest(source, endpoint, ManagedIdentityId.systemAssigned()); - verify(httpClientMock, times(1)).send(any()); - } + when(httpClientMock.send(any())).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); - @ParameterizedTest - @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataUserAssignedNotSupported") - void managedIdentityTest_UserAssigned_NotSupported(ManagedIdentitySourceType source, String endpoint, ManagedIdentityId id) throws Exception { - IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); - ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); - DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); + ManagedIdentityApplication miApp2 = ManagedIdentityApplication + .builder(ManagedIdentityId.systemAssigned()) + .httpClient(httpClientMock) + .build(); - miApp = ManagedIdentityApplication - .builder(id) - .httpClient(httpClientMock) - .build(); + IAuthenticationResult resultMiApp1 = acquireTokenCommon(resource).get(); - // Clear caching to avoid cross test pollution. - miApp.tokenCache().accessTokens.clear(); + assertTokenFromIdentityProvider(resultMiApp1); - try { - miApp.acquireTokenForManagedIdentity( + IAuthenticationResult resultMiApp2 = miApp2.acquireTokenForManagedIdentity( ManagedIdentityParameters.builder(resource) .build()).get(); - fail("MsalServiceException is expected but not thrown."); - } catch (Exception e) { - assertNotNull(e); - assertNotNull(e.getCause()); - assertInstanceOf(MsalServiceException.class, e.getCause()); - MsalServiceException msalMsiException = (MsalServiceException) e.getCause(); - assertEquals(source.name(), msalMsiException.managedIdentitySource()); - assertEquals(MsalError.USER_ASSIGNED_MANAGED_IDENTITY_NOT_SUPPORTED, msalMsiException.errorCode()); - } - } - - @ParameterizedTest - @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createData") - void managedIdentityTest_DifferentScopes_RequestsNewToken(ManagedIdentitySourceType source, String endpoint) throws Exception { - String resource = "https://management.azure.com"; - String anotherResource = "https://graph.microsoft.com"; + assertTokenFromCache(resultMiApp2); - IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); - ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); - DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); - if (source == SERVICE_FABRIC) { - ServiceFabricManagedIdentitySource.setHttpClient(httpClientMock); + //acquireTokenForManagedIdentity does a cache lookup by default, and all ManagedIdentityApplication's share a cache, + // so calling acquireTokenForManagedIdentity with the same parameters in two different ManagedIdentityApplications + // should return the same token + assertEquals(resultMiApp1.accessToken(), resultMiApp2.accessToken()); + verify(httpClientMock, times(1)).send(any()); } - when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); - when(httpClientMock.send(expectedRequest(source, anotherResource))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); - - miApp = ManagedIdentityApplication - .builder(ManagedIdentityId.systemAssigned()) - .httpClient(httpClientMock) - .build(); - - // Clear caching to avoid cross test pollution. - miApp.tokenCache().accessTokens.clear(); - - IAuthenticationResult result = miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .build()).get(); - - assertNotNull(result.accessToken()); - assertEquals(TokenSource.IDENTITY_PROVIDER, result.metadata().tokenSource()); - - result = miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(anotherResource) - .build()).get(); + @ParameterizedTest + @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createData") + void managedIdentityTest_DifferentScopes_RequestsNewToken(ManagedIdentitySourceType source, String endpoint) throws Exception { + setUpCommonTest(source, endpoint, ManagedIdentityId.systemAssigned()); - assertNotNull(result.accessToken()); - assertEquals(TokenSource.IDENTITY_PROVIDER, result.metadata().tokenSource()); - verify(httpClientMock, times(2)).send(any()); - } + when(httpClientMock.send(any())).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); - @ParameterizedTest - @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataWrongScope") - void managedIdentityTest_WrongScopes(ManagedIdentitySourceType source, String endpoint, String resource) throws Exception { - IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); - ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); - DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); - if (source == SERVICE_FABRIC) { - ServiceFabricManagedIdentitySource.setHttpClient(httpClientMock); - } + String anotherResource = "https://graph.microsoft.com"; - if (environmentVariables.getEnvironmentVariable("SourceType").equals(ManagedIdentitySourceType.CLOUD_SHELL.toString())) { - when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(500, getMsiErrorResponseCloudShell())); - } else { - when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(500, getMsiErrorResponse())); - } + when(httpClientMock.send(expectedRequest(source, anotherResource))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); - miApp = ManagedIdentityApplication - .builder(ManagedIdentityId.systemAssigned()) - .httpClient(httpClientMock) - .build(); + IAuthenticationResult result = acquireTokenCommon(resource).get(); - // Clear caching to avoid cross test pollution. - miApp.tokenCache().accessTokens.clear(); + assertTokenFromIdentityProvider(result); - try { - miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .build()).get(); - } catch (Exception exception) { - assert(exception.getCause() instanceof MsalServiceException); + result = acquireTokenCommon(anotherResource).get(); - MsalServiceException miException = (MsalServiceException) exception.getCause(); - assertEquals(source.name(), miException.managedIdentitySource()); - assertEquals(AuthenticationErrorCode.MANAGED_IDENTITY_REQUEST_FAILED, miException.errorCode()); - return; + assertTokenFromIdentityProvider(result); + verify(httpClientMock, times(2)).send(any()); } - - fail("MsalServiceException is expected but not thrown."); - verify(httpClientMock, times(1)).send(any()); } - @ParameterizedTest - @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataWrongScope") - void managedIdentityTest_Retry(ManagedIdentitySourceType source, String endpoint, String resource) throws Exception { - IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); - ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); + @Nested + class ManagedIdentityBehaviorTests extends BaseManagedIdentityTest { + //Tests covering specific behavior/scenarios/use cases/etc. for Managed Identity flows - DefaultHttpClientManagedIdentity httpClientMock = mock(DefaultHttpClientManagedIdentity.class); - if (source == SERVICE_FABRIC) { - ServiceFabricManagedIdentitySource.setHttpClient(httpClientMock); - } + @ParameterizedTest + @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError") + void managedIdentityTest_WithClaims(ManagedIdentitySourceType source, String endpoint) throws Exception { + setUpCommonTest(source, endpoint, ManagedIdentityId.systemAssigned()); - miApp = ManagedIdentityApplication - .builder(ManagedIdentityId.systemAssigned()) - .httpClient(httpClientMock) - .build(); + when(httpClientMock.send(any())).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); - // Clear caching to avoid cross test pollution. - miApp.tokenCache().accessTokens.clear(); + String claimsJson = "{\"default\":\"claim\"}"; - //Several specific 4xx and 5xx errors, such as 500, should trigger MSAL's retry logic - when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(500, getMsiErrorResponse())); + // First call, get the token from the identity provider. + IAuthenticationResult result = acquireTokenCommon(resource).get(); - try { - miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .build()).get(); - } catch (Exception exception) { - assert(exception.getCause() instanceof MsalServiceException); + assertTokenFromIdentityProvider(result); - //There should be three retries for certain MSI error codes, so there will be four invocations of - // HttpClient's send method: the original call, and the three retries - verify(httpClientMock, times(4)).send(any()); - } + // Second call, get the token from the cache without passing the claims. + result = acquireTokenCommon(resource).get(); + + assertTokenFromCache(result); - clearInvocations(httpClientMock); - //Status codes that aren't on the list, such as 123, should not cause a retry - when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(123, getMsiErrorResponseNoRetry())); + String expectedTokenHash = StringHelper.createSha256HashHexString(result.accessToken()); + when(httpClientMock.send(expectedRequest(source, resource, true, false, expectedTokenHash))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); - try { - miApp.acquireTokenForManagedIdentity( + // Third call, when claims are passed bypass the cache. + result = miApp.acquireTokenForManagedIdentity( ManagedIdentityParameters.builder(resource) + .claims(claimsJson) .build()).get(); - } catch (Exception exception) { - assert(exception.getCause() instanceof MsalServiceException); - //Because there was no retry, there should only be one invocation of HttpClient's send method - verify(httpClientMock, times(1)).send(any()); + assertTokenFromIdentityProvider(result); - return; + verify(httpClientMock, times(2)).send(any()); } - fail("MsalServiceException is expected but not thrown."); - } + @ParameterizedTest + @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError") + void managedIdentityTest_WithCapabilitiesOnly(ManagedIdentitySourceType source, String endpoint) throws Exception { + initEnvironmentVariables(source, endpoint); + initHttpClientMock(source); - @ParameterizedTest - @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError") - void managedIdentity_RequestFailed_NoPayload(ManagedIdentitySourceType source, String endpoint) throws Exception { - IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); - ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); - DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); - if (source == SERVICE_FABRIC) { - ServiceFabricManagedIdentitySource.setHttpClient(httpClientMock); - } + when(httpClientMock.send(expectedRequest(source, resource, false, true, null))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); - when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(500, "")); + miApp = ManagedIdentityApplication + .builder(ManagedIdentityId.systemAssigned()) + .httpClient(httpClientMock) + .clientCapabilities(singletonList("cp1")) + .build(); - miApp = ManagedIdentityApplication - .builder(ManagedIdentityId.systemAssigned()) - .httpClient(httpClientMock) - .build(); + miApp.tokenCache.accessTokens.clear(); - // Clear caching to avoid cross test pollution. - miApp.tokenCache().accessTokens.clear(); + // First call, get the token from the identity provider. + IAuthenticationResult result = acquireTokenCommon(resource).get(); - try { - miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .build()).get(); - } catch (Exception exception) { - assert(exception.getCause() instanceof MsalServiceException); + assertTokenFromIdentityProvider(result); - MsalServiceException miException = (MsalServiceException) exception.getCause(); - assertEquals(source.name(), miException.managedIdentitySource()); - assertEquals(MsalError.MANAGED_IDENTITY_RESPONSE_PARSE_FAILURE, miException.errorCode()); - return; - } + // Second call, get the token from the cache without passing the claims. + result = acquireTokenCommon(resource).get(); - fail("MsalServiceException is expected but not thrown."); - verify(httpClientMock, times(1)).send(any()); - } + assertTokenFromCache(result); - @ParameterizedTest - @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError") - void managedIdentity_RequestFailed_NullResponse(ManagedIdentitySourceType source, String endpoint) throws Exception { - IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); - ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); - DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); - if (source == SERVICE_FABRIC) { - ServiceFabricManagedIdentitySource.setHttpClient(httpClientMock); + verify(httpClientMock, times(1)).send(any()); } - when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(200, "")); - miApp = ManagedIdentityApplication - .builder(ManagedIdentityId.systemAssigned()) - .httpClient(httpClientMock) - .build(); + @ParameterizedTest + @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError") + void managedIdentity_ClaimsAndCapabilities(ManagedIdentitySourceType source, String endpoint) throws Exception { + setUpCommonTest(source, endpoint, ManagedIdentityId.systemAssigned()); - // Clear caching to avoid cross test pollution. - miApp.tokenCache().accessTokens.clear(); + when(httpClientMock.send(expectedRequest(source, resource, false, true, null))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); - try { - miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .build()).get(); - } catch (Exception exception) { - assert(exception.getCause() instanceof MsalServiceException); + miApp = ManagedIdentityApplication + .builder(ManagedIdentityId.systemAssigned()) + .clientCapabilities(singletonList("cp1")) + .httpClient(httpClientMock) + .build(); - MsalServiceException miException = (MsalServiceException) exception.getCause(); - assertEquals(source.name(), miException.managedIdentitySource()); - assertEquals(AuthenticationErrorCode.MANAGED_IDENTITY_REQUEST_FAILED, miException.errorCode()); - return; - } + String claimsJson = "{\"default\":\"claim\"}"; + // First call, get the token from the identity provider. + IAuthenticationResult result = acquireTokenCommon(resource).get(); - fail("MsalServiceException is expected but not thrown."); - verify(httpClientMock, times(1)).send(any()); - } + assertTokenFromIdentityProvider(result); - @ParameterizedTest - @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError") - void managedIdentity_RequestFailed_UnreachableNetwork(ManagedIdentitySourceType source, String endpoint) throws Exception { - IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); - ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); - DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); - if (source == SERVICE_FABRIC) { - ServiceFabricManagedIdentitySource.setHttpClient(httpClientMock); - } + // Second call, get the token from the cache without passing the claims. + result = acquireTokenCommon(resource).get(); - when(httpClientMock.send(expectedRequest(source, resource))).thenThrow(new SocketException("A socket operation was attempted to an unreachable network.")); + assertTokenFromCache(result); - miApp = ManagedIdentityApplication - .builder(ManagedIdentityId.systemAssigned()) - .httpClient(httpClientMock) - .build(); + String expectedTokenHash = StringHelper.createSha256HashHexString(result.accessToken()); + when(httpClientMock.send(expectedRequest(source, resource, true, true, expectedTokenHash))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); - // Clear caching to avoid cross test pollution. - miApp.tokenCache().accessTokens.clear(); - - try { - miApp.acquireTokenForManagedIdentity( + // Third call, when claims are passed bypass the cache. + result = miApp.acquireTokenForManagedIdentity( ManagedIdentityParameters.builder(resource) + .claims(claimsJson) .build()).get(); - } catch (Exception exception) { - assert(exception.getCause() instanceof MsalServiceException); - MsalServiceException miException = (MsalServiceException) exception.getCause(); - assertEquals(source.name(), miException.managedIdentitySource()); - assertEquals(MsalError.MANAGED_IDENTITY_UNREACHABLE_NETWORK, miException.errorCode()); - return; + assertTokenFromIdentityProvider(result); } - fail("MsalServiceException is expected but not thrown."); - verify(httpClientMock, times(1)).send(any()); - } + @ParameterizedTest + @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataGetSource") + void managedIdentity_GetManagedIdentitySource(ManagedIdentitySourceType source, String endpoint, ManagedIdentitySourceType expectedSource) { + setUpTestWithoutHttpClientMock(source, endpoint); - @ParameterizedTest - @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError") - void managedIdentity_SharedCache(ManagedIdentitySourceType source, String endpoint) throws Exception { - IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); - ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); - DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); - if (source == SERVICE_FABRIC) { - ServiceFabricManagedIdentitySource.setHttpClient(httpClientMock); + ManagedIdentitySourceType miClientSourceType = ManagedIdentityClient.getManagedIdentitySource(); + ManagedIdentitySourceType miAppSourceType = ManagedIdentityApplication.getManagedIdentitySource(); + assertEquals(expectedSource, miClientSourceType); + assertEquals(expectedSource, miAppSourceType); } - when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); - - miApp = ManagedIdentityApplication - .builder(ManagedIdentityId.systemAssigned()) - .httpClient(httpClientMock) - .build(); - - // Clear caching to avoid cross test pollution. - miApp.tokenCache().accessTokens.clear(); - - ManagedIdentityApplication miApp2 = ManagedIdentityApplication - .builder(ManagedIdentityId.systemAssigned()) - .httpClient(httpClientMock) - .build(); - - IAuthenticationResult resultMiApp1 = miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .build()).get(); + @Test + void managedIdentityTest_RefreshOnHalfOfExpiresOn() throws Exception { + //All managed identity flows use the same AcquireTokenByManagedIdentitySupplier where refreshOn is set, + // so any of the MI options should let us verify that it's being set correctly + setUpCommonTest(APP_SERVICE, appServiceEndpoint, ManagedIdentityId.systemAssigned()); - assertNotNull(resultMiApp1.accessToken()); - assertEquals(TokenSource.IDENTITY_PROVIDER, resultMiApp1.metadata().tokenSource()); + when(httpClientMock.send(expectedRequest(APP_SERVICE, resource))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); - IAuthenticationResult resultMiApp2 = miApp2.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .build()).get(); + AuthenticationResult result = (AuthenticationResult) acquireTokenCommon(resource).get(); - assertNotNull(resultMiApp2.accessToken()); - assertEquals(TokenSource.CACHE, resultMiApp2.metadata().tokenSource()); + long timestampSeconds = (System.currentTimeMillis() / 1000); + long expectedRefreshIn = result.refreshOn() - timestampSeconds; + long actualRefreshIn = (result.expiresOn() - timestampSeconds)/2; - //acquireTokenForManagedIdentity does a cache lookup by default, and all ManagedIdentityApplication's share a cache, - // so calling acquireTokenForManagedIdentity with the same parameters in two different ManagedIdentityApplications - // should return the same token - assertEquals(resultMiApp1.accessToken(), resultMiApp2.accessToken()); - verify(httpClientMock, times(1)).send(any()); - } + assertTokenFromIdentityProvider(result); + //Allow a few seconds of difference to account for execution time + assertTrue((actualRefreshIn - expectedRefreshIn) <= 5); - // managedIdentityTest_WithClaims: Tests that acquiring a token with claims works correctly - @ParameterizedTest - @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError") - void managedIdentityTest_WithClaims(ManagedIdentitySourceType source, String endpoint) throws Exception { - IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); - ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); - DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); - if (source == SERVICE_FABRIC) { - ServiceFabricManagedIdentitySource.setHttpClient(httpClientMock); + verify(httpClientMock, times(1)).send(any()); } + } - when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); - - miApp = ManagedIdentityApplication - .builder(ManagedIdentityId.systemAssigned()) - .httpClient(httpClientMock) - .build(); - - // Clear caching to avoid cross test pollution. - miApp.tokenCache().accessTokens.clear(); - - String claimsJson = "{\"default\":\"claim\"}"; - - // First call, get the token from the identity provider. - IAuthenticationResult result = miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .build()).get(); - - assertNotNull(result.accessToken()); - assertEquals(TokenSource.IDENTITY_PROVIDER, result.metadata().tokenSource()); + @Nested + class ErrorHandlingTests extends BaseManagedIdentityTest { - // Second call, get the token from the cache without passing the claims. - result = miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .build()).get(); + @ParameterizedTest + @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createData") + void managedIdentityTest_SuccessfulResponse_WithInvalidJson(ManagedIdentitySourceType source, String endpoint, String resource) throws Exception { + setUpCommonTest(source, endpoint, ManagedIdentityId.systemAssigned()); - assertNotNull(result.accessToken()); - assertEquals(TokenSource.CACHE, result.metadata().tokenSource()); + when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(200, SUCCESSFUL_RESPONSE_INVALID_JSON)); - String expectedTokenHash = StringHelper.createSha256HashHexString(result.accessToken()); - when(httpClientMock.send(expectedRequest(source, resource, true, false, expectedTokenHash))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); + assertMsalServiceException(acquireTokenCommon(resource), source, MsalError.MANAGED_IDENTITY_RESPONSE_PARSE_FAILURE); + } - // Third call, when claims are passed bypass the cache. - result = miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .claims(claimsJson) - .build()).get(); + @ParameterizedTest + @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataUserAssignedNotSupported") + void managedIdentityTest_UserAssigned_NotSupported(ManagedIdentitySourceType source, String endpoint, ManagedIdentityId id) throws Exception { + setUpCommonTest(source, endpoint, id); - assertNotNull(result.accessToken()); - assertEquals(TokenSource.IDENTITY_PROVIDER, result.metadata().tokenSource()); + assertMsalServiceException(acquireTokenCommon(resource), source, MsalError.USER_ASSIGNED_MANAGED_IDENTITY_NOT_SUPPORTED); + } - verify(httpClientMock, times(2)).send(any()); - } + @ParameterizedTest + @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataWrongScope") + void managedIdentityTest_WrongScopes(ManagedIdentitySourceType source, String endpoint, String resource) throws Exception { + setUpCommonTest(source, endpoint, ManagedIdentityId.systemAssigned()); + + if (environmentVariables.getEnvironmentVariable("SourceType").equals(CLOUD_SHELL.toString())) { + when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(500, CLOUDSHELL_ERROR_RESPONSE)); + } else { + when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(500, MSI_ERROR_RESPONSE_500)); + } - @ParameterizedTest - @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError") - void managedIdentityTest_WithCapabilitiesOnly(ManagedIdentitySourceType source, String endpoint) throws Exception { - IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); - ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); - DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); - if (source == SERVICE_FABRIC) { - ServiceFabricManagedIdentitySource.setHttpClient(httpClientMock); + assertMsalServiceException(acquireTokenCommon(resource), source, MsalError.MANAGED_IDENTITY_REQUEST_FAILED); } - when(httpClientMock.send(expectedRequest(source, resource, false, true, null))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); + @ParameterizedTest + @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataWrongScope") + void managedIdentityTest_Retry(ManagedIdentitySourceType source, String endpoint, String resource) throws Exception { + IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); + ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); - miApp = ManagedIdentityApplication - .builder(ManagedIdentityId.systemAssigned()) - .httpClient(httpClientMock) - .clientCapabilities(singletonList("cp1")) - .build(); + DefaultHttpClientManagedIdentity httpClientMock = mock(DefaultHttpClientManagedIdentity.class); + if (source == SERVICE_FABRIC) { + ServiceFabricManagedIdentitySource.setHttpClient(httpClientMock); + } - // Clear caching to avoid cross test pollution. - miApp.tokenCache().accessTokens.clear(); + miApp = ManagedIdentityApplication + .builder(ManagedIdentityId.systemAssigned()) + .httpClient(httpClientMock) + .build(); - // First call, get the token from the identity provider. - IAuthenticationResult result = miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .build()).get(); + //Several specific 4xx and 5xx errors, such as 500, should trigger MSAL's retry logic + when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(500, MSI_ERROR_RESPONSE_500)); - assertNotNull(result.accessToken()); - assertEquals(TokenSource.IDENTITY_PROVIDER, result.metadata().tokenSource()); + try { + acquireTokenCommon(resource).get(); + } catch (Exception exception) { + assert(exception.getCause() instanceof MsalServiceException); - // Second call, get the token from the cache without passing the claims. - result = miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .build()).get(); + //There should be three retries for certain MSI error codes, so there will be four invocations of + // HttpClient's send method: the original call, and the three retries + verify(httpClientMock, times(4)).send(any()); + } - assertNotNull(result.accessToken()); - assertEquals(TokenSource.CACHE, result.metadata().tokenSource()); + clearInvocations(httpClientMock); + //Status codes that aren't on the list, such as 123, should not cause a retry + when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(123, MSI_ERROR_RESPONSE_NORETRY)); - verify(httpClientMock, times(1)).send(any()); - } + try { + acquireTokenCommon(resource).get(); + } catch (Exception exception) { + assert(exception.getCause() instanceof MsalServiceException); - @ParameterizedTest - @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError") - void managedIdentity_ClaimsAndCapabilities(ManagedIdentitySourceType source, String endpoint) throws Exception { - IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(source, endpoint); - ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); - DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); - if (source == SERVICE_FABRIC) { - ServiceFabricManagedIdentitySource.setHttpClient(httpClientMock); - } + //Because there was no retry, there should only be one invocation of HttpClient's send method + verify(httpClientMock, times(1)).send(any()); - when(httpClientMock.send(expectedRequest(source, resource, false, true, null))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); + return; + } - miApp = ManagedIdentityApplication - .builder(ManagedIdentityId.systemAssigned()) - .clientCapabilities(singletonList("cp1")) - .httpClient(httpClientMock) - .build(); + fail("MsalServiceException is expected but not thrown."); + } - // Clear caching to avoid cross test pollution. - miApp.tokenCache().accessTokens.clear(); + @ParameterizedTest + @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError") + void managedIdentity_RequestFailed_NoPayload(ManagedIdentitySourceType source, String endpoint) throws Exception { + setUpCommonTest(source, endpoint, ManagedIdentityId.systemAssigned()); - String claimsJson = "{\"default\":\"claim\"}"; - // First call, get the token from the identity provider. - IAuthenticationResult result = miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .build()).get(); + when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(500, "")); - assertNotNull(result.accessToken()); - assertEquals(TokenSource.IDENTITY_PROVIDER, result.metadata().tokenSource()); + assertMsalServiceException(acquireTokenCommon(resource), source, MsalError.MANAGED_IDENTITY_RESPONSE_PARSE_FAILURE); + } - // Second call, get the token from the cache without passing the claims. - result = miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .build()).get(); + @ParameterizedTest + @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError") + void managedIdentity_RequestFailed_NullResponse(ManagedIdentitySourceType source, String endpoint) throws Exception { + setUpCommonTest(source, endpoint, ManagedIdentityId.systemAssigned()); - assertNotNull(result.accessToken()); - assertEquals(TokenSource.CACHE, result.metadata().tokenSource()); + when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(200, "")); - String expectedTokenHash = StringHelper.createSha256HashHexString(result.accessToken()); - when(httpClientMock.send(expectedRequest(source, resource, true, true, expectedTokenHash))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); + assertMsalServiceException(acquireTokenCommon(resource), source, MsalError.MANAGED_IDENTITY_REQUEST_FAILED); - // Third call, when claims are passed bypass the cache. - result = miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .claims(claimsJson) - .build()).get(); + verify(httpClientMock, times(1)).send(any()); + } - assertNotNull(result.accessToken()); - assertEquals(TokenSource.IDENTITY_PROVIDER, result.metadata().tokenSource()); - } + @ParameterizedTest + @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createDataError") + void managedIdentity_RequestFailed_UnreachableNetwork(ManagedIdentitySourceType source, String endpoint) throws Exception { + setUpCommonTest(source, endpoint, ManagedIdentityId.systemAssigned()); - @ParameterizedTest - @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createInvalidClaimsData") - void managedIdentity_InvalidClaims(String claimsJson) throws Exception { - IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(APP_SERVICE, appServiceEndpoint); - ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); - DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); + when(httpClientMock.send(expectedRequest(source, resource))).thenThrow(new SocketException("A socket operation was attempted to an unreachable network.")); - miApp = ManagedIdentityApplication - .builder(ManagedIdentityId.systemAssigned()) - .httpClient(httpClientMock) - .build(); + assertMsalServiceException(acquireTokenCommon(resource), source, MsalError.MANAGED_IDENTITY_UNREACHABLE_NETWORK); - CompletableFuture future = miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .claims(claimsJson) - .build()); + verify(httpClientMock, times(1)).send(any()); + } - ExecutionException ex = assertThrows(ExecutionException.class, future::get); - assertInstanceOf(MsalClientException.class, ex.getCause()); + @ParameterizedTest + @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createInvalidClaimsData") + void managedIdentity_InvalidClaims(String claimsJson) throws Exception { + setUpCommonTest(APP_SERVICE, appServiceEndpoint, ManagedIdentityId.systemAssigned()); - MsalClientException msalException = (MsalClientException) ex.getCause(); - assertEquals(AuthenticationErrorCode.INVALID_JSON, msalException.errorCode()); + CompletableFuture future = miApp.acquireTokenForManagedIdentity( + ManagedIdentityParameters.builder(resource) + .claims(claimsJson) + .build()); - // Verify no HTTP requests were made for invalid claims - verify(httpClientMock, never()).send(any()); - } + assertMsalClientException(future, AuthenticationErrorCode.INVALID_JSON); - @Test - void managedIdentityTest_WithEmptyClaims() throws Exception { - IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(APP_SERVICE, appServiceEndpoint); - ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); - DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); + // Verify no HTTP requests were made for invalid claims + verify(httpClientMock, never()).send(any()); + } - miApp = ManagedIdentityApplication - .builder(ManagedIdentityId.systemAssigned()) - .httpClient(httpClientMock) - .build(); + @Test + void managedIdentityTest_WithEmptyClaims() throws Exception { + setUpCommonTest(APP_SERVICE, appServiceEndpoint, ManagedIdentityId.systemAssigned()); + + try { + miApp.acquireTokenForManagedIdentity( + ManagedIdentityParameters.builder(resource) + .claims("") + .build()); + } catch (Exception exception) { + assert(exception instanceof IllegalArgumentException); + } - try { - miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .claims("") - .build()); - } catch (Exception exception) { - assert(exception instanceof IllegalArgumentException); - } + try { + miApp.acquireTokenForManagedIdentity( + ManagedIdentityParameters.builder(resource) + .claims(null) + .build()); + } catch (Exception exception) { + assert(exception instanceof IllegalArgumentException); + } - try { - miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) - .claims(null) - .build()); - } catch (Exception exception) { - assert(exception instanceof IllegalArgumentException); + // Verify no HTTP requests were made for invalid claims + verify(httpClientMock, never()).send(any()); } - - // Verify no HTTP requests were made for invalid claims - verify(httpClientMock, never()).send(any()); } @Nested - class AzureArc { + class AzureArc extends BaseManagedIdentityTest{ @Test void missingAuthHeader() throws Exception { @@ -886,9 +678,7 @@ void invalidPathWithRealFile(String authHeaderKey) } private void mockHttpResponse(Map> responseHeaders) throws Exception { - IEnvironmentVariables environmentVariables = new EnvironmentVariablesHelper(AZURE_ARC, azureArcEndpoint); - ManagedIdentityApplication.setEnvironmentVariables(environmentVariables); - DefaultHttpClient httpClientMock = mock(DefaultHttpClient.class); + setUpCommonTest(AZURE_ARC, azureArcEndpoint, ManagedIdentityId.systemAssigned()); HttpResponse response = new HttpResponse(); response.statusCode(SC_UNAUTHORIZED); @@ -897,20 +687,10 @@ private void mockHttpResponse(Map> responseHeader when(httpClientMock.send( expectedRequest(AZURE_ARC, resource))).thenReturn( response); - - miApp = ManagedIdentityApplication - .builder(ManagedIdentityId.systemAssigned()) - .httpClient(httpClientMock) - .build(); - - // Clear caching to avoid cross test pollution. - miApp.tokenCache().accessTokens.clear(); } private void assertMsalServiceException(String errorCode, String message) throws Exception { - CompletableFuture future = - miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource).build()); + CompletableFuture future = acquireTokenCommon(resource); ExecutionException ex = assertThrows(ExecutionException.class, future::get); assertInstanceOf(MsalServiceException.class, ex.getCause()); From 664230b7515914dd55bd93f1c704b23889e345d4 Mon Sep 17 00:00:00 2001 From: avdunn Date: Fri, 16 May 2025 12:37:11 -0700 Subject: [PATCH 2/5] Add new constants file --- .../msal4j/ManagedIdentityTestConstants.java | 28 ++++ .../ManagedIdentityTestDataProvider.java | 145 +++++++++--------- .../aad/msal4j/ManagedIdentityTests.java | 120 +++++++-------- 3 files changed, 153 insertions(+), 140 deletions(-) create mode 100644 msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTestConstants.java diff --git a/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTestConstants.java b/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTestConstants.java new file mode 100644 index 00000000..5b2a4d69 --- /dev/null +++ b/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTestConstants.java @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.microsoft.aad.msal4j; + +class ManagedIdentityTestConstants { + // ID types + static final String CLIENT_ID = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"; + static final String RESOURCE_ID = "/subscriptions/ffa4aaa2-4444-4444-5555-e3ccedd3d046/resourcegroups/UAMI_group/providers/Microsoft.ManagedIdentityClient/userAssignedIdentities/UAMI"; + static final String OBJECT_ID = "593b2662-5af7-4a90-a9cb-5a9de615b82f"; + + // Resources + static final String RESOURCE = "https://management.azure.com"; + static final String RESOURCE_DEFAULT_SUFFIX = "https://management.azure.com/.default"; + + // Endpoints + static final String APP_SERVICE_ENDPOINT = "http://127.0.0.1:41564/msi/token"; + static final String IMDS_ENDPOINT = "http://169.254.169.254/metadata/identity/oauth2/token"; + static final String AZURE_ARC_ENDPOINT = "http://localhost:40342/metadata/identity/oauth2/token"; + static final String CLOUDSHELL_ENDPOINT = "http://localhost:40342/metadata/identity/oauth2/token"; + static final String SERVICE_FABRIC_ENDPOINT = "http://localhost:40342/metadata/identity/oauth2/token"; + + // Example responses + static final String SUCCESSFUL_RESPONSE_INVALID_JSON = "missing starting bracket \"access_token\":\"accesstoken\",\"token_type\":" + "\"Bearer\",\"client_id\":\"a bunch of problems}"; + static final String MSI_ERROR_RESPONSE_500 = "{\"statusCode\":\"500\",\"message\":\"An unexpected error occured while fetching the AAD Token.\",\"correlationId\":\"7d0c9763-ff1d-4842-a3f3-6d49e64f4513\"}"; + static final String CLOUDSHELL_ERROR_RESPONSE = "{\"error\":{\"code\":\"AudienceNotSupported\",\"message\":\"Audience user.read is not a supported MSI token audience.\"}}"; + static final String MSI_ERROR_RESPONSE_NORETRY = "{\"statusCode\":\"123\",\"message\":\"Not one of the retryable error responses\",\"correlationId\":\"7d0c9763-ff1d-4842-a3f3-6d49e64f4513\"}"; +} \ No newline at end of file diff --git a/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTestDataProvider.java b/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTestDataProvider.java index d70a2555..4abb69c7 100644 --- a/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTestDataProvider.java +++ b/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTestDataProvider.java @@ -8,114 +8,111 @@ import java.util.stream.Stream; class ManagedIdentityTestDataProvider { - private static final String CLIENT_ID = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"; - private static final String RESOURCE_ID = "/subscriptions/ffa4aaa2-4444-4444-5555-e3ccedd3d046/resourcegroups/UAMI_group/providers/Microsoft.ManagedIdentityClient/userAssignedIdentities/UAMI"; - private static final String OBJECT_ID = "593b2662-5af7-4a90-a9cb-5a9de615b82f"; - public static Stream createData() { + static Stream createData() { return Stream.of( - Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTests.appServiceEndpoint, - ManagedIdentityTests.resource), - Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTests.appServiceEndpoint, - ManagedIdentityTests.resourceDefaultSuffix), - Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTests.cloudShellEndpoint, - ManagedIdentityTests.resource), - Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTests.cloudShellEndpoint, - ManagedIdentityTests.resourceDefaultSuffix), - Arguments.of(ManagedIdentitySourceType.AZURE_ARC, ManagedIdentityTests.azureArcEndpoint, - ManagedIdentityTests.resource), - Arguments.of(ManagedIdentitySourceType.AZURE_ARC, ManagedIdentityTests.azureArcEndpoint, - ManagedIdentityTests.resourceDefaultSuffix), - Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTests.IMDS_ENDPOINT, - ManagedIdentityTests.resource), - Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTests.IMDS_ENDPOINT, - ManagedIdentityTests.resourceDefaultSuffix), + Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTestConstants.APP_SERVICE_ENDPOINT, + ManagedIdentityTestConstants.RESOURCE), + Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTestConstants.APP_SERVICE_ENDPOINT, + ManagedIdentityTestConstants.RESOURCE_DEFAULT_SUFFIX), + Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTestConstants.CLOUDSHELL_ENDPOINT, + ManagedIdentityTestConstants.RESOURCE), + Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTestConstants.CLOUDSHELL_ENDPOINT, + ManagedIdentityTestConstants.RESOURCE_DEFAULT_SUFFIX), + Arguments.of(ManagedIdentitySourceType.AZURE_ARC, ManagedIdentityTestConstants.AZURE_ARC_ENDPOINT, + ManagedIdentityTestConstants.RESOURCE), + Arguments.of(ManagedIdentitySourceType.AZURE_ARC, ManagedIdentityTestConstants.AZURE_ARC_ENDPOINT, + ManagedIdentityTestConstants.RESOURCE_DEFAULT_SUFFIX), + Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTestConstants.IMDS_ENDPOINT, + ManagedIdentityTestConstants.RESOURCE), + Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTestConstants.IMDS_ENDPOINT, + ManagedIdentityTestConstants.RESOURCE_DEFAULT_SUFFIX), Arguments.of(ManagedIdentitySourceType.IMDS, null, - ManagedIdentityTests.resource), - Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTests.serviceFabricEndpoint, - ManagedIdentityTests.resource), - Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTests.serviceFabricEndpoint, - ManagedIdentityTests.resourceDefaultSuffix)); + ManagedIdentityTestConstants.RESOURCE), + Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTestConstants.SERVICE_FABRIC_ENDPOINT, + ManagedIdentityTestConstants.RESOURCE), + Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTestConstants.SERVICE_FABRIC_ENDPOINT, + ManagedIdentityTestConstants.RESOURCE_DEFAULT_SUFFIX)); } - public static Stream createDataUserAssigned() { + static Stream createDataUserAssigned() { return Stream.of( - Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTests.appServiceEndpoint, - ManagedIdentityId.userAssignedClientId(CLIENT_ID)), - Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTests.appServiceEndpoint, - ManagedIdentityId.userAssignedResourceId(RESOURCE_ID)), - Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTests.appServiceEndpoint, - ManagedIdentityId.userAssignedObjectId(OBJECT_ID)), + Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTestConstants.APP_SERVICE_ENDPOINT, + ManagedIdentityId.userAssignedClientId(ManagedIdentityTestConstants.CLIENT_ID)), + Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTestConstants.APP_SERVICE_ENDPOINT, + ManagedIdentityId.userAssignedResourceId(ManagedIdentityTestConstants.RESOURCE_ID)), + Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTestConstants.APP_SERVICE_ENDPOINT, + ManagedIdentityId.userAssignedObjectId(ManagedIdentityTestConstants.OBJECT_ID)), Arguments.of(ManagedIdentitySourceType.IMDS, null, - ManagedIdentityId.userAssignedClientId(CLIENT_ID)), + ManagedIdentityId.userAssignedClientId(ManagedIdentityTestConstants.CLIENT_ID)), Arguments.of(ManagedIdentitySourceType.IMDS, null, - ManagedIdentityId.userAssignedResourceId(RESOURCE_ID)), + ManagedIdentityId.userAssignedResourceId(ManagedIdentityTestConstants.RESOURCE_ID)), Arguments.of(ManagedIdentitySourceType.IMDS, null, - ManagedIdentityId.userAssignedObjectId(OBJECT_ID)), - Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTests.serviceFabricEndpoint, - ManagedIdentityId.userAssignedResourceId(CLIENT_ID)), - Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTests.serviceFabricEndpoint, - ManagedIdentityId.userAssignedResourceId(RESOURCE_ID)), - Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTests.serviceFabricEndpoint, - ManagedIdentityId.userAssignedObjectId(OBJECT_ID))); + ManagedIdentityId.userAssignedObjectId(ManagedIdentityTestConstants.OBJECT_ID)), + Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTestConstants.SERVICE_FABRIC_ENDPOINT, + ManagedIdentityId.userAssignedResourceId(ManagedIdentityTestConstants.CLIENT_ID)), + Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTestConstants.SERVICE_FABRIC_ENDPOINT, + ManagedIdentityId.userAssignedResourceId(ManagedIdentityTestConstants.RESOURCE_ID)), + Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTestConstants.SERVICE_FABRIC_ENDPOINT, + ManagedIdentityId.userAssignedObjectId(ManagedIdentityTestConstants.OBJECT_ID))); } - public static Stream createDataUserAssignedNotSupported() { + static Stream createDataUserAssignedNotSupported() { return Stream.of( - Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTests.cloudShellEndpoint, - ManagedIdentityId.userAssignedClientId(CLIENT_ID)), - Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTests.cloudShellEndpoint, - ManagedIdentityId.userAssignedResourceId(RESOURCE_ID)), - Arguments.of(ManagedIdentitySourceType.AZURE_ARC, ManagedIdentityTests.azureArcEndpoint, - ManagedIdentityId.userAssignedClientId(CLIENT_ID)), - Arguments.of(ManagedIdentitySourceType.AZURE_ARC, ManagedIdentityTests.azureArcEndpoint, - ManagedIdentityId.userAssignedResourceId(RESOURCE_ID))); + Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTestConstants.CLOUDSHELL_ENDPOINT, + ManagedIdentityId.userAssignedClientId(ManagedIdentityTestConstants.CLIENT_ID)), + Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTestConstants.CLOUDSHELL_ENDPOINT, + ManagedIdentityId.userAssignedResourceId(ManagedIdentityTestConstants.RESOURCE_ID)), + Arguments.of(ManagedIdentitySourceType.AZURE_ARC, ManagedIdentityTestConstants.AZURE_ARC_ENDPOINT, + ManagedIdentityId.userAssignedClientId(ManagedIdentityTestConstants.CLIENT_ID)), + Arguments.of(ManagedIdentitySourceType.AZURE_ARC, ManagedIdentityTestConstants.AZURE_ARC_ENDPOINT, + ManagedIdentityId.userAssignedResourceId(ManagedIdentityTestConstants.RESOURCE_ID))); } - public static Stream createDataWrongScope() { + static Stream createDataWrongScope() { return Stream.of( - Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTests.appServiceEndpoint, + Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTestConstants.APP_SERVICE_ENDPOINT, "user.read"), - Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTests.appServiceEndpoint, + Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTestConstants.APP_SERVICE_ENDPOINT, "https://management.core.windows.net//user_impersonation"), - Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTests.cloudShellEndpoint, + Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTestConstants.CLOUDSHELL_ENDPOINT, "user.read"), - Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTests.cloudShellEndpoint, + Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTestConstants.CLOUDSHELL_ENDPOINT, "https://management.core.windows.net//user_impersonation"), - Arguments.of(ManagedIdentitySourceType.AZURE_ARC, ManagedIdentityTests.azureArcEndpoint, + Arguments.of(ManagedIdentitySourceType.AZURE_ARC, ManagedIdentityTestConstants.AZURE_ARC_ENDPOINT, "user.read"), - Arguments.of(ManagedIdentitySourceType.AZURE_ARC, ManagedIdentityTests.azureArcEndpoint, + Arguments.of(ManagedIdentitySourceType.AZURE_ARC, ManagedIdentityTestConstants.AZURE_ARC_ENDPOINT, "https://management.core.windows.net//user_impersonation"), - Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTests.IMDS_ENDPOINT, + Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTestConstants.IMDS_ENDPOINT, "user.read"), - Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTests.IMDS_ENDPOINT, + Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTestConstants.IMDS_ENDPOINT, "https://management.core.windows.net//user_impersonation"), - Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTests.serviceFabricEndpoint, + Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTestConstants.SERVICE_FABRIC_ENDPOINT, "user.read"), - Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTests.serviceFabricEndpoint, + Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTestConstants.SERVICE_FABRIC_ENDPOINT, "https://management.core.windows.net//user_impersonation")); } - public static Stream createDataError() { + static Stream createDataError() { return Stream.of( - Arguments.of(ManagedIdentitySourceType.AZURE_ARC, ManagedIdentityTests.azureArcEndpoint), - Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTests.appServiceEndpoint), - Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTests.cloudShellEndpoint), - Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTests.IMDS_ENDPOINT), - Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTests.serviceFabricEndpoint)); + Arguments.of(ManagedIdentitySourceType.AZURE_ARC, ManagedIdentityTestConstants.AZURE_ARC_ENDPOINT), + Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTestConstants.APP_SERVICE_ENDPOINT), + Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTestConstants.CLOUDSHELL_ENDPOINT), + Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTestConstants.IMDS_ENDPOINT), + Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTestConstants.SERVICE_FABRIC_ENDPOINT)); } - public static Stream createDataGetSource() { + static Stream createDataGetSource() { return Stream.of( - Arguments.of(ManagedIdentitySourceType.AZURE_ARC, ManagedIdentityTests.azureArcEndpoint, ManagedIdentitySourceType.AZURE_ARC), - Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTests.appServiceEndpoint, ManagedIdentitySourceType.APP_SERVICE), - Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTests.cloudShellEndpoint, ManagedIdentitySourceType.CLOUD_SHELL), - Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTests.IMDS_ENDPOINT, ManagedIdentitySourceType.DEFAULT_TO_IMDS), + Arguments.of(ManagedIdentitySourceType.AZURE_ARC, ManagedIdentityTestConstants.AZURE_ARC_ENDPOINT, ManagedIdentitySourceType.AZURE_ARC), + Arguments.of(ManagedIdentitySourceType.APP_SERVICE, ManagedIdentityTestConstants.APP_SERVICE_ENDPOINT, ManagedIdentitySourceType.APP_SERVICE), + Arguments.of(ManagedIdentitySourceType.CLOUD_SHELL, ManagedIdentityTestConstants.CLOUDSHELL_ENDPOINT, ManagedIdentitySourceType.CLOUD_SHELL), + Arguments.of(ManagedIdentitySourceType.IMDS, ManagedIdentityTestConstants.IMDS_ENDPOINT, ManagedIdentitySourceType.DEFAULT_TO_IMDS), Arguments.of(ManagedIdentitySourceType.IMDS, "", ManagedIdentitySourceType.DEFAULT_TO_IMDS), - Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTests.serviceFabricEndpoint, ManagedIdentitySourceType.SERVICE_FABRIC)); + Arguments.of(ManagedIdentitySourceType.SERVICE_FABRIC, ManagedIdentityTestConstants.SERVICE_FABRIC_ENDPOINT, ManagedIdentitySourceType.SERVICE_FABRIC)); } - public static Stream createInvalidClaimsData() { + static Stream createInvalidClaimsData() { return Stream.of( Arguments.of("invalid json format"), Arguments.of("{\"access_token\": }") diff --git a/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTests.java b/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTests.java index 6a1cf4f2..671dc6c7 100644 --- a/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTests.java +++ b/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTests.java @@ -37,23 +37,6 @@ @TestInstance(TestInstance.Lifecycle.PER_METHOD) class ManagedIdentityTests { - static final String resource = "https://management.azure.com"; - static final String resourceDefaultSuffix = "https://management.azure.com/.default"; - static final String appServiceEndpoint = "http://127.0.0.1:41564/msi/token"; - static final String IMDS_ENDPOINT = "http://169.254.169.254/metadata/identity/oauth2/token"; - static final String azureArcEndpoint = "http://localhost:40342/metadata/identity/oauth2/token"; - static final String cloudShellEndpoint = "http://localhost:40342/metadata/identity/oauth2/token"; - static final String serviceFabricEndpoint = "http://localhost:40342/metadata/identity/oauth2/token"; - private static ManagedIdentityApplication miApp; - - static final String SUCCESSFUL_RESPONSE_INVALID_JSON = "missing starting bracket \"access_token\":\"accesstoken\",\"token_type\":" + "\"Bearer\",\"client_id\":\"a bunch of problems}"; - //Realistic error response that should trigger a retry - static final String MSI_ERROR_RESPONSE_500 = "{\"statusCode\":\"500\",\"message\":\"An unexpected error occured while fetching the AAD Token.\",\"correlationId\":\"7d0c9763-ff1d-4842-a3f3-6d49e64f4513\"}"; - //Cloud Shell error responses follow a different style, the error info is in a second JSON - static final String CLOUDSHELL_ERROR_RESPONSE = "{\"error\":{\"code\":\"AudienceNotSupported\",\"message\":\"Audience user.read is not a supported MSI token audience.\"}}"; - //Response with an error code that should not trigger a retry - static final String MSI_ERROR_RESPONSE_NORETRY = "{\"statusCode\":\"123\",\"message\":\"Not one of the retryable error responses\",\"correlationId\":\"7d0c9763-ff1d-4842-a3f3-6d49e64f4513\"}"; - private String getSuccessfulResponse(String resource) { long expiresOn = (System.currentTimeMillis() / 1000) + (24 * 3600);//A long-lived, 24 hour token return "{\"access_token\":\"accesstoken\",\"expires_on\":\"" + expiresOn + "\",\"resource\":\"" + resource + "\",\"token_type\":" + @@ -114,22 +97,22 @@ private String configureSourceSpecificParameters(ManagedIdentitySourceType sourc case APP_SERVICE: queryParameters.put("api-version", singletonList("2019-08-01")); headers.put("X-IDENTITY-HEADER", "secret"); - return appServiceEndpoint; + return ManagedIdentityTestConstants.APP_SERVICE_ENDPOINT; case CLOUD_SHELL: headers.put("ContentType", "application/x-www-form-urlencoded"); headers.put("Metadata", "true"); - return cloudShellEndpoint; + return ManagedIdentityTestConstants.CLOUDSHELL_ENDPOINT; case AZURE_ARC: queryParameters.put("api-version", singletonList("2019-11-01")); headers.put("Metadata", "true"); - return azureArcEndpoint; + return ManagedIdentityTestConstants.AZURE_ARC_ENDPOINT; case SERVICE_FABRIC: queryParameters.put("api-version", singletonList("2019-07-01-preview")); headers.put("secret", "secret"); - return serviceFabricEndpoint; + return ManagedIdentityTestConstants.SERVICE_FABRIC_ENDPOINT; case IMDS: case NONE: @@ -137,7 +120,7 @@ private String configureSourceSpecificParameters(ManagedIdentitySourceType sourc default: queryParameters.put("api-version", singletonList("2018-02-01")); headers.put("Metadata", "true"); - return IMDS_ENDPOINT; + return ManagedIdentityTestConstants.IMDS_ENDPOINT; } } @@ -279,9 +262,10 @@ void managedIdentityTest_SystemAssigned_SuccessfulResponse(ManagedIdentitySource void managedIdentityTest_UserAssigned_SuccessfulResponse(ManagedIdentitySourceType source, String endpoint, ManagedIdentityId id) throws Exception { setUpCommonTest(source, endpoint, id); - when(httpClientMock.send(expectedRequest(source, resource, id))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); + when(httpClientMock.send(expectedRequest(source, ManagedIdentityTestConstants.RESOURCE, id))) + .thenReturn(expectedResponse(200, getSuccessfulResponse(ManagedIdentityTestConstants.RESOURCE))); - IAuthenticationResult result = acquireTokenCommon(resource).get(); + IAuthenticationResult result = acquireTokenCommon(ManagedIdentityTestConstants.RESOURCE).get(); assertTokenFromIdentityProvider(result); verify(httpClientMock, times(1)).send(any()); @@ -292,19 +276,19 @@ void managedIdentityTest_UserAssigned_SuccessfulResponse(ManagedIdentitySourceTy void managedIdentity_SharedCache(ManagedIdentitySourceType source, String endpoint) throws Exception { setUpCommonTest(source, endpoint, ManagedIdentityId.systemAssigned()); - when(httpClientMock.send(any())).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); + when(httpClientMock.send(any())).thenReturn(expectedResponse(200, getSuccessfulResponse(ManagedIdentityTestConstants.RESOURCE))); ManagedIdentityApplication miApp2 = ManagedIdentityApplication .builder(ManagedIdentityId.systemAssigned()) .httpClient(httpClientMock) .build(); - IAuthenticationResult resultMiApp1 = acquireTokenCommon(resource).get(); + IAuthenticationResult resultMiApp1 = acquireTokenCommon(ManagedIdentityTestConstants.RESOURCE).get(); assertTokenFromIdentityProvider(resultMiApp1); IAuthenticationResult resultMiApp2 = miApp2.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) + ManagedIdentityParameters.builder(ManagedIdentityTestConstants.RESOURCE) .build()).get(); assertTokenFromCache(resultMiApp2); @@ -321,13 +305,13 @@ void managedIdentity_SharedCache(ManagedIdentitySourceType source, String endpoi void managedIdentityTest_DifferentScopes_RequestsNewToken(ManagedIdentitySourceType source, String endpoint) throws Exception { setUpCommonTest(source, endpoint, ManagedIdentityId.systemAssigned()); - when(httpClientMock.send(any())).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); + when(httpClientMock.send(any())).thenReturn(expectedResponse(200, getSuccessfulResponse(ManagedIdentityTestConstants.RESOURCE))); String anotherResource = "https://graph.microsoft.com"; - when(httpClientMock.send(expectedRequest(source, anotherResource))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); + when(httpClientMock.send(expectedRequest(source, anotherResource))).thenReturn(expectedResponse(200, getSuccessfulResponse(ManagedIdentityTestConstants.RESOURCE))); - IAuthenticationResult result = acquireTokenCommon(resource).get(); + IAuthenticationResult result = acquireTokenCommon(ManagedIdentityTestConstants.RESOURCE).get(); assertTokenFromIdentityProvider(result); @@ -347,26 +331,27 @@ class ManagedIdentityBehaviorTests extends BaseManagedIdentityTest { void managedIdentityTest_WithClaims(ManagedIdentitySourceType source, String endpoint) throws Exception { setUpCommonTest(source, endpoint, ManagedIdentityId.systemAssigned()); - when(httpClientMock.send(any())).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); + when(httpClientMock.send(any())).thenReturn(expectedResponse(200, getSuccessfulResponse(ManagedIdentityTestConstants.RESOURCE))); String claimsJson = "{\"default\":\"claim\"}"; // First call, get the token from the identity provider. - IAuthenticationResult result = acquireTokenCommon(resource).get(); + IAuthenticationResult result = acquireTokenCommon(ManagedIdentityTestConstants.RESOURCE).get(); assertTokenFromIdentityProvider(result); // Second call, get the token from the cache without passing the claims. - result = acquireTokenCommon(resource).get(); + result = acquireTokenCommon(ManagedIdentityTestConstants.RESOURCE).get(); assertTokenFromCache(result); String expectedTokenHash = StringHelper.createSha256HashHexString(result.accessToken()); - when(httpClientMock.send(expectedRequest(source, resource, true, false, expectedTokenHash))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); + when(httpClientMock.send(expectedRequest(source, ManagedIdentityTestConstants.RESOURCE, true, false, expectedTokenHash))) + .thenReturn(expectedResponse(200, getSuccessfulResponse(ManagedIdentityTestConstants.RESOURCE))); // Third call, when claims are passed bypass the cache. result = miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) + ManagedIdentityParameters.builder(ManagedIdentityTestConstants.RESOURCE) .claims(claimsJson) .build()).get(); @@ -381,7 +366,8 @@ void managedIdentityTest_WithCapabilitiesOnly(ManagedIdentitySourceType source, initEnvironmentVariables(source, endpoint); initHttpClientMock(source); - when(httpClientMock.send(expectedRequest(source, resource, false, true, null))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); + when(httpClientMock.send(expectedRequest(source, ManagedIdentityTestConstants.RESOURCE, false, true, null))) + .thenReturn(expectedResponse(200, getSuccessfulResponse(ManagedIdentityTestConstants.RESOURCE))); miApp = ManagedIdentityApplication .builder(ManagedIdentityId.systemAssigned()) @@ -392,12 +378,12 @@ void managedIdentityTest_WithCapabilitiesOnly(ManagedIdentitySourceType source, miApp.tokenCache.accessTokens.clear(); // First call, get the token from the identity provider. - IAuthenticationResult result = acquireTokenCommon(resource).get(); + IAuthenticationResult result = acquireTokenCommon(ManagedIdentityTestConstants.RESOURCE).get(); assertTokenFromIdentityProvider(result); // Second call, get the token from the cache without passing the claims. - result = acquireTokenCommon(resource).get(); + result = acquireTokenCommon(ManagedIdentityTestConstants.RESOURCE).get(); assertTokenFromCache(result); @@ -409,7 +395,8 @@ void managedIdentityTest_WithCapabilitiesOnly(ManagedIdentitySourceType source, void managedIdentity_ClaimsAndCapabilities(ManagedIdentitySourceType source, String endpoint) throws Exception { setUpCommonTest(source, endpoint, ManagedIdentityId.systemAssigned()); - when(httpClientMock.send(expectedRequest(source, resource, false, true, null))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); + when(httpClientMock.send(expectedRequest(source, ManagedIdentityTestConstants.RESOURCE, false, true, null))) + .thenReturn(expectedResponse(200, getSuccessfulResponse(ManagedIdentityTestConstants.RESOURCE))); miApp = ManagedIdentityApplication .builder(ManagedIdentityId.systemAssigned()) @@ -419,21 +406,22 @@ void managedIdentity_ClaimsAndCapabilities(ManagedIdentitySourceType source, Str String claimsJson = "{\"default\":\"claim\"}"; // First call, get the token from the identity provider. - IAuthenticationResult result = acquireTokenCommon(resource).get(); + IAuthenticationResult result = acquireTokenCommon(ManagedIdentityTestConstants.RESOURCE).get(); assertTokenFromIdentityProvider(result); // Second call, get the token from the cache without passing the claims. - result = acquireTokenCommon(resource).get(); + result = acquireTokenCommon(ManagedIdentityTestConstants.RESOURCE).get(); assertTokenFromCache(result); String expectedTokenHash = StringHelper.createSha256HashHexString(result.accessToken()); - when(httpClientMock.send(expectedRequest(source, resource, true, true, expectedTokenHash))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); + when(httpClientMock.send(expectedRequest(source, ManagedIdentityTestConstants.RESOURCE, true, true, expectedTokenHash))) + .thenReturn(expectedResponse(200, getSuccessfulResponse(ManagedIdentityTestConstants.RESOURCE))); // Third call, when claims are passed bypass the cache. result = miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) + ManagedIdentityParameters.builder(ManagedIdentityTestConstants.RESOURCE) .claims(claimsJson) .build()).get(); @@ -455,11 +443,11 @@ void managedIdentity_GetManagedIdentitySource(ManagedIdentitySourceType source, void managedIdentityTest_RefreshOnHalfOfExpiresOn() throws Exception { //All managed identity flows use the same AcquireTokenByManagedIdentitySupplier where refreshOn is set, // so any of the MI options should let us verify that it's being set correctly - setUpCommonTest(APP_SERVICE, appServiceEndpoint, ManagedIdentityId.systemAssigned()); + setUpCommonTest(APP_SERVICE, ManagedIdentityTestConstants.APP_SERVICE_ENDPOINT, ManagedIdentityId.systemAssigned()); - when(httpClientMock.send(expectedRequest(APP_SERVICE, resource))).thenReturn(expectedResponse(200, getSuccessfulResponse(resource))); + when(httpClientMock.send(expectedRequest(APP_SERVICE, ManagedIdentityTestConstants.RESOURCE))).thenReturn(expectedResponse(200, getSuccessfulResponse(ManagedIdentityTestConstants.RESOURCE))); - AuthenticationResult result = (AuthenticationResult) acquireTokenCommon(resource).get(); + AuthenticationResult result = (AuthenticationResult) acquireTokenCommon(ManagedIdentityTestConstants.RESOURCE).get(); long timestampSeconds = (System.currentTimeMillis() / 1000); long expectedRefreshIn = result.refreshOn() - timestampSeconds; @@ -481,7 +469,7 @@ class ErrorHandlingTests extends BaseManagedIdentityTest { void managedIdentityTest_SuccessfulResponse_WithInvalidJson(ManagedIdentitySourceType source, String endpoint, String resource) throws Exception { setUpCommonTest(source, endpoint, ManagedIdentityId.systemAssigned()); - when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(200, SUCCESSFUL_RESPONSE_INVALID_JSON)); + when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(200, ManagedIdentityTestConstants.SUCCESSFUL_RESPONSE_INVALID_JSON)); assertMsalServiceException(acquireTokenCommon(resource), source, MsalError.MANAGED_IDENTITY_RESPONSE_PARSE_FAILURE); } @@ -491,7 +479,7 @@ void managedIdentityTest_SuccessfulResponse_WithInvalidJson(ManagedIdentitySourc void managedIdentityTest_UserAssigned_NotSupported(ManagedIdentitySourceType source, String endpoint, ManagedIdentityId id) throws Exception { setUpCommonTest(source, endpoint, id); - assertMsalServiceException(acquireTokenCommon(resource), source, MsalError.USER_ASSIGNED_MANAGED_IDENTITY_NOT_SUPPORTED); + assertMsalServiceException(acquireTokenCommon(ManagedIdentityTestConstants.RESOURCE), source, MsalError.USER_ASSIGNED_MANAGED_IDENTITY_NOT_SUPPORTED); } @ParameterizedTest @@ -500,9 +488,9 @@ void managedIdentityTest_WrongScopes(ManagedIdentitySourceType source, String en setUpCommonTest(source, endpoint, ManagedIdentityId.systemAssigned()); if (environmentVariables.getEnvironmentVariable("SourceType").equals(CLOUD_SHELL.toString())) { - when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(500, CLOUDSHELL_ERROR_RESPONSE)); + when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(500, ManagedIdentityTestConstants.CLOUDSHELL_ERROR_RESPONSE)); } else { - when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(500, MSI_ERROR_RESPONSE_500)); + when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(500, ManagedIdentityTestConstants.MSI_ERROR_RESPONSE_500)); } assertMsalServiceException(acquireTokenCommon(resource), source, MsalError.MANAGED_IDENTITY_REQUEST_FAILED); @@ -525,7 +513,7 @@ void managedIdentityTest_Retry(ManagedIdentitySourceType source, String endpoint .build(); //Several specific 4xx and 5xx errors, such as 500, should trigger MSAL's retry logic - when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(500, MSI_ERROR_RESPONSE_500)); + when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(500, ManagedIdentityTestConstants.MSI_ERROR_RESPONSE_500)); try { acquireTokenCommon(resource).get(); @@ -539,7 +527,7 @@ void managedIdentityTest_Retry(ManagedIdentitySourceType source, String endpoint clearInvocations(httpClientMock); //Status codes that aren't on the list, such as 123, should not cause a retry - when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(123, MSI_ERROR_RESPONSE_NORETRY)); + when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(123, ManagedIdentityTestConstants.MSI_ERROR_RESPONSE_NORETRY)); try { acquireTokenCommon(resource).get(); @@ -560,9 +548,9 @@ void managedIdentityTest_Retry(ManagedIdentitySourceType source, String endpoint void managedIdentity_RequestFailed_NoPayload(ManagedIdentitySourceType source, String endpoint) throws Exception { setUpCommonTest(source, endpoint, ManagedIdentityId.systemAssigned()); - when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(500, "")); + when(httpClientMock.send(expectedRequest(source, ManagedIdentityTestConstants.RESOURCE))).thenReturn(expectedResponse(500, "")); - assertMsalServiceException(acquireTokenCommon(resource), source, MsalError.MANAGED_IDENTITY_RESPONSE_PARSE_FAILURE); + assertMsalServiceException(acquireTokenCommon(ManagedIdentityTestConstants.RESOURCE), source, MsalError.MANAGED_IDENTITY_RESPONSE_PARSE_FAILURE); } @ParameterizedTest @@ -570,9 +558,9 @@ void managedIdentity_RequestFailed_NoPayload(ManagedIdentitySourceType source, S void managedIdentity_RequestFailed_NullResponse(ManagedIdentitySourceType source, String endpoint) throws Exception { setUpCommonTest(source, endpoint, ManagedIdentityId.systemAssigned()); - when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(200, "")); + when(httpClientMock.send(expectedRequest(source, ManagedIdentityTestConstants.RESOURCE))).thenReturn(expectedResponse(200, "")); - assertMsalServiceException(acquireTokenCommon(resource), source, MsalError.MANAGED_IDENTITY_REQUEST_FAILED); + assertMsalServiceException(acquireTokenCommon(ManagedIdentityTestConstants.RESOURCE), source, MsalError.MANAGED_IDENTITY_REQUEST_FAILED); verify(httpClientMock, times(1)).send(any()); } @@ -582,9 +570,9 @@ void managedIdentity_RequestFailed_NullResponse(ManagedIdentitySourceType source void managedIdentity_RequestFailed_UnreachableNetwork(ManagedIdentitySourceType source, String endpoint) throws Exception { setUpCommonTest(source, endpoint, ManagedIdentityId.systemAssigned()); - when(httpClientMock.send(expectedRequest(source, resource))).thenThrow(new SocketException("A socket operation was attempted to an unreachable network.")); + when(httpClientMock.send(expectedRequest(source, ManagedIdentityTestConstants.RESOURCE))).thenThrow(new SocketException("A socket operation was attempted to an unreachable network.")); - assertMsalServiceException(acquireTokenCommon(resource), source, MsalError.MANAGED_IDENTITY_UNREACHABLE_NETWORK); + assertMsalServiceException(acquireTokenCommon(ManagedIdentityTestConstants.RESOURCE), source, MsalError.MANAGED_IDENTITY_UNREACHABLE_NETWORK); verify(httpClientMock, times(1)).send(any()); } @@ -592,10 +580,10 @@ void managedIdentity_RequestFailed_UnreachableNetwork(ManagedIdentitySourceType @ParameterizedTest @MethodSource("com.microsoft.aad.msal4j.ManagedIdentityTestDataProvider#createInvalidClaimsData") void managedIdentity_InvalidClaims(String claimsJson) throws Exception { - setUpCommonTest(APP_SERVICE, appServiceEndpoint, ManagedIdentityId.systemAssigned()); + setUpCommonTest(APP_SERVICE, ManagedIdentityTestConstants.APP_SERVICE_ENDPOINT, ManagedIdentityId.systemAssigned()); CompletableFuture future = miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) + ManagedIdentityParameters.builder(ManagedIdentityTestConstants.RESOURCE) .claims(claimsJson) .build()); @@ -607,11 +595,11 @@ void managedIdentity_InvalidClaims(String claimsJson) throws Exception { @Test void managedIdentityTest_WithEmptyClaims() throws Exception { - setUpCommonTest(APP_SERVICE, appServiceEndpoint, ManagedIdentityId.systemAssigned()); + setUpCommonTest(APP_SERVICE, ManagedIdentityTestConstants.APP_SERVICE_ENDPOINT, ManagedIdentityId.systemAssigned()); try { miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) + ManagedIdentityParameters.builder(ManagedIdentityTestConstants.RESOURCE) .claims("") .build()); } catch (Exception exception) { @@ -620,7 +608,7 @@ void managedIdentityTest_WithEmptyClaims() throws Exception { try { miApp.acquireTokenForManagedIdentity( - ManagedIdentityParameters.builder(resource) + ManagedIdentityParameters.builder(ManagedIdentityTestConstants.RESOURCE) .claims(null) .build()); } catch (Exception exception) { @@ -678,19 +666,19 @@ void invalidPathWithRealFile(String authHeaderKey) } private void mockHttpResponse(Map> responseHeaders) throws Exception { - setUpCommonTest(AZURE_ARC, azureArcEndpoint, ManagedIdentityId.systemAssigned()); + setUpCommonTest(AZURE_ARC, ManagedIdentityTestConstants.AZURE_ARC_ENDPOINT, ManagedIdentityId.systemAssigned()); HttpResponse response = new HttpResponse(); response.statusCode(SC_UNAUTHORIZED); response.headers().putAll(responseHeaders); when(httpClientMock.send( - expectedRequest(AZURE_ARC, resource))).thenReturn( + expectedRequest(AZURE_ARC, ManagedIdentityTestConstants.RESOURCE))).thenReturn( response); } private void assertMsalServiceException(String errorCode, String message) throws Exception { - CompletableFuture future = acquireTokenCommon(resource); + CompletableFuture future = acquireTokenCommon(ManagedIdentityTestConstants.RESOURCE); ExecutionException ex = assertThrows(ExecutionException.class, future::get); assertInstanceOf(MsalServiceException.class, ex.getCause()); From c6d2fe6189c8040d402d4142b0a5c6c2d6421767 Mon Sep 17 00:00:00 2001 From: Avery-Dunn <62066438+Avery-Dunn@users.noreply.github.com> Date: Tue, 20 May 2025 13:08:09 -0700 Subject: [PATCH 3/5] Update msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTestConstants.java Co-authored-by: Gladwin Johnson <90415114+gladjohn@users.noreply.github.com> --- .../com/microsoft/aad/msal4j/ManagedIdentityTestConstants.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTestConstants.java b/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTestConstants.java index 5b2a4d69..1abf69d1 100644 --- a/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTestConstants.java +++ b/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTestConstants.java @@ -21,7 +21,7 @@ class ManagedIdentityTestConstants { static final String SERVICE_FABRIC_ENDPOINT = "http://localhost:40342/metadata/identity/oauth2/token"; // Example responses - static final String SUCCESSFUL_RESPONSE_INVALID_JSON = "missing starting bracket \"access_token\":\"accesstoken\",\"token_type\":" + "\"Bearer\",\"client_id\":\"a bunch of problems}"; + static final String RESPONSE_MALFORMED_JSON = "missing starting bracket \"access_token\":\"accesstoken\",\"token_type\":" + "\"Bearer\",\"client_id\":\"a bunch of problems}"; static final String MSI_ERROR_RESPONSE_500 = "{\"statusCode\":\"500\",\"message\":\"An unexpected error occured while fetching the AAD Token.\",\"correlationId\":\"7d0c9763-ff1d-4842-a3f3-6d49e64f4513\"}"; static final String CLOUDSHELL_ERROR_RESPONSE = "{\"error\":{\"code\":\"AudienceNotSupported\",\"message\":\"Audience user.read is not a supported MSI token audience.\"}}"; static final String MSI_ERROR_RESPONSE_NORETRY = "{\"statusCode\":\"123\",\"message\":\"Not one of the retryable error responses\",\"correlationId\":\"7d0c9763-ff1d-4842-a3f3-6d49e64f4513\"}"; From d1f51ed9f275c401d866f6c04edbb64099372182 Mon Sep 17 00:00:00 2001 From: avdunn Date: Tue, 20 May 2025 13:16:01 -0700 Subject: [PATCH 4/5] Update constant --- .../java/com/microsoft/aad/msal4j/ManagedIdentityTests.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTests.java b/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTests.java index 671dc6c7..60e527fc 100644 --- a/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTests.java +++ b/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTests.java @@ -469,7 +469,7 @@ class ErrorHandlingTests extends BaseManagedIdentityTest { void managedIdentityTest_SuccessfulResponse_WithInvalidJson(ManagedIdentitySourceType source, String endpoint, String resource) throws Exception { setUpCommonTest(source, endpoint, ManagedIdentityId.systemAssigned()); - when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(200, ManagedIdentityTestConstants.SUCCESSFUL_RESPONSE_INVALID_JSON)); + when(httpClientMock.send(expectedRequest(source, resource))).thenReturn(expectedResponse(200, ManagedIdentityTestConstants.RESPONSE_MALFORMED_JSON)); assertMsalServiceException(acquireTokenCommon(resource), source, MsalError.MANAGED_IDENTITY_RESPONSE_PARSE_FAILURE); } From 8e68c926cda65fbf00a6814aa5f5ad5cb1368df4 Mon Sep 17 00:00:00 2001 From: Avery-Dunn <62066438+Avery-Dunn@users.noreply.github.com> Date: Tue, 20 May 2025 13:18:57 -0700 Subject: [PATCH 5/5] Update msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTestConstants.java Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../com/microsoft/aad/msal4j/ManagedIdentityTestConstants.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTestConstants.java b/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTestConstants.java index 1abf69d1..02f95754 100644 --- a/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTestConstants.java +++ b/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTestConstants.java @@ -22,7 +22,7 @@ class ManagedIdentityTestConstants { // Example responses static final String RESPONSE_MALFORMED_JSON = "missing starting bracket \"access_token\":\"accesstoken\",\"token_type\":" + "\"Bearer\",\"client_id\":\"a bunch of problems}"; - static final String MSI_ERROR_RESPONSE_500 = "{\"statusCode\":\"500\",\"message\":\"An unexpected error occured while fetching the AAD Token.\",\"correlationId\":\"7d0c9763-ff1d-4842-a3f3-6d49e64f4513\"}"; + static final String MSI_ERROR_RESPONSE_500 = "{\"statusCode\":\"500\",\"message\":\"An unexpected error occurred while fetching the AAD Token.\",\"correlationId\":\"7d0c9763-ff1d-4842-a3f3-6d49e64f4513\"}"; static final String CLOUDSHELL_ERROR_RESPONSE = "{\"error\":{\"code\":\"AudienceNotSupported\",\"message\":\"Audience user.read is not a supported MSI token audience.\"}}"; static final String MSI_ERROR_RESPONSE_NORETRY = "{\"statusCode\":\"123\",\"message\":\"Not one of the retryable error responses\",\"correlationId\":\"7d0c9763-ff1d-4842-a3f3-6d49e64f4513\"}"; } \ No newline at end of file