Skip to content

Refine Repository Composition retrieval during AOT. #3282

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: 4.0.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

<groupId>org.springframework.data</groupId>
<artifactId>spring-data-commons</artifactId>
<version>4.0.0-SNAPSHOT</version>
<version>4.0.x-GH-3279-SNAPSHOT</version>

<name>Spring Data Core</name>
<description>Core Spring concepts underpinning every Spring Data module.</description>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ public boolean hasName() {
* @since 3.5
* @see org.springframework.core.ParameterNameDiscoverer
*/
@SuppressWarnings("NullAway")
public String getRequiredName() {

if (!hasName()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,24 +138,23 @@ public AotBundle build() {
this.customizer.customize(repositoryInformation, generationMetadata, builder);
JavaFile javaFile = JavaFile.builder(packageName(), builder.build()).build();

// TODO: module identifier
AotRepositoryMetadata metadata = new AotRepositoryMetadata(repositoryInformation.getRepositoryInterface().getName(),
"", repositoryType, methodMetadata);
repositoryInformation.moduleName() != null ? repositoryInformation.moduleName() : "", repositoryType, methodMetadata);

return new AotBundle(javaFile, metadata.toJson());
}

private void contributeMethod(Method method, RepositoryComposition repositoryComposition,
List<AotRepositoryMethod> methodMetadata, TypeSpec.Builder builder) {

if (repositoryInformation.isCustomMethod(method) || repositoryInformation.isBaseClassMethod(method)) {
if (repositoryInformation.isCustomMethod(method) || (repositoryInformation.isBaseClassMethod(method) && !repositoryInformation.isQueryMethod(method))) {

RepositoryFragment<?> fragment = repositoryComposition.findFragment(method);

if (fragment != null) {
methodMetadata.add(getFragmentMetadata(method, fragment));
return;
}
return;
}

if (method.isBridge() || method.isDefault() || java.lang.reflect.Modifier.isStatic(method.getModifiers())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public abstract class MethodContributor<M extends QueryMethod> {
private final M queryMethod;
private final QueryMetadata metadata;

private MethodContributor(M queryMethod, QueryMetadata metadata) {
MethodContributor(M queryMethod, QueryMetadata metadata) {
this.queryMethod = queryMethod;
this.metadata = metadata;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public CodeBlock decorate() {
// bring in properties as usual
builder.add(inheritedProperties.get());

builder.add("beanDefinition.getPropertyValues().addPropertyValue(\"repositoryFragments\", new $T() {\n",
builder.add("beanDefinition.getPropertyValues().addPropertyValue(\"repositoryFragmentsFunction\", new $T() {\n",
RepositoryFactoryBeanSupport.RepositoryFragmentsFunction.class);
builder.indent();
builder.add("public $T getRepositoryFragments($T beanFactory, $T context) {\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@
import java.util.Set;
import java.util.function.Supplier;

import org.jspecify.annotations.Nullable;
import org.springframework.data.repository.core.RepositoryInformation;
import org.springframework.data.repository.core.RepositoryInformationSupport;
import org.springframework.data.repository.core.RepositoryMetadata;
import org.springframework.data.repository.core.support.RepositoryComposition;
import org.springframework.data.repository.core.support.RepositoryComposition.RepositoryFragments;
import org.springframework.data.repository.core.support.RepositoryFragment;
import org.springframework.data.util.Lazy;

Expand All @@ -36,16 +38,31 @@
*/
class AotRepositoryInformation extends RepositoryInformationSupport implements RepositoryInformation {

private final @Nullable String moduleName;
private final Supplier<Collection<RepositoryFragment<?>>> fragments;
private Lazy<RepositoryComposition> baseComposition = Lazy.of(() -> {
return RepositoryComposition.of(RepositoryFragment.structural(getRepositoryBaseClass()));
});

AotRepositoryInformation(Supplier<RepositoryMetadata> repositoryMetadata, Supplier<Class<?>> repositoryBaseClass,
Supplier<Collection<RepositoryFragment<?>>> fragments) {
private final Lazy<RepositoryComposition> repositoryComposition;
private final Lazy<RepositoryComposition> baseComposition;

AotRepositoryInformation(@Nullable String moduleName, Supplier<RepositoryMetadata> repositoryMetadata,
Supplier<Class<?>> repositoryBaseClass, Supplier<Collection<RepositoryFragment<?>>> fragments) {

super(repositoryMetadata, repositoryBaseClass);

this.moduleName = moduleName;
this.fragments = fragments;

this.repositoryComposition = Lazy
.of(() -> RepositoryComposition.fromMetadata(getMetadata()).append(RepositoryFragments.from(getFragments())));

this.baseComposition = Lazy.of(() -> {

RepositoryComposition targetRepoComposition = repositoryComposition.get();

return RepositoryComposition.of(RepositoryFragment.structural(getRepositoryBaseClass())) //
.withArgumentConverter(targetRepoComposition.getArgumentConverter()) //
.withMethodLookup(targetRepoComposition.getMethodLookup());
});
}

/**
Expand All @@ -57,10 +74,9 @@ public Set<RepositoryFragment<?>> getFragments() {
return new LinkedHashSet<>(fragments.get());
}

// Not required during AOT processing.
@Override
public boolean isCustomMethod(Method method) {
return false;
return repositoryComposition.get().findMethod(method).isPresent();
}

@Override
Expand All @@ -75,7 +91,11 @@ public Method getTargetClassMethod(Method method) {

@Override
public RepositoryComposition getRepositoryComposition() {
return baseComposition.get().append(RepositoryComposition.RepositoryFragments.from(fragments.get()));
return repositoryComposition.get();
}

@Override
public @Nullable String moduleName() {
return moduleName;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,21 @@
package org.springframework.data.repository.config;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.config.ConstructorArgumentValues.ValueHolder;
import org.springframework.beans.factory.config.RuntimeBeanReference;
import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.core.ResolvableType;
import org.springframework.data.repository.CrudRepository;
import org.springframework.data.repository.PagingAndSortingRepository;
import org.springframework.data.repository.core.RepositoryInformation;
import org.springframework.data.repository.core.support.DefaultRepositoryMetadata;
import org.springframework.data.repository.core.RepositoryMetadata;
import org.springframework.data.repository.core.support.AbstractRepositoryMetadata;
import org.springframework.data.repository.core.support.RepositoryFragment;
import org.springframework.data.util.Lazy;
import org.springframework.data.repository.core.support.RepositoryFragment.ImplementedRepositoryFragment;
import org.springframework.util.ClassUtils;

/**
Expand All @@ -38,49 +42,108 @@
*/
class RepositoryBeanDefinitionReader {

static RepositoryInformation readRepositoryInformation(RepositoryConfiguration<?> metadata,
ConfigurableListableBeanFactory beanFactory) {

return new AotRepositoryInformation(metadataSupplier(metadata, beanFactory),
repositoryBaseClass(metadata, beanFactory), fragments(metadata, beanFactory));
/**
* @return
*/
static RepositoryInformation repositoryInformation(RepositoryConfiguration<?> repoConfig, RegisteredBean repoBean) {
return repositoryInformation(repoConfig, repoBean.getMergedBeanDefinition(), repoBean.getBeanFactory());
}

private static Supplier<Collection<RepositoryFragment<?>>> fragments(RepositoryConfiguration<?> metadata,
/**
* @param source the RepositoryFactoryBeanSupport bean definition.
* @param beanFactory
* @return
*/
@SuppressWarnings("NullAway")
static RepositoryInformation repositoryInformation(RepositoryConfiguration<?> repoConfig, BeanDefinition source,
ConfigurableListableBeanFactory beanFactory) {

if (metadata instanceof RepositoryFragmentConfigurationProvider provider) {

return Lazy.of(() -> {
return provider.getFragmentConfiguration().stream().flatMap(it -> {

List<RepositoryFragment<?>> fragments = new ArrayList<>(1);
RepositoryMetadata metadata = AbstractRepositoryMetadata
.getMetadata(forName(repoConfig.getRepositoryInterface(), beanFactory));
Class<?> repositoryBaseClass = readRepositoryBaseClass(source, beanFactory);
List<RepositoryFragment<?>> fragmentList = readRepositoryFragments(source, beanFactory);
if (source.getPropertyValues().contains("customImplementation")) {

Object o = source.getPropertyValues().get("customImplementation");
if (o instanceof RuntimeBeanReference rbr) {
BeanDefinition customImplBeanDefintion = beanFactory.getBeanDefinition(rbr.getBeanName());
Class<?> beanType = forName(customImplBeanDefintion.getBeanClassName(), beanFactory);
ResolvableType[] interfaces = ResolvableType.forClass(beanType).getInterfaces();
if (interfaces.length == 1) {
fragmentList.add(new ImplementedRepositoryFragment(interfaces[0].toClass(), beanType));
} else {
boolean found = false;
for (ResolvableType i : interfaces) {
if (beanType.getSimpleName().contains(i.resolve().getSimpleName())) {
fragmentList.add(new ImplementedRepositoryFragment(interfaces[0].toClass(), beanType));
found = true;
break;
}
}
if (!found) {
fragmentList.add(RepositoryFragment.implemented(beanType));
}
}
}
}

fragments.add(RepositoryFragment.implemented(forName(it.getClassName(), beanFactory)));
String moduleName = (String) source.getPropertyValues().get("moduleName");
AotRepositoryInformation repositoryInformation = new AotRepositoryInformation(moduleName, () -> metadata,
() -> repositoryBaseClass, () -> fragmentList);
return repositoryInformation;
}

if (it.getInterfaceName() != null) {
fragments.add(RepositoryFragment.structural(forName(it.getInterfaceName(), beanFactory)));
}
@SuppressWarnings("NullAway")
private static Class<?> readRepositoryBaseClass(BeanDefinition source, ConfigurableListableBeanFactory beanFactory) {

return fragments.stream();
}).collect(Collectors.toList());
});
Object repoBaseClassName = source.getPropertyValues().get("repositoryBaseClass");
if (repoBaseClassName != null) {
return forName(repoBaseClassName.toString(), beanFactory);
}

return Lazy.of(Collections::emptyList);
if (source.getPropertyValues().contains("moduleBaseClass")) {
return forName((String) source.getPropertyValues().get("moduleBaseClass"), beanFactory);
}
return Dummy.class;
}

@SuppressWarnings({ "rawtypes", "unchecked" })
private static Supplier<Class<?>> repositoryBaseClass(RepositoryConfiguration metadata,
@SuppressWarnings("NullAway")
private static List<RepositoryFragment<?>> readRepositoryFragments(BeanDefinition source,
ConfigurableListableBeanFactory beanFactory) {

return Lazy.of(() -> (Class<?>) metadata.getRepositoryBaseClassName().map(it -> forName(it.toString(), beanFactory))
.orElse(Object.class));
RuntimeBeanReference beanReference = (RuntimeBeanReference) source.getPropertyValues().get("repositoryFragments");
BeanDefinition fragments = beanFactory.getBeanDefinition(beanReference.getBeanName());

ValueHolder fragmentBeanNameList = fragments.getConstructorArgumentValues().getArgumentValue(0, List.class);
List<String> fragmentBeanNames = (List<String>) fragmentBeanNameList.getValue();

List<RepositoryFragment<?>> fragmentList = new ArrayList<>();
for (String beanName : fragmentBeanNames) {

BeanDefinition fragmentBeanDefinition = beanFactory.getBeanDefinition(beanName);
ValueHolder argumentValue = fragmentBeanDefinition.getConstructorArgumentValues().getArgumentValue(0,
String.class);
ValueHolder argumentValue1 = fragmentBeanDefinition.getConstructorArgumentValues().getArgumentValue(1, null, null,
null);
Object fragmentClassName = argumentValue.getValue();

try {
Class<?> type = ClassUtils.forName(fragmentClassName.toString(), beanFactory.getBeanClassLoader());

if (argumentValue1 != null && argumentValue1.getValue() instanceof RuntimeBeanReference rbf) {
BeanDefinition implBeanDef = beanFactory.getBeanDefinition(rbf.getBeanName());
Class implClass = ClassUtils.forName(implBeanDef.getBeanClassName(), beanFactory.getBeanClassLoader());
fragmentList.add(new RepositoryFragment.ImplementedRepositoryFragment(type, implClass));
} else {
fragmentList.add(RepositoryFragment.structural(type));
}
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
}
return fragmentList;
}

private static Supplier<org.springframework.data.repository.core.RepositoryMetadata> metadataSupplier(
RepositoryConfiguration<?> metadata, ConfigurableListableBeanFactory beanFactory) {
return Lazy.of(() -> new DefaultRepositoryMetadata(forName(metadata.getRepositoryInterface(), beanFactory)));
}
static abstract class Dummy implements CrudRepository<Object, Object>, PagingAndSortingRepository<Object, Object> {}

static Class<?> forName(String name, ConfigurableListableBeanFactory beanFactory) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import java.util.function.Predicate;

import org.jspecify.annotations.Nullable;

import org.springframework.aop.SpringProxy;
import org.springframework.aop.framework.Advised;
import org.springframework.aot.generate.GenerationContext;
Expand All @@ -49,7 +48,6 @@
import org.springframework.data.repository.Repository;
import org.springframework.data.repository.aot.generate.RepositoryContributor;
import org.springframework.data.repository.core.RepositoryInformation;
import org.springframework.data.repository.core.support.RepositoryFactoryBeanSupport;
import org.springframework.data.repository.core.support.RepositoryFragment;
import org.springframework.data.util.Predicates;
import org.springframework.data.util.QTypeContributor;
Expand Down Expand Up @@ -90,8 +88,7 @@ public class RepositoryRegistrationAotContribution implements BeanRegistrationAo
* @throws IllegalArgumentException if the {@link RepositoryRegistrationAotProcessor} is {@literal null}.
* @see RepositoryRegistrationAotProcessor
*/
protected RepositoryRegistrationAotContribution(
RepositoryRegistrationAotProcessor processor) {
protected RepositoryRegistrationAotContribution(RepositoryRegistrationAotProcessor processor) {

Assert.notNull(processor, "RepositoryRegistrationAotProcessor must not be null");

Expand All @@ -108,8 +105,7 @@ protected RepositoryRegistrationAotContribution(
* @throws IllegalArgumentException if the {@link RepositoryRegistrationAotProcessor} is {@literal null}.
* @see RepositoryRegistrationAotProcessor
*/
public static RepositoryRegistrationAotContribution fromProcessor(
RepositoryRegistrationAotProcessor processor) {
public static RepositoryRegistrationAotContribution fromProcessor(RepositoryRegistrationAotProcessor processor) {
return new RepositoryRegistrationAotContribution(processor);
}

Expand Down Expand Up @@ -255,7 +251,8 @@ private void contributeRepositoryInfo(AotRepositoryContext repositoryContext, Ge
});

implementation.ifPresent(impl -> {
contribution.getRuntimeHints().reflection().registerType(impl.getClass(), hint -> {
Class<?> typeToRegister = impl instanceof Class c ? c : impl.getClass();
contribution.getRuntimeHints().reflection().registerType(typeToRegister, hint -> {

hint.withMembers(MemberCategory.INVOKE_PUBLIC_METHODS);

Expand Down Expand Up @@ -365,18 +362,16 @@ public Predicate<Class<?>> typeFilter() { // like only document ones. // TODO: A

@SuppressWarnings("rawtypes")
private DefaultAotRepositoryContext buildAotRepositoryContext(RegisteredBean bean,
RepositoryConfiguration<?> repositoryMetadata) {
RepositoryConfiguration<?> repositoryConfiguration) {

DefaultAotRepositoryContext repositoryContext = new DefaultAotRepositoryContext(
AotContext.from(getBeanFactory(), getRepositoryRegistrationAotProcessor().getEnvironment()));

RepositoryFactoryBeanSupport rfbs = bean.getBeanFactory().getBean("&" + bean.getBeanName(),
RepositoryFactoryBeanSupport.class);

repositoryContext.setBeanName(bean.getBeanName());
repositoryContext.setBasePackages(repositoryMetadata.getBasePackages().toSet());
repositoryContext.setBasePackages(repositoryConfiguration.getBasePackages().toSet());
repositoryContext.setIdentifyingAnnotations(resolveIdentifyingAnnotations());
repositoryContext.setRepositoryInformation(rfbs.getRepositoryInformation());
repositoryContext
.setRepositoryInformation(RepositoryBeanDefinitionReader.repositoryInformation(repositoryConfiguration, bean));

return repositoryContext;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.lang.reflect.Method;
import java.util.List;

import org.jspecify.annotations.Nullable;
import org.springframework.data.repository.core.support.RepositoryComposition;

/**
Expand Down Expand Up @@ -105,4 +106,8 @@ default boolean hasQueryMethods() {
*/
RepositoryComposition getRepositoryComposition();

default @Nullable String moduleName() {
return null;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ protected boolean isQueryMethodCandidate(Method method) {
return true;
}

private RepositoryMetadata getMetadata() {
protected RepositoryMetadata getMetadata() {
return metadata.get();
}

Expand Down
Loading