From f8ae8d640c393dcaf162c5d34291442893664f8d Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Fri, 25 Apr 2025 15:52:30 +0200 Subject: [PATCH 1/3] Prepare issue branch. --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index a6dc167a03..2607252e7c 100644 --- a/pom.xml +++ b/pom.xml @@ -5,7 +5,7 @@ org.springframework.data spring-data-commons - 4.0.0-SNAPSHOT + 4.0.x-GH-3279-SNAPSHOT Spring Data Core Core Spring concepts underpinning every Spring Data module. From 8e1b87aed771ec9501fcfc32816618096a1a5bcd Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Thu, 24 Apr 2025 07:36:55 +0200 Subject: [PATCH 2/3] Refine Repository Composition retrieval during AOT Add module identifier and base repository implementation properties. Fix fragment function previously overriding already set property due to name clash. Extend tests for bean definition resolution and code block creation. --- .../aot/generate/AotRepositoryBuilder.java | 7 +- .../aot/generate/MethodContributor.java | 2 +- ...toryBeanDefinitionPropertiesDecorator.java | 2 +- .../config/AotRepositoryInformation.java | 36 +++- .../RepositoryBeanDefinitionReader.java | 133 ++++++++++---- ...RepositoryRegistrationAotContribution.java | 21 +-- .../core/RepositoryInformation.java | 5 + .../core/RepositoryInformationSupport.java | 2 +- .../support/RepositoryFactoryBeanSupport.java | 16 +- .../core/support/RepositoryFragment.java | 2 +- src/test/java/example/UserRepository.java | 2 +- .../java/example/UserRepositoryExtension.java | 25 +++ .../example/UserRepositoryExtensionImpl.java | 29 +++ .../AotRepositoryBuilderUnitTests.java | 157 ++++++++++++++++ .../AotRepositoryMethodBuilderUnitTests.java | 88 +++++++++ .../MethodCapturingRepositoryContributor.java | 57 ++++++ .../RepositoryContributorUnitTests.java | 167 +++++++++++++++++- .../RepositoryBeanDefinitionReaderTests.java | 122 +++++++++++++ 18 files changed, 800 insertions(+), 73 deletions(-) create mode 100644 src/test/java/example/UserRepositoryExtension.java create mode 100644 src/test/java/example/UserRepositoryExtensionImpl.java create mode 100644 src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilderUnitTests.java create mode 100644 src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryMethodBuilderUnitTests.java create mode 100644 src/test/java/org/springframework/data/repository/aot/generate/MethodCapturingRepositoryContributor.java create mode 100644 src/test/java/org/springframework/data/repository/config/RepositoryBeanDefinitionReaderTests.java diff --git a/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilder.java b/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilder.java index c1ea88e7b1..199ca89f63 100644 --- a/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilder.java +++ b/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilder.java @@ -138,9 +138,8 @@ public AotBundle build() { this.customizer.customize(repositoryInformation, generationMetadata, builder); JavaFile javaFile = JavaFile.builder(packageName(), builder.build()).build(); - // TODO: module identifier AotRepositoryMetadata metadata = new AotRepositoryMetadata(repositoryInformation.getRepositoryInterface().getName(), - "", repositoryType, methodMetadata); + repositoryInformation.moduleName() != null ? repositoryInformation.moduleName() : "", repositoryType, methodMetadata); return new AotBundle(javaFile, metadata.toJson()); } @@ -148,14 +147,14 @@ public AotBundle build() { private void contributeMethod(Method method, RepositoryComposition repositoryComposition, List methodMetadata, TypeSpec.Builder builder) { - if (repositoryInformation.isCustomMethod(method) || repositoryInformation.isBaseClassMethod(method)) { + if (repositoryInformation.isCustomMethod(method) || (repositoryInformation.isBaseClassMethod(method) && !repositoryInformation.isQueryMethod(method))) { RepositoryFragment fragment = repositoryComposition.findFragment(method); if (fragment != null) { methodMetadata.add(getFragmentMetadata(method, fragment)); + return; } - return; } if (method.isBridge() || method.isDefault() || java.lang.reflect.Modifier.isStatic(method.getModifiers())) { diff --git a/src/main/java/org/springframework/data/repository/aot/generate/MethodContributor.java b/src/main/java/org/springframework/data/repository/aot/generate/MethodContributor.java index cfd29faf02..b30b2fa5ab 100644 --- a/src/main/java/org/springframework/data/repository/aot/generate/MethodContributor.java +++ b/src/main/java/org/springframework/data/repository/aot/generate/MethodContributor.java @@ -36,7 +36,7 @@ public abstract class MethodContributor { private final M queryMethod; private final QueryMetadata metadata; - private MethodContributor(M queryMethod, QueryMetadata metadata) { + MethodContributor(M queryMethod, QueryMetadata metadata) { this.queryMethod = queryMethod; this.metadata = metadata; } diff --git a/src/main/java/org/springframework/data/repository/config/AotRepositoryBeanDefinitionPropertiesDecorator.java b/src/main/java/org/springframework/data/repository/config/AotRepositoryBeanDefinitionPropertiesDecorator.java index 1326ac4370..d25e0f1cb3 100644 --- a/src/main/java/org/springframework/data/repository/config/AotRepositoryBeanDefinitionPropertiesDecorator.java +++ b/src/main/java/org/springframework/data/repository/config/AotRepositoryBeanDefinitionPropertiesDecorator.java @@ -55,7 +55,7 @@ public CodeBlock decorate() { // bring in properties as usual builder.add(inheritedProperties.get()); - builder.add("beanDefinition.getPropertyValues().addPropertyValue(\"repositoryFragments\", new $T() {\n", + builder.add("beanDefinition.getPropertyValues().addPropertyValue(\"repositoryFragmentsFunction\", new $T() {\n", RepositoryFactoryBeanSupport.RepositoryFragmentsFunction.class); builder.indent(); builder.add("public $T getRepositoryFragments($T beanFactory, $T context) {\n", diff --git a/src/main/java/org/springframework/data/repository/config/AotRepositoryInformation.java b/src/main/java/org/springframework/data/repository/config/AotRepositoryInformation.java index c4ea580ab8..1ddcbde9a4 100644 --- a/src/main/java/org/springframework/data/repository/config/AotRepositoryInformation.java +++ b/src/main/java/org/springframework/data/repository/config/AotRepositoryInformation.java @@ -21,10 +21,12 @@ import java.util.Set; import java.util.function.Supplier; +import org.jspecify.annotations.Nullable; import org.springframework.data.repository.core.RepositoryInformation; import org.springframework.data.repository.core.RepositoryInformationSupport; import org.springframework.data.repository.core.RepositoryMetadata; import org.springframework.data.repository.core.support.RepositoryComposition; +import org.springframework.data.repository.core.support.RepositoryComposition.RepositoryFragments; import org.springframework.data.repository.core.support.RepositoryFragment; import org.springframework.data.util.Lazy; @@ -36,16 +38,31 @@ */ class AotRepositoryInformation extends RepositoryInformationSupport implements RepositoryInformation { + private final @Nullable String moduleName; private final Supplier>> fragments; - private Lazy baseComposition = Lazy.of(() -> { - return RepositoryComposition.of(RepositoryFragment.structural(getRepositoryBaseClass())); - }); - AotRepositoryInformation(Supplier repositoryMetadata, Supplier> repositoryBaseClass, - Supplier>> fragments) { + private final Lazy repositoryComposition; + private final Lazy baseComposition; + + AotRepositoryInformation(@Nullable String moduleName, Supplier repositoryMetadata, + Supplier> repositoryBaseClass, Supplier>> fragments) { super(repositoryMetadata, repositoryBaseClass); + + this.moduleName = moduleName; this.fragments = fragments; + + this.repositoryComposition = Lazy + .of(() -> RepositoryComposition.fromMetadata(getMetadata()).append(RepositoryFragments.from(getFragments()))); + + this.baseComposition = Lazy.of(() -> { + + RepositoryComposition targetRepoComposition = repositoryComposition.get(); + + return RepositoryComposition.of(RepositoryFragment.structural(getRepositoryBaseClass())) // + .withArgumentConverter(targetRepoComposition.getArgumentConverter()) // + .withMethodLookup(targetRepoComposition.getMethodLookup()); + }); } /** @@ -57,10 +74,9 @@ public Set> getFragments() { return new LinkedHashSet<>(fragments.get()); } - // Not required during AOT processing. @Override public boolean isCustomMethod(Method method) { - return false; + return repositoryComposition.get().findMethod(method).isPresent(); } @Override @@ -75,7 +91,11 @@ public Method getTargetClassMethod(Method method) { @Override public RepositoryComposition getRepositoryComposition() { - return baseComposition.get().append(RepositoryComposition.RepositoryFragments.from(fragments.get())); + return repositoryComposition.get(); } + @Override + public @Nullable String moduleName() { + return moduleName; + } } diff --git a/src/main/java/org/springframework/data/repository/config/RepositoryBeanDefinitionReader.java b/src/main/java/org/springframework/data/repository/config/RepositoryBeanDefinitionReader.java index 0ac1ae991a..1209903d3b 100644 --- a/src/main/java/org/springframework/data/repository/config/RepositoryBeanDefinitionReader.java +++ b/src/main/java/org/springframework/data/repository/config/RepositoryBeanDefinitionReader.java @@ -16,17 +16,21 @@ package org.springframework.data.repository.config; import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; import java.util.List; -import java.util.function.Supplier; -import java.util.stream.Collectors; +import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.beans.factory.config.ConstructorArgumentValues.ValueHolder; +import org.springframework.beans.factory.config.RuntimeBeanReference; +import org.springframework.beans.factory.support.RegisteredBean; +import org.springframework.core.ResolvableType; +import org.springframework.data.repository.CrudRepository; +import org.springframework.data.repository.PagingAndSortingRepository; import org.springframework.data.repository.core.RepositoryInformation; -import org.springframework.data.repository.core.support.DefaultRepositoryMetadata; +import org.springframework.data.repository.core.RepositoryMetadata; +import org.springframework.data.repository.core.support.AbstractRepositoryMetadata; import org.springframework.data.repository.core.support.RepositoryFragment; -import org.springframework.data.util.Lazy; +import org.springframework.data.repository.core.support.RepositoryFragment.ImplementedRepositoryFragment; import org.springframework.util.ClassUtils; /** @@ -38,49 +42,108 @@ */ class RepositoryBeanDefinitionReader { - static RepositoryInformation readRepositoryInformation(RepositoryConfiguration metadata, - ConfigurableListableBeanFactory beanFactory) { - - return new AotRepositoryInformation(metadataSupplier(metadata, beanFactory), - repositoryBaseClass(metadata, beanFactory), fragments(metadata, beanFactory)); + /** + * @return + */ + static RepositoryInformation repositoryInformation(RepositoryConfiguration repoConfig, RegisteredBean repoBean) { + return repositoryInformation(repoConfig, repoBean.getMergedBeanDefinition(), repoBean.getBeanFactory()); } - private static Supplier>> fragments(RepositoryConfiguration metadata, + /** + * @param source the RepositoryFactoryBeanSupport bean definition. + * @param beanFactory + * @return + */ + @SuppressWarnings("NullAway") + static RepositoryInformation repositoryInformation(RepositoryConfiguration repoConfig, BeanDefinition source, ConfigurableListableBeanFactory beanFactory) { - if (metadata instanceof RepositoryFragmentConfigurationProvider provider) { - - return Lazy.of(() -> { - return provider.getFragmentConfiguration().stream().flatMap(it -> { - - List> fragments = new ArrayList<>(1); + RepositoryMetadata metadata = AbstractRepositoryMetadata + .getMetadata(forName(repoConfig.getRepositoryInterface(), beanFactory)); + Class repositoryBaseClass = readRepositoryBaseClass(source, beanFactory); + List> fragmentList = readRepositoryFragments(source, beanFactory); + if (source.getPropertyValues().contains("customImplementation")) { + + Object o = source.getPropertyValues().get("customImplementation"); + if (o instanceof RuntimeBeanReference rbr) { + BeanDefinition customImplBeanDefintion = beanFactory.getBeanDefinition(rbr.getBeanName()); + Class beanType = forName(customImplBeanDefintion.getBeanClassName(), beanFactory); + ResolvableType[] interfaces = ResolvableType.forClass(beanType).getInterfaces(); + if (interfaces.length == 1) { + fragmentList.add(new ImplementedRepositoryFragment(interfaces[0].toClass(), beanType)); + } else { + boolean found = false; + for (ResolvableType i : interfaces) { + if (beanType.getSimpleName().contains(i.resolve().getSimpleName())) { + fragmentList.add(new ImplementedRepositoryFragment(interfaces[0].toClass(), beanType)); + found = true; + break; + } + } + if (!found) { + fragmentList.add(RepositoryFragment.implemented(beanType)); + } + } + } + } - fragments.add(RepositoryFragment.implemented(forName(it.getClassName(), beanFactory))); + String moduleName = (String) source.getPropertyValues().get("moduleName"); + AotRepositoryInformation repositoryInformation = new AotRepositoryInformation(moduleName, () -> metadata, + () -> repositoryBaseClass, () -> fragmentList); + return repositoryInformation; + } - if (it.getInterfaceName() != null) { - fragments.add(RepositoryFragment.structural(forName(it.getInterfaceName(), beanFactory))); - } + @SuppressWarnings("NullAway") + private static Class readRepositoryBaseClass(BeanDefinition source, ConfigurableListableBeanFactory beanFactory) { - return fragments.stream(); - }).collect(Collectors.toList()); - }); + Object repoBaseClassName = source.getPropertyValues().get("repositoryBaseClass"); + if (repoBaseClassName != null) { + return forName(repoBaseClassName.toString(), beanFactory); } - - return Lazy.of(Collections::emptyList); + if (source.getPropertyValues().contains("moduleBaseClass")) { + return forName((String) source.getPropertyValues().get("moduleBaseClass"), beanFactory); + } + return Dummy.class; } - @SuppressWarnings({ "rawtypes", "unchecked" }) - private static Supplier> repositoryBaseClass(RepositoryConfiguration metadata, + @SuppressWarnings("NullAway") + private static List> readRepositoryFragments(BeanDefinition source, ConfigurableListableBeanFactory beanFactory) { - return Lazy.of(() -> (Class) metadata.getRepositoryBaseClassName().map(it -> forName(it.toString(), beanFactory)) - .orElse(Object.class)); + RuntimeBeanReference beanReference = (RuntimeBeanReference) source.getPropertyValues().get("repositoryFragments"); + BeanDefinition fragments = beanFactory.getBeanDefinition(beanReference.getBeanName()); + + ValueHolder fragmentBeanNameList = fragments.getConstructorArgumentValues().getArgumentValue(0, List.class); + List fragmentBeanNames = (List) fragmentBeanNameList.getValue(); + + List> fragmentList = new ArrayList<>(); + for (String beanName : fragmentBeanNames) { + + BeanDefinition fragmentBeanDefinition = beanFactory.getBeanDefinition(beanName); + ValueHolder argumentValue = fragmentBeanDefinition.getConstructorArgumentValues().getArgumentValue(0, + String.class); + ValueHolder argumentValue1 = fragmentBeanDefinition.getConstructorArgumentValues().getArgumentValue(1, null, null, + null); + Object fragmentClassName = argumentValue.getValue(); + + try { + Class type = ClassUtils.forName(fragmentClassName.toString(), beanFactory.getBeanClassLoader()); + + if (argumentValue1 != null && argumentValue1.getValue() instanceof RuntimeBeanReference rbf) { + BeanDefinition implBeanDef = beanFactory.getBeanDefinition(rbf.getBeanName()); + Class implClass = ClassUtils.forName(implBeanDef.getBeanClassName(), beanFactory.getBeanClassLoader()); + fragmentList.add(new RepositoryFragment.ImplementedRepositoryFragment(type, implClass)); + } else { + fragmentList.add(RepositoryFragment.structural(type)); + } + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } + } + return fragmentList; } - private static Supplier metadataSupplier( - RepositoryConfiguration metadata, ConfigurableListableBeanFactory beanFactory) { - return Lazy.of(() -> new DefaultRepositoryMetadata(forName(metadata.getRepositoryInterface(), beanFactory))); - } + static abstract class Dummy implements CrudRepository, PagingAndSortingRepository {} static Class forName(String name, ConfigurableListableBeanFactory beanFactory) { try { diff --git a/src/main/java/org/springframework/data/repository/config/RepositoryRegistrationAotContribution.java b/src/main/java/org/springframework/data/repository/config/RepositoryRegistrationAotContribution.java index 92405a0aeb..40b2cc43a7 100644 --- a/src/main/java/org/springframework/data/repository/config/RepositoryRegistrationAotContribution.java +++ b/src/main/java/org/springframework/data/repository/config/RepositoryRegistrationAotContribution.java @@ -28,7 +28,6 @@ import java.util.function.Predicate; import org.jspecify.annotations.Nullable; - import org.springframework.aop.SpringProxy; import org.springframework.aop.framework.Advised; import org.springframework.aot.generate.GenerationContext; @@ -49,7 +48,6 @@ import org.springframework.data.repository.Repository; import org.springframework.data.repository.aot.generate.RepositoryContributor; import org.springframework.data.repository.core.RepositoryInformation; -import org.springframework.data.repository.core.support.RepositoryFactoryBeanSupport; import org.springframework.data.repository.core.support.RepositoryFragment; import org.springframework.data.util.Predicates; import org.springframework.data.util.QTypeContributor; @@ -90,8 +88,7 @@ public class RepositoryRegistrationAotContribution implements BeanRegistrationAo * @throws IllegalArgumentException if the {@link RepositoryRegistrationAotProcessor} is {@literal null}. * @see RepositoryRegistrationAotProcessor */ - protected RepositoryRegistrationAotContribution( - RepositoryRegistrationAotProcessor processor) { + protected RepositoryRegistrationAotContribution(RepositoryRegistrationAotProcessor processor) { Assert.notNull(processor, "RepositoryRegistrationAotProcessor must not be null"); @@ -108,8 +105,7 @@ protected RepositoryRegistrationAotContribution( * @throws IllegalArgumentException if the {@link RepositoryRegistrationAotProcessor} is {@literal null}. * @see RepositoryRegistrationAotProcessor */ - public static RepositoryRegistrationAotContribution fromProcessor( - RepositoryRegistrationAotProcessor processor) { + public static RepositoryRegistrationAotContribution fromProcessor(RepositoryRegistrationAotProcessor processor) { return new RepositoryRegistrationAotContribution(processor); } @@ -255,7 +251,8 @@ private void contributeRepositoryInfo(AotRepositoryContext repositoryContext, Ge }); implementation.ifPresent(impl -> { - contribution.getRuntimeHints().reflection().registerType(impl.getClass(), hint -> { + Class typeToRegister = impl instanceof Class c ? c : impl.getClass(); + contribution.getRuntimeHints().reflection().registerType(typeToRegister, hint -> { hint.withMembers(MemberCategory.INVOKE_PUBLIC_METHODS); @@ -365,18 +362,16 @@ public Predicate> typeFilter() { // like only document ones. // TODO: A @SuppressWarnings("rawtypes") private DefaultAotRepositoryContext buildAotRepositoryContext(RegisteredBean bean, - RepositoryConfiguration repositoryMetadata) { + RepositoryConfiguration repositoryConfiguration) { DefaultAotRepositoryContext repositoryContext = new DefaultAotRepositoryContext( AotContext.from(getBeanFactory(), getRepositoryRegistrationAotProcessor().getEnvironment())); - RepositoryFactoryBeanSupport rfbs = bean.getBeanFactory().getBean("&" + bean.getBeanName(), - RepositoryFactoryBeanSupport.class); - repositoryContext.setBeanName(bean.getBeanName()); - repositoryContext.setBasePackages(repositoryMetadata.getBasePackages().toSet()); + repositoryContext.setBasePackages(repositoryConfiguration.getBasePackages().toSet()); repositoryContext.setIdentifyingAnnotations(resolveIdentifyingAnnotations()); - repositoryContext.setRepositoryInformation(rfbs.getRepositoryInformation()); + repositoryContext + .setRepositoryInformation(RepositoryBeanDefinitionReader.repositoryInformation(repositoryConfiguration, bean)); return repositoryContext; } diff --git a/src/main/java/org/springframework/data/repository/core/RepositoryInformation.java b/src/main/java/org/springframework/data/repository/core/RepositoryInformation.java index e3f77cc339..3ebee41f24 100644 --- a/src/main/java/org/springframework/data/repository/core/RepositoryInformation.java +++ b/src/main/java/org/springframework/data/repository/core/RepositoryInformation.java @@ -18,6 +18,7 @@ import java.lang.reflect.Method; import java.util.List; +import org.jspecify.annotations.Nullable; import org.springframework.data.repository.core.support.RepositoryComposition; /** @@ -105,4 +106,8 @@ default boolean hasQueryMethods() { */ RepositoryComposition getRepositoryComposition(); + default @Nullable String moduleName() { + return null; + } + } diff --git a/src/main/java/org/springframework/data/repository/core/RepositoryInformationSupport.java b/src/main/java/org/springframework/data/repository/core/RepositoryInformationSupport.java index 269563dc1f..94660dec28 100644 --- a/src/main/java/org/springframework/data/repository/core/RepositoryInformationSupport.java +++ b/src/main/java/org/springframework/data/repository/core/RepositoryInformationSupport.java @@ -184,7 +184,7 @@ protected boolean isQueryMethodCandidate(Method method) { return true; } - private RepositoryMetadata getMetadata() { + protected RepositoryMetadata getMetadata() { return metadata.get(); } diff --git a/src/main/java/org/springframework/data/repository/core/support/RepositoryFactoryBeanSupport.java b/src/main/java/org/springframework/data/repository/core/support/RepositoryFactoryBeanSupport.java index d2f449c6e5..a0f19c5fc2 100644 --- a/src/main/java/org/springframework/data/repository/core/support/RepositoryFactoryBeanSupport.java +++ b/src/main/java/org/springframework/data/repository/core/support/RepositoryFactoryBeanSupport.java @@ -95,6 +95,10 @@ public abstract class RepositoryFactoryBeanSupport, private @Nullable Lazy repository; private @Nullable RepositoryMetadata repositoryMetadata; + // AOT bean factory hint? + private @Nullable String moduleBaseClass; + private @Nullable String moduleName; + /** * Creates a new {@link RepositoryFactoryBeanSupport} for the given repository interface. * @@ -155,7 +159,7 @@ public void setCustomImplementation(Object customImplementation) { * @param repositoryFragments */ public void setRepositoryFragments(RepositoryFragments repositoryFragments) { - setRepositoryFragments(RepositoryFragmentsFunction.just(repositoryFragments)); + setRepositoryFragmentsFunction(RepositoryFragmentsFunction.just(repositoryFragments)); } /** @@ -165,7 +169,7 @@ public void setRepositoryFragments(RepositoryFragments repositoryFragments) { * @param fragmentsFunction * @since 4.0 */ - public void setRepositoryFragments(RepositoryFragmentsFunction fragmentsFunction) { + public void setRepositoryFragmentsFunction(RepositoryFragmentsFunction fragmentsFunction) { this.fragments.add(fragmentsFunction); } @@ -257,6 +261,14 @@ public void setApplicationEventPublisher(ApplicationEventPublisher publisher) { this.publisher = publisher; } + public void setModuleBaseClass(String moduleBaseClass) { + this.moduleBaseClass = moduleBaseClass; + } + + public void setModuleName(String moduleName) { + this.moduleName = moduleName; + } + @Override @SuppressWarnings("unchecked") public EntityInformation getEntityInformation() { diff --git a/src/main/java/org/springframework/data/repository/core/support/RepositoryFragment.java b/src/main/java/org/springframework/data/repository/core/support/RepositoryFragment.java index f89b80b847..7b34326a8a 100644 --- a/src/main/java/org/springframework/data/repository/core/support/RepositoryFragment.java +++ b/src/main/java/org/springframework/data/repository/core/support/RepositoryFragment.java @@ -265,7 +265,7 @@ public ImplementedRepositoryFragment(@Nullable Class interfaceClass, T implem Assert.notNull(implementation, "Implementation object must not be null"); - if (interfaceClass != null) { + if (interfaceClass != null && !(implementation instanceof Class)) { Assert .isTrue(ClassUtils.isAssignableValue(interfaceClass, implementation), diff --git a/src/test/java/example/UserRepository.java b/src/test/java/example/UserRepository.java index d87b9237ad..d9b35863ef 100644 --- a/src/test/java/example/UserRepository.java +++ b/src/test/java/example/UserRepository.java @@ -24,7 +24,7 @@ /** * @author Christoph Strobl */ -public interface UserRepository extends CrudRepository { +public interface UserRepository extends CrudRepository, UserRepositoryExtension { User findByFirstname(String firstname); diff --git a/src/test/java/example/UserRepositoryExtension.java b/src/test/java/example/UserRepositoryExtension.java new file mode 100644 index 0000000000..6123aed839 --- /dev/null +++ b/src/test/java/example/UserRepositoryExtension.java @@ -0,0 +1,25 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package example; + +import example.UserRepository.User; + +/** + * @author Christoph Strobl + */ +public interface UserRepositoryExtension { + User findUserByExtensionMethod(); +} diff --git a/src/test/java/example/UserRepositoryExtensionImpl.java b/src/test/java/example/UserRepositoryExtensionImpl.java new file mode 100644 index 0000000000..8e6ccb2419 --- /dev/null +++ b/src/test/java/example/UserRepositoryExtensionImpl.java @@ -0,0 +1,29 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package example; + +import example.UserRepository.User; + +/** + * @author Christoph Strobl + */ +public class UserRepositoryExtensionImpl implements UserRepositoryExtension { + + @Override + public User findUserByExtensionMethod() { + return null; + } +} diff --git a/src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilderUnitTests.java b/src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilderUnitTests.java new file mode 100644 index 0000000000..f57dc41c13 --- /dev/null +++ b/src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryBuilderUnitTests.java @@ -0,0 +1,157 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.repository.aot.generate; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import example.UserRepository; +import example.UserRepository.User; + +import java.util.TimeZone; + +import javax.lang.model.element.Modifier; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import org.springframework.data.geo.Metric; +import org.springframework.data.projection.SpelAwareProxyProjectionFactory; +import org.springframework.data.repository.core.RepositoryInformation; +import org.springframework.data.repository.query.QueryMethod; +import org.springframework.data.util.TypeInformation; +import org.springframework.javapoet.MethodSpec; +import org.springframework.javapoet.TypeName; +import org.springframework.stereotype.Repository; + +/** + * @author Christoph Strobl + */ +class AotRepositoryBuilderUnitTests { + + RepositoryInformation repositoryInformation; + + @BeforeEach + void beforeEach() { + + repositoryInformation = mock(RepositoryInformation.class); + doReturn(UserRepository.class).when(repositoryInformation).getRepositoryInterface(); + } + + @Test // GH-3279 + void writesClassSkeleton() { + + AotRepositoryBuilder repoBuilder = AotRepositoryBuilder.forRepository(repositoryInformation, + new SpelAwareProxyProjectionFactory()); + assertThat(repoBuilder.build().javaFile().toString()) + .contains("package %s;".formatted(UserRepository.class.getPackageName())) // same package as source repo + .contains("@Generated") // marked as generated source + .contains("public class %sImpl__Aot".formatted(UserRepository.class.getSimpleName())) // target name + .contains("public UserRepositoryImpl__Aot()"); // default constructor if not arguments to wire + } + + @Test // GH-3279 + void appliesCtorArguments() { + + AotRepositoryBuilder repoBuilder = AotRepositoryBuilder.forRepository(repositoryInformation, + new SpelAwareProxyProjectionFactory()); + repoBuilder.withConstructorCustomizer(ctor -> { + ctor.addParameter("param1", Metric.class); + ctor.addParameter("param2", String.class); + ctor.addParameter("ctorScoped", TypeName.OBJECT, false); + }); + assertThat(repoBuilder.build().javaFile().toString()) // + .contains("private final Metric param1;") // + .contains("private final String param2;") // + .doesNotContain("private final Object ctorScoped;") // + .contains("public UserRepositoryImpl__Aot(Metric param1, String param2, Object ctorScoped)") // + .contains("this.param1 = param1") // + .contains("this.param2 = param2") // + .doesNotContain("this.ctorScoped = ctorScoped"); + } + + @Test // GH-3279 + void appliesCtorCodeBlock() { + + AotRepositoryBuilder repoBuilder = AotRepositoryBuilder.forRepository(repositoryInformation, + new SpelAwareProxyProjectionFactory()); + repoBuilder.withConstructorCustomizer(ctor -> { + ctor.customize((info, code) -> { + code.addStatement("throw new $T($S)", IllegalStateException.class, "initialization error"); + }); + }); + assertThat(repoBuilder.build().javaFile().toString()).containsIgnoringWhitespaces( + "UserRepositoryImpl__Aot() { throw new IllegalStateException(\"initialization error\"); }"); + } + + @Test // GH-3279 + void appliesClassCustomizations() { + + AotRepositoryBuilder repoBuilder = AotRepositoryBuilder.forRepository(repositoryInformation, + new SpelAwareProxyProjectionFactory()); + + repoBuilder.withClassCustomizer((info, metadata, clazz) -> { + + clazz.addField(Float.class, "f", Modifier.PRIVATE, Modifier.STATIC); + clazz.addField(Double.class, "d", Modifier.PUBLIC); + clazz.addField(TimeZone.class, "t", Modifier.FINAL); + + clazz.addAnnotation(Repository.class); + + clazz.addMethod(MethodSpec.methodBuilder("oops").build()); + }); + + assertThat(repoBuilder.build().javaFile().toString()) // + .contains("@Repository") // + .contains("private static Float f;") // + .contains("public Double d;") // + .contains("final TimeZone t;") // + .containsIgnoringWhitespaces("void oops() { }"); + } + + @Test // GH-3279 + void appliesQueryMethodContributor() { + + AotRepositoryBuilder repoBuilder = AotRepositoryBuilder.forRepository(repositoryInformation, + new SpelAwareProxyProjectionFactory()); + + when(repositoryInformation.isQueryMethod(Mockito.argThat(arg -> arg.getName().equals("findByFirstname")))) + .thenReturn(true); + doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnType(any()); + + repoBuilder.withQueryMethodContributor((method, info) -> { + + return new MethodContributor<>(mock(QueryMethod.class), null) { + + @Override + public MethodSpec contribute(AotQueryMethodGenerationContext context) { + return MethodSpec.methodBuilder("oops").build(); + } + + @Override + public boolean contributesMethodSpec() { + return true; + } + }; + }); + + assertThat(repoBuilder.build().javaFile().toString()) // + .containsIgnoringWhitespaces("void oops() { }"); + } +} diff --git a/src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryMethodBuilderUnitTests.java b/src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryMethodBuilderUnitTests.java new file mode 100644 index 0000000000..b0f19b807e --- /dev/null +++ b/src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryMethodBuilderUnitTests.java @@ -0,0 +1,88 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.repository.aot.generate; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.when; + +import example.UserRepository; +import example.UserRepository.User; + +import java.lang.reflect.Method; +import java.util.List; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import org.springframework.core.ResolvableType; +import org.springframework.data.repository.core.RepositoryInformation; +import org.springframework.data.util.TypeInformation; +import org.springframework.javapoet.ParameterSpec; +import org.springframework.javapoet.ParameterizedTypeName; + +/** + * @author Christoph Strobl + */ +class AotRepositoryMethodBuilderUnitTests { + + RepositoryInformation repositoryInformation; + AotQueryMethodGenerationContext methodGenerationContext; + + @BeforeEach + void beforeEach() { + repositoryInformation = Mockito.mock(RepositoryInformation.class); + methodGenerationContext = Mockito.mock(AotQueryMethodGenerationContext.class); + + when(methodGenerationContext.getRepositoryInformation()).thenReturn(repositoryInformation); + } + + @Test // GH-3279 + void generatesMethodSkeletonBasedOnGenerationMetadata() throws NoSuchMethodException { + + Method method = UserRepository.class.getMethod("findByFirstname", String.class); + when(methodGenerationContext.getMethod()).thenReturn(method); + when(methodGenerationContext.getReturnType()).thenReturn(ResolvableType.forClass(User.class)); + doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnType(any()); + MethodMetadata methodMetadata = new MethodMetadata(repositoryInformation, method); + methodMetadata.addParameter(ParameterSpec.builder(String.class, "firstname").build()); + when(methodGenerationContext.getTargetMethodMetadata()).thenReturn(methodMetadata); + + AotRepositoryMethodBuilder builder = new AotRepositoryMethodBuilder(methodGenerationContext); + assertThat(builder.buildMethod().toString()) // + .containsPattern("public .*User findByFirstname\\(.*String firstname\\)"); + } + + @Test // GH-3279 + void generatesMethodWithGenerics() throws NoSuchMethodException { + + Method method = UserRepository.class.getMethod("findByFirstnameIn", List.class); + when(methodGenerationContext.getMethod()).thenReturn(method); + when(methodGenerationContext.getReturnType()) + .thenReturn(ResolvableType.forClassWithGenerics(List.class, User.class)); + doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnType(any()); + MethodMetadata methodMetadata = new MethodMetadata(repositoryInformation, method); + methodMetadata + .addParameter(ParameterSpec.builder(ParameterizedTypeName.get(List.class, String.class), "firstnames").build()); + when(methodGenerationContext.getTargetMethodMetadata()).thenReturn(methodMetadata); + + AotRepositoryMethodBuilder builder = new AotRepositoryMethodBuilder(methodGenerationContext); + assertThat(builder.buildMethod().toString()) // + .containsPattern("public .*List<.*User> findByFirstnameIn\\(") // + .containsPattern(".*List<.*String> firstnames\\)"); + } +} diff --git a/src/test/java/org/springframework/data/repository/aot/generate/MethodCapturingRepositoryContributor.java b/src/test/java/org/springframework/data/repository/aot/generate/MethodCapturingRepositoryContributor.java new file mode 100644 index 0000000000..033c7fbe18 --- /dev/null +++ b/src/test/java/org/springframework/data/repository/aot/generate/MethodCapturingRepositoryContributor.java @@ -0,0 +1,57 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.repository.aot.generate; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.lang.reflect.Method; +import java.util.List; + +import org.assertj.core.api.MapAssert; +import org.jspecify.annotations.Nullable; +import org.springframework.data.repository.config.AotRepositoryContext; +import org.springframework.data.repository.core.RepositoryInformation; +import org.springframework.data.repository.query.QueryMethod; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; + +/** + * @author Christoph Strobl + */ +public class MethodCapturingRepositoryContributor extends RepositoryContributor { + + MultiValueMap capturedInvocations; + + public MethodCapturingRepositoryContributor(AotRepositoryContext repositoryContext) { + super(repositoryContext); + this.capturedInvocations = new LinkedMultiValueMap<>(3); + } + + @Override + protected @Nullable MethodContributor contributeQueryMethod(Method method, + RepositoryInformation repositoryInformation) { + capturedInvocations.add(method.getName(), method); + return null; + } + + void verifyContributionFor(String methodName) { + assertThat(capturedInvocations).containsKey(methodName); + } + + MapAssert> verifyContributedMethods() { + return assertThat(capturedInvocations); + } +} diff --git a/src/test/java/org/springframework/data/repository/aot/generate/RepositoryContributorUnitTests.java b/src/test/java/org/springframework/data/repository/aot/generate/RepositoryContributorUnitTests.java index b77ac6346e..133281fe0c 100644 --- a/src/test/java/org/springframework/data/repository/aot/generate/RepositoryContributorUnitTests.java +++ b/src/test/java/org/springframework/data/repository/aot/generate/RepositoryContributorUnitTests.java @@ -15,31 +15,45 @@ */ package org.springframework.data.repository.aot.generate; -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.Mockito.when; import example.UserRepository; +import example.UserRepositoryExtension; +import example.UserRepositoryExtensionImpl; import java.lang.reflect.Method; import java.util.Map; +import java.util.Optional; +import java.util.Set; import org.jspecify.annotations.Nullable; import org.junit.jupiter.api.Test; - +import org.mockito.Mockito; import org.springframework.aot.test.generate.TestGenerationContext; import org.springframework.core.test.tools.TestCompiler; import org.springframework.data.aot.CodeContributionAssert; +import org.springframework.data.repository.CrudRepository; +import org.springframework.data.repository.config.AotRepositoryContext; import org.springframework.data.repository.core.RepositoryInformation; +import org.springframework.data.repository.core.support.RepositoryComposition; +import org.springframework.data.repository.core.support.RepositoryComposition.RepositoryFragments; +import org.springframework.data.repository.core.support.RepositoryFragment; import org.springframework.data.repository.query.QueryMethod; import org.springframework.javapoet.CodeBlock; import org.springframework.util.ClassUtils; /** + * Unit tests targeting {@link RepositoryContributor}. + * * @author Christoph Strobl */ class RepositoryContributorUnitTests { - @Test - void testCompile() { + @Test // GH-3279 + void createsCompilableClassStub() { DummyModuleAotRepositoryContext aotContext = new DummyModuleAotRepositoryContext(UserRepository.class, null); RepositoryContributor repositoryContributor = new RepositoryContributor(aotContext) { @@ -55,8 +69,7 @@ void testCompile() { public Map serialize() { return Map.of(); } - }) - .contribute(context -> { + }).contribute(context -> { CodeBlock.Builder builder = CodeBlock.builder(); if (!ClassUtils.isVoidType(method.getReturnType())) { @@ -81,4 +94,146 @@ public Map serialize() { new CodeContributionAssert(generationContext).contributesReflectionFor(expectedTypeName); } + @Test // GH-3279 + void callsMethodContributionForQueryMethod() { + + AotRepositoryContext repositoryContext = Mockito.mock(AotRepositoryContext.class); + RepositoryInformation repositoryInformation = Mockito.mock(RepositoryInformation.class); + + when(repositoryContext.getRepositoryInformation()).thenReturn(repositoryInformation); + when(repositoryInformation.getRepositoryInterface()).thenReturn((Class) UserRepository.class); + when(repositoryInformation.isQueryMethod(argThat(it -> it.getName().equals("findByFirstname")))).thenReturn(true); + + MethodCapturingRepositoryContributor contributor = new MethodCapturingRepositoryContributor(repositoryContext); + contributor.contribute(new TestGenerationContext(UserRepository.class)); + + contributor.verifyContributionFor("findByFirstname"); + } + + @Test // GH-3279 + void doesNotContributeBaseClassMethods() { + + AotRepositoryContext repositoryContext = Mockito.mock(AotRepositoryContext.class); + RepositoryInformation repositoryInformation = Mockito.mock(RepositoryInformation.class); + + when(repositoryContext.getRepositoryInformation()).thenReturn(repositoryInformation); + when(repositoryInformation.getRepositoryInterface()).thenReturn((Class) UserRepository.class); + when(repositoryInformation.getRepositoryComposition()) + .thenReturn(RepositoryComposition.of(RepositoryFragment.structural(RepoBaseClass.class))); + when(repositoryInformation.isBaseClassMethod(argThat(it -> it.getName().equals("findByFirstname")))) + .thenReturn(true); + when(repositoryInformation.isQueryMethod(argThat(it -> !it.getName().equals("findByFirstname")))).thenReturn(true); + + MethodCapturingRepositoryContributor contributor = new MethodCapturingRepositoryContributor(repositoryContext); + contributor.contribute(new TestGenerationContext(UserRepository.class)); + + contributor.verifyContributedMethods().isNotEmpty().doesNotContainKey("findByFirstname"); + } + + @Test // GH-3279 + void doesNotContributeFragmentMethod() { + + AotRepositoryContext repositoryContext = Mockito.mock(AotRepositoryContext.class); + RepositoryInformation repositoryInformation = Mockito.mock(RepositoryInformation.class); + + when(repositoryContext.getRepositoryInformation()).thenReturn(repositoryInformation); + when(repositoryInformation.getRepositoryInterface()).thenReturn((Class) UserRepository.class); + when(repositoryInformation.getRepositoryComposition()) + .thenReturn(RepositoryComposition.of(RepositoryFragment.structural(UserRepository.class)) + .append(RepositoryFragments + .from(Set.of(new RepositoryFragment.ImplementedRepositoryFragment(UserRepositoryExtension.class, + UserRepositoryExtensionImpl.class))))); + + when(repositoryInformation.isCustomMethod(argThat(it -> it.getName().equals("findUserByExtensionMethod")))) + .thenReturn(true); + when(repositoryInformation.isQueryMethod(argThat(it -> it.getName().equals("findByFirstname")))).thenReturn(true); + + MethodCapturingRepositoryContributor contributor = new MethodCapturingRepositoryContributor(repositoryContext); + contributor.contribute(new TestGenerationContext(UserRepository.class)); + + contributor.verifyContributedMethods().isNotEmpty().doesNotContainKey("findUserByExtensionMethod"); + } + + @Test // GH-3279 + void contributesBaseClassMethodIfQueryMethod() { + + AotRepositoryContext repositoryContext = Mockito.mock(AotRepositoryContext.class); + RepositoryInformation repositoryInformation = Mockito.mock(RepositoryInformation.class); + + when(repositoryContext.getRepositoryInformation()).thenReturn(repositoryInformation); + when(repositoryInformation.getRepositoryInterface()).thenReturn((Class) UserRepository.class); + when(repositoryInformation.getRepositoryComposition()) + .thenReturn(RepositoryComposition.of(RepositoryFragment.structural(RepoBaseClass.class))); + when(repositoryInformation.isBaseClassMethod(argThat(it -> it.getName().equals("findByFirstname")))) + .thenReturn(true); + when(repositoryInformation.isQueryMethod(any())).thenReturn(true); + + MethodCapturingRepositoryContributor contributor = new MethodCapturingRepositoryContributor(repositoryContext); + contributor.contribute(new TestGenerationContext(UserRepository.class)); + + contributor.verifyContributedMethods().containsKey("findByFirstname").hasSizeGreaterThan(1); + } + + static class RepoBaseClass implements CrudRepository { + + private CrudRepository delegate; + + public S save(S entity) { + return this.delegate.save(entity); + } + + @Override + public Iterable saveAll(Iterable entities) { + return this.delegate.saveAll(entities); + } + + public Optional findById(ID id) { + return this.delegate.findById(id); + } + + @Override + public boolean existsById(ID id) { + return this.delegate.existsById(id); + } + + @Override + public Iterable findAll() { + return this.delegate.findAll(); + } + + @Override + public Iterable findAllById(Iterable ids) { + return this.delegate.findAllById(ids); + } + + @Override + public long count() { + return this.delegate.count(); + } + + @Override + public void deleteById(ID id) { + this.delegate.deleteById(id); + } + + @Override + public void delete(T entity) { + this.delegate.delete(entity); + } + + @Override + public void deleteAllById(Iterable ids) { + this.delegate.deleteAllById(ids); + } + + @Override + public void deleteAll(Iterable entities) { + this.delegate.deleteAll(entities); + } + + @Override + public void deleteAll() { + this.delegate.deleteAll(); + } + } } diff --git a/src/test/java/org/springframework/data/repository/config/RepositoryBeanDefinitionReaderTests.java b/src/test/java/org/springframework/data/repository/config/RepositoryBeanDefinitionReaderTests.java new file mode 100644 index 0000000000..54379865c1 --- /dev/null +++ b/src/test/java/org/springframework/data/repository/config/RepositoryBeanDefinitionReaderTests.java @@ -0,0 +1,122 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.repository.config; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; + +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.beans.factory.support.RegisteredBean; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.context.annotation.AnnotationConfigApplicationContext; +import org.springframework.data.aot.sample.ConfigWithCustomImplementation; +import org.springframework.data.aot.sample.ConfigWithCustomRepositoryBaseClass; +import org.springframework.data.aot.sample.ConfigWithCustomRepositoryBaseClass.CustomerRepositoryWithCustomBaseRepo; +import org.springframework.data.aot.sample.ConfigWithSimpleCrudRepository; +import org.springframework.data.repository.core.RepositoryInformation; +import org.springframework.data.repository.core.support.RepositoryFactoryBeanSupport; + +/** + * @author Christoph Strobl + */ +class RepositoryBeanDefinitionReaderTests { + + @Test // GH-3279 + void readsSimpleConfigFromBeanFactory() { + + RegisteredBean repoFactoryBean = repositoryFactory(ConfigWithSimpleCrudRepository.class); + + RepositoryConfiguration repoConfig = mock(RepositoryConfiguration.class); + Mockito.when(repoConfig.getRepositoryInterface()).thenReturn(ConfigWithSimpleCrudRepository.MyRepo.class.getName()); + + RepositoryInformation repositoryInformation = RepositoryBeanDefinitionReader.repositoryInformation(repoConfig, + repoFactoryBean.getMergedBeanDefinition(), repoFactoryBean.getBeanFactory()); + + assertThat(repositoryInformation.getRepositoryInterface()).isEqualTo(ConfigWithSimpleCrudRepository.MyRepo.class); + assertThat(repositoryInformation.getDomainType()).isEqualTo(ConfigWithSimpleCrudRepository.Person.class); + assertThat(repositoryInformation.getFragments()).isEmpty(); + } + + @Test // GH-3279 + void readsCustomRepoBaseClassFromBeanFactory() { + + RegisteredBean repoFactoryBean = repositoryFactory(ConfigWithCustomRepositoryBaseClass.class); + + RepositoryConfiguration repoConfig = mock(RepositoryConfiguration.class); + Class repositoryInterfaceType = CustomerRepositoryWithCustomBaseRepo.class; + Mockito.when(repoConfig.getRepositoryInterface()).thenReturn(repositoryInterfaceType.getName()); + + RepositoryInformation repositoryInformation = RepositoryBeanDefinitionReader.repositoryInformation(repoConfig, + repoFactoryBean.getMergedBeanDefinition(), repoFactoryBean.getBeanFactory()); + + assertThat(repositoryInformation.getRepositoryBaseClass()) + .isEqualTo(ConfigWithCustomRepositoryBaseClass.RepoBaseClass.class); + } + + @Test // GH-3279 + void readsFragmentsFromBeanFactory() { + + RegisteredBean repoFactoryBean = repositoryFactory(ConfigWithCustomImplementation.class); + + RepositoryConfiguration repoConfig = mock(RepositoryConfiguration.class); + Class repositoryInterfaceType = ConfigWithCustomImplementation.RepositoryWithCustomImplementation.class; + Mockito.when(repoConfig.getRepositoryInterface()).thenReturn(repositoryInterfaceType.getName()); + + RepositoryInformation repositoryInformation = RepositoryBeanDefinitionReader.repositoryInformation(repoConfig, + repoFactoryBean.getMergedBeanDefinition(), repoFactoryBean.getBeanFactory()); + + assertThat(repositoryInformation.getFragments()).satisfiesExactly(fragment -> { + assertThat(fragment.getSignatureContributor()) + .isEqualTo(ConfigWithCustomImplementation.CustomImplInterface.class); + }); + } + + @Test // GH-3279 + void fallsBackToModuleBaseClassIfSetAndNoRepoBaseDefined() { + + RegisteredBean repoFactoryBean = repositoryFactory(ConfigWithSimpleCrudRepository.class); + RootBeanDefinition rootBeanDefinition = repoFactoryBean.getMergedBeanDefinition().cloneBeanDefinition(); + // need to unset because its defined as non default + rootBeanDefinition.getPropertyValues().removePropertyValue("repositoryBaseClass"); + rootBeanDefinition.getPropertyValues().add("moduleBaseClass", ModuleBase.class.getName()); + + RepositoryConfiguration repoConfig = mock(RepositoryConfiguration.class); + Mockito.when(repoConfig.getRepositoryInterface()).thenReturn(ConfigWithSimpleCrudRepository.MyRepo.class.getName()); + + RepositoryInformation repositoryInformation = RepositoryBeanDefinitionReader.repositoryInformation(repoConfig, + rootBeanDefinition, repoFactoryBean.getBeanFactory()); + + assertThat(repositoryInformation.getRepositoryBaseClass()).isEqualTo(ModuleBase.class); + } + + static RegisteredBean repositoryFactory(Class configClass) { + + AnnotationConfigApplicationContext applicationContext = new AnnotationConfigApplicationContext(); + applicationContext.register(configClass); + applicationContext.refreshForAotProcessing(new RuntimeHints()); + + String[] beanNamesForType = applicationContext.getBeanNamesForType(RepositoryFactoryBeanSupport.class); + if (beanNamesForType.length != 1) { + throw new IllegalStateException("Unable to find repository FactoryBean"); + } + + return RegisteredBean.of(applicationContext.getBeanFactory(), beanNamesForType[0]); + } + + static class ModuleBase {} +} From 9f76af389a187849ca0cad84fc2512814e4fcbbc Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Tue, 29 Apr 2025 10:30:23 +0200 Subject: [PATCH 3/3] Make NullAway go away. Ignore warnings for already checked constructs null away does not understand. --- src/main/java/org/springframework/data/mapping/Parameter.java | 1 + .../org/springframework/data/repository/query/Parameter.java | 1 + 2 files changed, 2 insertions(+) diff --git a/src/main/java/org/springframework/data/mapping/Parameter.java b/src/main/java/org/springframework/data/mapping/Parameter.java index bf6221faad..ca4afea4e6 100644 --- a/src/main/java/org/springframework/data/mapping/Parameter.java +++ b/src/main/java/org/springframework/data/mapping/Parameter.java @@ -118,6 +118,7 @@ public boolean hasName() { * @since 3.5 * @see org.springframework.core.ParameterNameDiscoverer */ + @SuppressWarnings("NullAway") public String getRequiredName() { if (!hasName()) { diff --git a/src/main/java/org/springframework/data/repository/query/Parameter.java b/src/main/java/org/springframework/data/repository/query/Parameter.java index 0907d0f035..b52cbb3df1 100644 --- a/src/main/java/org/springframework/data/repository/query/Parameter.java +++ b/src/main/java/org/springframework/data/repository/query/Parameter.java @@ -125,6 +125,7 @@ public boolean isDynamicProjectionParameter() { * * @return */ + @SuppressWarnings("NullAway") public String getPlaceholder() { if (isNamedParameter()) {