diff --git a/build.gradle b/build.gradle index 1caa0915a35..b53d7ca96aa 100644 --- a/build.gradle +++ b/build.gradle @@ -110,7 +110,6 @@ subprojects { tasks.withType(JavaCompile) { options.encoding = "UTF-8" options.compilerArgs.add("-parameters") - options.release = 8 } } @@ -139,6 +138,13 @@ allprojects { } } } + + tasks.withType(JavaCompile).configureEach { + javaCompiler = javaToolchains.compilerFor { + languageVersion = JavaLanguageVersion.of(8) + } + } + } if (hasProperty('buildScan')) { diff --git a/config/spring-security-config.gradle b/config/spring-security-config.gradle index 2ce079d305b..ffb1186721c 100644 --- a/config/spring-security-config.gradle +++ b/config/spring-security-config.gradle @@ -113,7 +113,6 @@ dependencies { testRuntimeOnly 'org.hsqldb:hsqldb' } - rncToXsd { rncDir = file('src/main/resources/org/springframework/security/config/') xsdDir = rncDir @@ -130,3 +129,33 @@ tasks.withType(KotlinCompile).configureEach { } build.dependsOn rncToXsd + +compileTestJava { + exclude "org/springframework/security/config/annotation/web/configurers/saml2/**", "org/springframework/security/config/http/Saml2*" +} + +task compileSaml2TestJava(type: JavaCompile) { + javaCompiler = javaToolchains.compilerFor { + languageVersion = JavaLanguageVersion.of(11) + } + source = sourceSets.test.java.srcDirs + include "org/springframework/security/config/annotation/web/configurers/saml2/**", "org/springframework/security/config/http/Saml2*" + classpath = sourceSets.test.compileClasspath + destinationDirectory = new File("${buildDir}/classes/java/test") + options.sourcepath = sourceSets.test.java.getSourceDirectories() +} + +task saml2Tests(type: Test) { + javaLauncher = javaToolchains.launcherFor { + languageVersion = JavaLanguageVersion.of(11) + } + filter { + includeTestsMatching "org.springframework.security.config.annotation.web.configurers.saml2.*" + } + useJUnitPlatform() + dependsOn compileSaml2TestJava +} + +tasks.named('test') { + finalizedBy 'saml2Tests' +} diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java index ee8e598ff11..a4cfb815dc8 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java @@ -32,8 +32,6 @@ import org.springframework.security.config.annotation.web.configurers.CsrfConfigurer; import org.springframework.security.core.Authentication; import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest; -import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider; -import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationRequestFactory; import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider; import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationRequestFactory; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory; @@ -55,6 +53,7 @@ import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; import org.springframework.util.StringUtils; /** @@ -112,6 +111,8 @@ public final class Saml2LoginConfigurer> extends AbstractAuthenticationFilterConfigurer, Saml2WebSsoAuthenticationFilter> { + private static final String OPEN_SAML_4_VERSION = "4"; + private String loginPage; private String authenticationRequestUri = "/saml2/authenticate/{registrationId}"; @@ -320,11 +321,9 @@ private Saml2AuthenticationRequestFactory getAuthenticationRequestFactory(B http return resolver; } if (version().startsWith("4")) { - return new OpenSaml4AuthenticationRequestFactory(); - } - else { - return new OpenSamlAuthenticationRequestFactory(); + return OpenSaml4LoginSupportFactory.getAuthenticationRequestFactory(); } + return new OpenSamlAuthenticationRequestFactory(); } private Saml2AuthenticationRequestContextResolver getAuthenticationRequestContextResolver(B http) { @@ -354,17 +353,9 @@ private AuthenticationConverter getAuthenticationConverter(B http) { return authenticationConverterBean; } - private String version() { - String version = Version.getVersion(); - if (version != null) { - return version; - } - return Version.getVersion(); - } - private void registerDefaultAuthenticationProvider(B http) { if (version().startsWith("4")) { - http.authenticationProvider(postProcess(new OpenSaml4AuthenticationProvider())); + http.authenticationProvider(postProcess(OpenSaml4LoginSupportFactory.getAuthenticationProvider())); } else { http.authenticationProvider(postProcess(new OpenSamlAuthenticationProvider())); @@ -414,6 +405,19 @@ private Saml2AuthenticationRequestRepository return repository; } + private String version() { + String version = Version.getVersion(); + if (StringUtils.hasText(version)) { + return version; + } + boolean openSaml4ClassPresent = ClassUtils + .isPresent("org.opensaml.core.xml.persist.impl.PassthroughSourceStrategy", null); + if (openSaml4ClassPresent) { + return OPEN_SAML_4_VERSION; + } + throw new IllegalStateException("cannot determine OpenSAML version"); + } + private C getSharedOrBean(B http, Class clazz) { C shared = http.getSharedObject(clazz); if (shared != null) { @@ -441,4 +445,33 @@ private void setSharedObject(B http, Class clazz, C object) { } } + private static class OpenSaml4LoginSupportFactory { + + private static Saml2AuthenticationRequestFactory getAuthenticationRequestFactory() { + try { + Class authenticationRequestFactory = ClassUtils.forName( + "org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationRequestFactory", + OpenSaml4LoginSupportFactory.class.getClassLoader()); + return (Saml2AuthenticationRequestFactory) authenticationRequestFactory.getDeclaredConstructor() + .newInstance(); + } + catch (ReflectiveOperationException ex) { + throw new IllegalStateException("Could not instantiate OpenSaml4AuthenticationRequestFactory", ex); + } + } + + private static AuthenticationProvider getAuthenticationProvider() { + try { + Class authenticationProvider = ClassUtils.forName( + "org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider", + OpenSaml4LoginSupportFactory.class.getClassLoader()); + return (AuthenticationProvider) authenticationProvider.getDeclaredConstructor().newInstance(); + } + catch (ReflectiveOperationException ex) { + throw new IllegalStateException("Could not instantiate OpenSaml4AuthenticationProvider", ex); + } + } + + } + } diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LogoutConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LogoutConfigurer.java index 9e4c4eac0f3..a5251bc94b1 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LogoutConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LogoutConfigurer.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -47,8 +47,6 @@ import org.springframework.security.saml2.provider.service.web.authentication.logout.HttpSessionLogoutRequestRepository; import org.springframework.security.saml2.provider.service.web.authentication.logout.OpenSaml3LogoutRequestResolver; import org.springframework.security.saml2.provider.service.web.authentication.logout.OpenSaml3LogoutResponseResolver; -import org.springframework.security.saml2.provider.service.web.authentication.logout.OpenSaml4LogoutRequestResolver; -import org.springframework.security.saml2.provider.service.web.authentication.logout.OpenSaml4LogoutResponseResolver; import org.springframework.security.saml2.provider.service.web.authentication.logout.Saml2LogoutRequestFilter; import org.springframework.security.saml2.provider.service.web.authentication.logout.Saml2LogoutRequestRepository; import org.springframework.security.saml2.provider.service.web.authentication.logout.Saml2LogoutRequestResolver; @@ -67,6 +65,8 @@ import org.springframework.security.web.util.matcher.AndRequestMatcher; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; +import org.springframework.util.ClassUtils; +import org.springframework.util.StringUtils; /** * Adds SAML 2.0 logout support. @@ -113,6 +113,8 @@ public final class Saml2LogoutConfigurer> extends AbstractHttpConfigurer, H> { + private static final String OPEN_SAML_4_VERSION = "4"; + private ApplicationContext context; private RelyingPartyRegistrationRepository relyingPartyRegistrationRepository; @@ -304,6 +306,19 @@ private Saml2LogoutResponseResolver createSaml2LogoutResponseResolver( return this.logoutResponseConfigurer.logoutResponseResolver(relyingPartyRegistrationResolver); } + private String version() { + String version = Version.getVersion(); + if (StringUtils.hasText(version)) { + return version; + } + boolean openSaml4ClassPresent = ClassUtils + .isPresent("org.opensaml.core.xml.persist.impl.PassthroughSourceStrategy", null); + if (openSaml4ClassPresent) { + return OPEN_SAML_4_VERSION; + } + throw new IllegalStateException("cannot determine OpenSAML version"); + } + private C getBeanOrNull(Class clazz) { if (this.context == null) { return null; @@ -314,14 +329,6 @@ private C getBeanOrNull(Class clazz) { return this.context.getBean(clazz); } - private String version() { - String version = Version.getVersion(); - if (version != null) { - return version; - } - return Version.getVersion(); - } - /** * A configurer for SAML 2.0 LogoutRequest components */ @@ -402,7 +409,7 @@ private Saml2LogoutRequestResolver logoutRequestResolver( return this.logoutRequestResolver; } if (version().startsWith("4")) { - return new OpenSaml4LogoutRequestResolver(relyingPartyRegistrationResolver); + return OpenSaml4LogoutSupportFactory.getLogoutRequestResolver(relyingPartyRegistrationResolver); } return new OpenSaml3LogoutRequestResolver(relyingPartyRegistrationResolver); } @@ -470,13 +477,13 @@ private Saml2LogoutResponseValidator logoutResponseValidator() { private Saml2LogoutResponseResolver logoutResponseResolver( RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) { - if (this.logoutResponseResolver == null) { - if (version().startsWith("4")) { - return new OpenSaml4LogoutResponseResolver(relyingPartyRegistrationResolver); - } - return new OpenSaml3LogoutResponseResolver(relyingPartyRegistrationResolver); + if (this.logoutResponseResolver != null) { + return this.logoutResponseResolver; } - return this.logoutResponseResolver; + if (version().startsWith("4")) { + return OpenSaml4LogoutSupportFactory.getLogoutResponseResolver(relyingPartyRegistrationResolver); + } + return new OpenSaml3LogoutResponseResolver(relyingPartyRegistrationResolver); } } @@ -519,4 +526,38 @@ public void logout(HttpServletRequest request, HttpServletResponse response, Aut } + private static class OpenSaml4LogoutSupportFactory { + + private static Saml2LogoutResponseResolver getLogoutResponseResolver( + RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) { + try { + Class logoutResponseResolver = ClassUtils.forName( + "org.springframework.security.saml2.provider.service.web.authentication.logout.OpenSaml4LogoutResponseResolver", + OpenSaml4LogoutSupportFactory.class.getClassLoader()); + return (Saml2LogoutResponseResolver) logoutResponseResolver + .getDeclaredConstructor(RelyingPartyRegistrationResolver.class) + .newInstance(relyingPartyRegistrationResolver); + } + catch (ReflectiveOperationException ex) { + throw new IllegalStateException("Could not instantiate OpenSaml4LogoutResponseResolver", ex); + } + } + + private static Saml2LogoutRequestResolver getLogoutRequestResolver( + RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) { + try { + Class logoutRequestResolver = ClassUtils.forName( + "org.springframework.security.saml2.provider.service.web.authentication.logout.OpenSaml4LogoutRequestResolver", + OpenSaml4LogoutSupportFactory.class.getClassLoader()); + return (Saml2LogoutRequestResolver) logoutRequestResolver + .getDeclaredConstructor(RelyingPartyRegistrationResolver.class) + .newInstance(relyingPartyRegistrationResolver); + } + catch (ReflectiveOperationException ex) { + throw new IllegalStateException("Could not instantiate OpenSaml4LogoutRequestResolver", ex); + } + } + + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpX509Tests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpX509Tests.java index 7868021f15f..d54a0a286e2 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpX509Tests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/NamespaceHttpX509Tests.java @@ -21,13 +21,11 @@ import java.security.cert.CertificateFactory; import java.security.cert.X509Certificate; -import javax.security.auth.x500.X500Principal; import javax.servlet.http.HttpServletRequest; -import org.bouncycastle.asn1.x500.X500Name; -import org.bouncycastle.asn1.x500.style.BCStyle; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import sun.security.x509.X500Name; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; @@ -242,8 +240,12 @@ protected void configure(HttpSecurity http) throws Exception { } private String extractCommonName(X509Certificate certificate) { - X500Principal principal = certificate.getSubjectX500Principal(); - return new X500Name(principal.getName()).getRDNs(BCStyle.CN)[0].getFirst().getValue().toString(); + try { + return ((X500Name) certificate.getSubjectDN()).getCommonName(); + } + catch (Exception ex) { + throw new IllegalArgumentException(ex); + } } } diff --git a/core/src/test/java/org/springframework/security/core/SpringSecurityCoreVersionTests.java b/core/src/test/java/org/springframework/security/core/SpringSecurityCoreVersionTests.java index d2d9b0e665a..fe2bfbadb13 100644 --- a/core/src/test/java/org/springframework/security/core/SpringSecurityCoreVersionTests.java +++ b/core/src/test/java/org/springframework/security/core/SpringSecurityCoreVersionTests.java @@ -18,7 +18,6 @@ import java.lang.reflect.Field; import java.lang.reflect.Method; -import java.lang.reflect.Modifier; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -61,24 +60,15 @@ public class SpringSecurityCoreVersionTests { @BeforeEach public void setup() throws Exception { - setFinalStaticField(SpringSecurityCoreVersion.class, "logger", this.logger); + Field logger = ReflectionUtils.findField(SpringSecurityCoreVersion.class, "logger"); + StaticFinalReflectionUtils.setField(logger, this.logger); } @AfterEach public void cleanup() throws Exception { System.clearProperty(getDisableChecksProperty()); - setFinalStaticField(SpringSecurityCoreVersion.class, "logger", - LogFactory.getLog(SpringSecurityCoreVersion.class)); - } - - private static void setFinalStaticField(Class clazz, String fieldName, Object value) - throws ReflectiveOperationException { - Field field = clazz.getDeclaredField(fieldName); - field.setAccessible(true); - Field modifiers = Field.class.getDeclaredField("modifiers"); - modifiers.setAccessible(true); - modifiers.setInt(field, field.getModifiers() & ~Modifier.FINAL); - field.set(null, value); + Field logger = ReflectionUtils.findField(SpringSecurityCoreVersion.class, "logger"); + StaticFinalReflectionUtils.setField(logger, LogFactory.getLog(SpringSecurityCoreVersion.class)); } @Test diff --git a/core/src/test/java/org/springframework/security/core/StaticFinalReflectionUtils.java b/core/src/test/java/org/springframework/security/core/StaticFinalReflectionUtils.java new file mode 100644 index 00000000000..1cff2226804 --- /dev/null +++ b/core/src/test/java/org/springframework/security/core/StaticFinalReflectionUtils.java @@ -0,0 +1,115 @@ +/* + * Copyright 2008 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.core; + +import java.lang.reflect.Field; +import java.lang.reflect.Modifier; +import java.security.AccessController; +import java.security.PrivilegedAction; + +import sun.misc.Unsafe; + +import org.springframework.objenesis.instantiator.util.UnsafeUtils; + +/** + * Used for setting static variables even if they are private static final. + * + * The code in this class has been adopted from Powermock's WhiteboxImpl. + * + * @author Rob Winch + */ +final class StaticFinalReflectionUtils { + + /** + * Used to support setting static fields that are final using Java's Unsafe. If the + * field is not static final, use + * {@link org.springframework.test.util.ReflectionTestUtils}. + * @param field the field to set + * @param newValue the new value + */ + static void setField(final Field field, final Object newValue) { + try { + field.setAccessible(true); + int fieldModifiersMask = field.getModifiers(); + boolean isFinalModifierPresent = (fieldModifiersMask & Modifier.FINAL) == Modifier.FINAL; + if (isFinalModifierPresent) { + AccessController.doPrivileged(new PrivilegedAction() { + @Override + public Object run() { + try { + Unsafe unsafe = UnsafeUtils.getUnsafe(); + long offset = unsafe.staticFieldOffset(field); + Object base = unsafe.staticFieldBase(field); + setFieldUsingUnsafe(base, field.getType(), offset, newValue, unsafe); + return null; + } + catch (Throwable thrown) { + throw new RuntimeException(thrown); + } + } + }); + } + else { + field.set(null, newValue); + } + } + catch (SecurityException ex) { + throw new RuntimeException(ex); + } + catch (IllegalAccessException ex) { + throw new RuntimeException(ex); + } + catch (IllegalArgumentException ex) { + throw new RuntimeException(ex); + } + } + + private static void setFieldUsingUnsafe(Object base, Class type, long offset, Object newValue, Unsafe unsafe) { + if (type == Integer.TYPE) { + unsafe.putInt(base, offset, ((Integer) newValue)); + } + else if (type == Short.TYPE) { + unsafe.putShort(base, offset, ((Short) newValue)); + } + else if (type == Long.TYPE) { + unsafe.putLong(base, offset, ((Long) newValue)); + } + else if (type == Byte.TYPE) { + unsafe.putByte(base, offset, ((Byte) newValue)); + } + else if (type == Boolean.TYPE) { + unsafe.putBoolean(base, offset, ((Boolean) newValue)); + } + else if (type == Float.TYPE) { + unsafe.putFloat(base, offset, ((Float) newValue)); + } + else if (type == Double.TYPE) { + unsafe.putDouble(base, offset, ((Double) newValue)); + } + else if (type == Character.TYPE) { + unsafe.putChar(base, offset, ((Character) newValue)); + } + else { + unsafe.putObject(base, offset, newValue); + } + } + + private StaticFinalReflectionUtils() { + } + +} diff --git a/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle b/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle index 8cf5aff9de7..c9cc0e6f684 100644 --- a/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle +++ b/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle @@ -36,10 +36,19 @@ configurations { } compileOpensaml4MainJava { + javaCompiler = javaToolchains.compilerFor { + languageVersion = JavaLanguageVersion.of(11) + } sourceCompatibility = '11' targetCompatibility = '11' } +compileOpensaml4TestJava { + javaCompiler = javaToolchains.compilerFor { + languageVersion = JavaLanguageVersion.of(11) + } +} + dependencies { management platform(project(":spring-security-dependencies")) api project(':spring-security-web')