Skip to content

Commit 26054fd

Browse files
committed
AOT contribution for @PersistenceContext and @PersistenceUnit
Closes gh-28364
1 parent 10d2549 commit 26054fd

File tree

8 files changed

+491
-22
lines changed

8 files changed

+491
-22
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* Copyright 2002-2022 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.beans.factory.generator;
18+
19+
import java.lang.reflect.Field;
20+
import java.lang.reflect.Modifier;
21+
22+
import org.springframework.aot.generator.ProtectedAccess.Options;
23+
import org.springframework.javapoet.CodeBlock;
24+
import org.springframework.javapoet.support.MultiStatement;
25+
import org.springframework.util.ReflectionUtils;
26+
27+
/**
28+
* Support for generating {@link Field} access.
29+
*
30+
* @author Stephane Nicoll
31+
* @since 6.0
32+
*/
33+
public class BeanFieldGenerator {
34+
35+
/**
36+
* The {@link Options} to use to access a field.
37+
*/
38+
public static final Options FIELD_OPTIONS = Options.defaults()
39+
.useReflection(member -> Modifier.isPrivate(member.getModifiers())).build();
40+
41+
42+
/**
43+
* Generate the necessary code to set the specified field. Use reflection
44+
* using {@link ReflectionUtils} if necessary.
45+
* @param field the field to set
46+
* @param value a code representation of the field value
47+
* @return the code to set the specified field
48+
*/
49+
public MultiStatement generateSetValue(String target, Field field, CodeBlock value) {
50+
MultiStatement statement = new MultiStatement();
51+
boolean useReflection = Modifier.isPrivate(field.getModifiers());
52+
if (useReflection) {
53+
String fieldName = String.format("%sField", field.getName());
54+
statement.addStatement("$T $L = $T.findField($T.class, $S)", Field.class, fieldName, ReflectionUtils.class,
55+
field.getDeclaringClass(), field.getName());
56+
statement.addStatement("$T.makeAccessible($L)", ReflectionUtils.class, fieldName);
57+
statement.addStatement("$T.setField($L, $L, $L)", ReflectionUtils.class, fieldName, target, value);
58+
}
59+
else {
60+
statement.addStatement("$L.$L = $L", target, field.getName(), value);
61+
}
62+
return statement;
63+
}
64+
65+
}

spring-beans/src/main/java/org/springframework/beans/factory/generator/InjectionGenerator.java

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
import org.springframework.javapoet.CodeBlock;
3535
import org.springframework.javapoet.CodeBlock.Builder;
3636
import org.springframework.util.ClassUtils;
37-
import org.springframework.util.ReflectionUtils;
3837

3938
/**
4039
* Generate the necessary code to {@link #generateInstantiation(Executable)
@@ -53,14 +52,13 @@
5352
*/
5453
public class InjectionGenerator {
5554

56-
private static final Options FIELD_INJECTION_OPTIONS = Options.defaults()
57-
.useReflection(member -> Modifier.isPrivate(member.getModifiers())).build();
58-
5955
private static final Options METHOD_INJECTION_OPTIONS = Options.defaults()
6056
.useReflection(member -> false).build();
6157

6258
private final BeanParameterGenerator parameterGenerator = new BeanParameterGenerator();
6359

60+
private final BeanFieldGenerator fieldGenerator = new BeanFieldGenerator();
61+
6462

6563
/**
6664
* Generate the necessary code to instantiate an object using the specified
@@ -110,7 +108,7 @@ public Options getProtectedAccessInjectionOptions(Member member) {
110108
return METHOD_INJECTION_OPTIONS;
111109
}
112110
if (member instanceof Field) {
113-
return FIELD_INJECTION_OPTIONS;
111+
return BeanFieldGenerator.FIELD_OPTIONS;
114112
}
115113
throw new IllegalArgumentException("Could not handle member " + member);
116114
}
@@ -230,24 +228,13 @@ CodeBlock generateFieldInjection(Field injectionPoint, boolean required) {
230228
code.add("instanceContext.field($S", injectionPoint.getName());
231229
code.add(")\n").indent().indent();
232230
if (required) {
233-
code.add(".invoke(beanFactory, (attributes) ->");
234-
}
235-
else {
236-
code.add(".resolve(beanFactory, false).ifResolved((attributes) ->");
237-
}
238-
boolean hasAssignment = Modifier.isPrivate(injectionPoint.getModifiers());
239-
if (hasAssignment) {
240-
code.beginControlFlow("");
241-
String fieldName = String.format("%sField", injectionPoint.getName());
242-
code.addStatement("$T $L = $T.findField($T.class, $S)", Field.class, fieldName, ReflectionUtils.class,
243-
injectionPoint.getDeclaringClass(), injectionPoint.getName());
244-
code.addStatement("$T.makeAccessible($L)", ReflectionUtils.class, fieldName);
245-
code.addStatement("$T.setField($L, bean, attributes.get(0))", ReflectionUtils.class, fieldName);
246-
code.unindent().add("}");
231+
code.add(".invoke(beanFactory, ");
247232
}
248233
else {
249-
code.add(" bean.$L = attributes.get(0)", injectionPoint.getName());
234+
code.add(".resolve(beanFactory, false).ifResolved(");
250235
}
236+
code.add(this.fieldGenerator.generateSetValue("bean", injectionPoint,
237+
CodeBlock.of("attributes.get(0)")).toLambdaBody("(attributes) ->"));
251238
code.add(")").unindent().unindent();
252239
return code.build();
253240
}
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/*
2+
* Copyright 2002-2022 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.beans.factory.generator;
18+
19+
import java.lang.reflect.Field;
20+
21+
import org.junit.jupiter.api.Test;
22+
23+
import org.springframework.javapoet.CodeBlock;
24+
import org.springframework.javapoet.support.CodeSnippet;
25+
import org.springframework.javapoet.support.MultiStatement;
26+
import org.springframework.util.ReflectionUtils;
27+
28+
import static org.assertj.core.api.Assertions.assertThat;
29+
30+
/**
31+
* Tests for {@link BeanFieldGenerator}.
32+
*
33+
* @author Stephane Nicoll
34+
*/
35+
class BeanFieldGeneratorTests {
36+
37+
private final BeanFieldGenerator generator = new BeanFieldGenerator();
38+
39+
@Test
40+
void generateSetFieldWithPublicField() {
41+
MultiStatement statement = this.generator.generateSetValue("bean",
42+
field(SampleBean.class, "one"), CodeBlock.of("$S", "test"));
43+
assertThat(CodeSnippet.process(statement.toCodeBlock())).isEqualTo("""
44+
bean.one = "test";
45+
""");
46+
}
47+
48+
@Test
49+
void generateSetFieldWithPrivateField() {
50+
MultiStatement statement = this.generator.generateSetValue("example",
51+
field(SampleBean.class, "two"), CodeBlock.of("42"));
52+
CodeSnippet code = CodeSnippet.of(statement.toCodeBlock());
53+
assertThat(code.getSnippet()).isEqualTo("""
54+
Field twoField = ReflectionUtils.findField(BeanFieldGeneratorTests.SampleBean.class, "two");
55+
ReflectionUtils.makeAccessible(twoField);
56+
ReflectionUtils.setField(twoField, example, 42);
57+
""");
58+
assertThat(code.hasImport(ReflectionUtils.class)).isTrue();
59+
assertThat(code.hasImport(BeanFieldGeneratorTests.class)).isTrue();
60+
}
61+
62+
63+
private Field field(Class<?> type, String name) {
64+
Field field = ReflectionUtils.findField(type, name);
65+
assertThat(field).isNotNull();
66+
return field;
67+
}
68+
69+
70+
public static class SampleBean {
71+
72+
public String one;
73+
74+
private int two;
75+
76+
}
77+
78+
}

spring-core/src/main/java/org/springframework/javapoet/support/MultiStatement.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,16 @@ public boolean isEmpty() {
4444
return this.statements.isEmpty();
4545
}
4646

47+
/**
48+
* Add the statements defined in the specified multi statement to this instance.
49+
* @param multiStatement the statements to add
50+
* @return {@code this}, to facilitate method chaining
51+
*/
52+
public MultiStatement add(MultiStatement multiStatement) {
53+
this.statements.addAll(multiStatement.statements);
54+
return this;
55+
}
56+
4757
/**
4858
* Add the specified {@link CodeBlock codeblock} rendered as-is.
4959
* @param codeBlock the code block to add

spring-core/src/test/java/org/springframework/javapoet/support/MultiStatementTests.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,4 +150,17 @@ void multiStatementsWithAddAllAndLambda() {
150150
"}");
151151
}
152152

153+
@Test
154+
void addWithAnotherMultiStatement() {
155+
MultiStatement statements = new MultiStatement();
156+
statements.addStatement(CodeBlock.of("test.invoke()"));
157+
MultiStatement another = new MultiStatement();
158+
another.addStatement(CodeBlock.of("test.another()"));
159+
statements.add(another);
160+
assertThat(statements.toCodeBlock().toString()).isEqualTo("""
161+
test.invoke();
162+
test.another();
163+
""");
164+
}
165+
153166
}

spring-orm/spring-orm.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dependencies {
1111
optional("org.eclipse.persistence:org.eclipse.persistence.jpa")
1212
optional("org.hibernate:hibernate-core-jakarta")
1313
optional("jakarta.servlet:jakarta.servlet-api")
14+
testImplementation(project(":spring-core-test"))
1415
testImplementation(testFixtures(project(":spring-beans")))
1516
testImplementation(testFixtures(project(":spring-context")))
1617
testImplementation(testFixtures(project(":spring-core")))

0 commit comments

Comments
 (0)