Skip to content

Commit f1b7952

Browse files
committed
Use merged bean definitions for EntityCallback type lookup.
We now use the merged bean definition to resolve the defined EntityCallback type. Previously, we used just the bean definition that might have contained no type hints because of ASM-parsed configuration classes. Closes #2853
1 parent 2cbf0fb commit f1b7952

File tree

3 files changed

+34
-42
lines changed

3 files changed

+34
-42
lines changed

src/main/java/org/springframework/data/mapping/callback/DefaultEntityCallbacks.java

+4-3
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.apache.commons.logging.Log;
2323
import org.apache.commons.logging.LogFactory;
2424
import org.springframework.beans.factory.BeanFactory;
25+
import org.springframework.context.support.GenericApplicationContext;
2526
import org.springframework.core.ResolvableType;
2627
import org.springframework.util.Assert;
2728
import org.springframework.util.ClassUtils;
@@ -57,7 +58,8 @@ class DefaultEntityCallbacks implements EntityCallbacks {
5758
* @param beanFactory must not be {@literal null}.
5859
*/
5960
DefaultEntityCallbacks(BeanFactory beanFactory) {
60-
this.callbackDiscoverer = new EntityCallbackDiscoverer(beanFactory);
61+
this.callbackDiscoverer = new EntityCallbackDiscoverer(
62+
beanFactory instanceof GenericApplicationContext ac ? ac.getBeanFactory() : beanFactory);
6163
}
6264

6365
@Override
@@ -93,8 +95,7 @@ public void addEntityCallback(EntityCallback<?> callback) {
9395
this.callbackDiscoverer.addEntityCallback(callback);
9496
}
9597

96-
static class SimpleEntityCallbackInvoker
97-
implements org.springframework.data.mapping.callback.EntityCallbackInvoker {
98+
static class SimpleEntityCallbackInvoker implements org.springframework.data.mapping.callback.EntityCallbackInvoker {
9899

99100
@Override
100101
public <T> T invokeCallback(EntityCallback<T> callback, T entity,

src/main/java/org/springframework/data/mapping/callback/EntityCallbackDiscoverer.java

+30-38
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import java.lang.reflect.Modifier;
2020
import java.util.ArrayList;
2121
import java.util.Collection;
22+
import java.util.Comparator;
2223
import java.util.LinkedHashSet;
2324
import java.util.List;
2425
import java.util.Map;
@@ -28,10 +29,9 @@
2829

2930
import org.springframework.aop.framework.AopProxyUtils;
3031
import org.springframework.beans.factory.BeanFactory;
31-
import org.springframework.beans.factory.ListableBeanFactory;
3232
import org.springframework.beans.factory.config.BeanDefinition;
3333
import org.springframework.beans.factory.config.ConfigurableBeanFactory;
34-
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
34+
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
3535
import org.springframework.core.ResolvableType;
3636
import org.springframework.core.annotation.AnnotationAwareOrderComparator;
3737
import org.springframework.lang.Nullable;
@@ -56,7 +56,6 @@ class EntityCallbackDiscoverer {
5656
private final Map<Class<?>, ResolvableType> entityTypeCache = new ConcurrentReferenceHashMap<>(64);
5757

5858
@Nullable private ClassLoader beanClassLoader;
59-
@Nullable private BeanFactory beanFactory;
6059

6160
private Object retrievalMutex = this.defaultRetriever;
6261

@@ -104,12 +103,13 @@ void removeEntityCallback(EntityCallback<?> callback) {
104103
* Return a {@link Collection} of all {@link EntityCallback}s matching the given entity type. Non-matching callbacks
105104
* get excluded early.
106105
*
107-
* @param entity the entity to be called back for. Allows for excluding non-matching callbacks early, based on
108-
* cached matching information.
106+
* @param entity the entity to be called back for. Allows for excluding non-matching callbacks early, based on cached
107+
* matching information.
109108
* @param callbackType the source callback type.
110109
* @return a {@link Collection} of {@link EntityCallback}s.
111110
* @see EntityCallback
112111
*/
112+
@SuppressWarnings({ "unchecked", "rawtypes" })
113113
<T extends S, S> Collection<EntityCallback<S>> getEntityCallbacks(Class<T> entity, ResolvableType callbackType) {
114114

115115
Class<?> sourceType = entity;
@@ -121,7 +121,7 @@ <T extends S, S> Collection<EntityCallback<S>> getEntityCallbacks(Class<T> entit
121121
return (Collection) retriever.getEntityCallbacks();
122122
}
123123

124-
if (this.beanClassLoader == null || ClassUtils.isCacheSafe(entity.getClass(), this.beanClassLoader)
124+
if (this.beanClassLoader == null || ClassUtils.isCacheSafe(entity, this.beanClassLoader)
125125
&& (sourceType == null || ClassUtils.isCacheSafe(sourceType, this.beanClassLoader))) {
126126

127127
// Fully synchronized building and caching of a CallbackRetriever
@@ -163,8 +163,8 @@ ResolvableType resolveDeclaredEntityType(Class<?> callbackType) {
163163
* @param retriever the {@link CallbackRetriever}, if supposed to populate one (for caching purposes)
164164
* @return the pre-filtered list of entity callbacks for the given entity and callback type.
165165
*/
166-
private Collection<EntityCallback<?>> retrieveEntityCallbacks(ResolvableType entityType,
167-
ResolvableType callbackType, @Nullable CallbackRetriever retriever) {
166+
private Collection<EntityCallback<?>> retrieveEntityCallbacks(ResolvableType entityType, ResolvableType callbackType,
167+
@Nullable CallbackRetriever retriever) {
168168

169169
List<EntityCallback<?>> allCallbacks = new ArrayList<>();
170170
Set<EntityCallback<?>> callbacks;
@@ -198,16 +198,14 @@ private Collection<EntityCallback<?>> retrieveEntityCallbacks(ResolvableType ent
198198
}
199199

200200
/**
201-
* Set the {@link BeanFactory} and optionally {@link #setBeanClassLoader(ClassLoader) class loader} if not set.
202-
* Pre-loads {@link EntityCallback} beans by scanning the {@link BeanFactory}.
201+
* Set the {@link BeanFactory} and optionally class loader if not set. Pre-loads {@link EntityCallback} beans by
202+
* scanning the {@link BeanFactory}.
203203
*
204204
* @param beanFactory must not be {@literal null}.
205205
* @see org.springframework.beans.factory.BeanFactoryAware#setBeanFactory(BeanFactory)
206206
*/
207207
public void setBeanFactory(BeanFactory beanFactory) {
208208

209-
this.beanFactory = beanFactory;
210-
211209
if (beanFactory instanceof ConfigurableBeanFactory cbf) {
212210

213211
if (this.beanClassLoader == null) {
@@ -228,10 +226,8 @@ static Method lookupCallbackMethod(Class<?> callbackType, Class<?> entityType, O
228226

229227
ReflectionUtils.doWithMethods(callbackType, methods::add, method -> {
230228

231-
if (!Modifier.isPublic(method.getModifiers())
232-
|| method.getParameterCount() != args.length + 1
233-
|| method.isBridge()
234-
|| ReflectionUtils.isObjectMethod(method)) {
229+
if (!Modifier.isPublic(method.getModifiers()) || method.getParameterCount() != args.length + 1
230+
|| method.isBridge() || ReflectionUtils.isObjectMethod(method)) {
235231
return false;
236232
}
237233

@@ -242,9 +238,8 @@ static Method lookupCallbackMethod(Class<?> callbackType, Class<?> entityType, O
242238
return methods.iterator().next();
243239
}
244240

245-
throw new IllegalStateException(
246-
"%s does not define a callback method accepting %s and %s additional arguments".formatted(
247-
ClassUtils.getShortName(callbackType), ClassUtils.getShortName(entityType), args.length));
241+
throw new IllegalStateException("%s does not define a callback method accepting %s and %s additional arguments"
242+
.formatted(ClassUtils.getShortName(callbackType), ClassUtils.getShortName(entityType), args.length));
248243
}
249244

250245
static <T> BiFunction<EntityCallback<T>, T, Object> computeCallbackInvokerFunction(EntityCallback<T> callback,
@@ -267,10 +262,10 @@ static <T> BiFunction<EntityCallback<T>, T, Object> computeCallbackInvokerFuncti
267262
* Filter a callback early through checking its generically declared entity type before trying to instantiate it.
268263
* <p>
269264
* If this method returns {@literal true} for a given callback as a first pass, the callback instance will get
270-
* retrieved and fully evaluated through a {@link #supportsEvent(EntityCallback, ResolvableType, ResolvableType)}
271-
* call afterwards.
265+
* retrieved and fully evaluated through a {@link #supportsEvent(EntityCallback, ResolvableType, ResolvableType)} call
266+
* afterwards.
272267
*
273-
* @param callback the callback's type as determined by the BeanFactory.
268+
* @param callbackType the callback's type as determined by the BeanFactory.
274269
* @param entityType the entity type to check.
275270
* @return whether the given callback should be included in the candidates for the given callback type.
276271
*/
@@ -286,11 +281,9 @@ static boolean supportsEvent(ResolvableType callbackType, ResolvableType entityT
286281
* @param callbackType the source type to check against.
287282
* @return whether the given callback should be included in the candidates for the given callback type.
288283
*/
289-
static boolean supportsEvent(EntityCallback<?> callback, ResolvableType entityType,
290-
ResolvableType callbackType) {
284+
static boolean supportsEvent(EntityCallback<?> callback, ResolvableType entityType, ResolvableType callbackType) {
291285

292-
return callback instanceof EntityCallbackAdapter<?> provider
293-
? provider.supports(callbackType, entityType)
286+
return callback instanceof EntityCallbackAdapter<?> provider ? provider.supports(callbackType, entityType)
294287
: callbackType.isInstance(callback) && supportsEvent(ResolvableType.forInstance(callback), entityType);
295288
}
296289

@@ -310,13 +303,11 @@ void discoverEntityCallbacks(BeanFactory beanFactory) {
310303

311304
// We need both a ListableBeanFactory and BeanDefinitionRegistry here for advanced inspection.
312305
// If we don't get that, use simple inspection.
313-
if (!(beanFactory instanceof ListableBeanFactory && beanFactory instanceof BeanDefinitionRegistry)) {
306+
if (!(beanFactory instanceof ConfigurableListableBeanFactory bf)) {
314307
beanFactory.getBeanProvider(EntityCallback.class).stream().forEach(entityCallbacks::add);
315308
return;
316309
}
317310

318-
var bf = (ListableBeanFactory & BeanDefinitionRegistry) beanFactory;
319-
320311
for (var beanName : bf.getBeanNamesForType(EntityCallback.class)) {
321312

322313
EntityCallback<?> bean = EntityCallback.class.cast(bf.getBean(beanName));
@@ -328,7 +319,7 @@ void discoverEntityCallbacks(BeanFactory beanFactory) {
328319
entityCallbacks.add(bean);
329320
} else {
330321

331-
BeanDefinition definition = bf.getBeanDefinition(beanName);
322+
BeanDefinition definition = bf.getMergedBeanDefinition(beanName);
332323
entityCallbacks.add(new EntityCallbackAdapter<>(bean, definition.getResolvableType()));
333324
}
334325
}
@@ -340,8 +331,8 @@ void discoverEntityCallbacks(BeanFactory beanFactory) {
340331
*
341332
* @author Oliver Drotbohm
342333
*/
343-
private static record EntityCallbackAdapter<T>(EntityCallback<T> delegate, ResolvableType type)
344-
implements EntityCallback<T> {
334+
private record EntityCallbackAdapter<T> (EntityCallback<T> delegate,
335+
ResolvableType type) implements EntityCallback<T> {
345336

346337
boolean supports(ResolvableType callbackType, ResolvableType entityType) {
347338
return callbackType.isInstance(delegate) && supportsEvent(type, entityType);
@@ -351,15 +342,16 @@ boolean supports(ResolvableType callbackType, ResolvableType entityType) {
351342
/**
352343
* Cache key for {@link EntityCallback}, based on event type and source type.
353344
*/
354-
private static record CallbackCacheKey(ResolvableType callbackType, @Nullable Class<?> entityType)
355-
implements Comparable<CallbackCacheKey> {
345+
private record CallbackCacheKey(ResolvableType callbackType,
346+
@Nullable Class<?> entityType) implements Comparable<CallbackCacheKey> {
347+
348+
private static final Comparator<CallbackCacheKey> COMPARATOR = Comparators.<CallbackCacheKey> nullsHigh() //
349+
.thenComparing(it -> it.callbackType.toString()) //
350+
.thenComparing(it -> it.entityType.getName());
356351

357352
@Override
358353
public int compareTo(CallbackCacheKey other) {
359-
360-
return Comparators.<CallbackCacheKey> nullsHigh()
361-
.thenComparing(it -> callbackType.toString())
362-
.thenComparing(it -> entityType.getName()).compare(this, other);
354+
return COMPARATOR.compare(this, other);
363355
}
364356
}
365357

src/test/java/org/springframework/data/mapping/callback/DefaultEntityCallbacksUnitTests.java

-1
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,6 @@ void skipsInvocationUsingJava18ReflectiveTypeRejection() {
170170
void detectsMultipleCallbacksWithinOneClass() {
171171

172172
var ctx = new AnnotationConfigApplicationContext(MultipleCallbacksInOneClassConfig.class);
173-
174173
var callbacks = new DefaultEntityCallbacks(ctx);
175174

176175
var personDocument = new PersonDocument(null, "Walter", null);

0 commit comments

Comments
 (0)