Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,20 @@
import com.google.api.generator.engine.ast.Expr;
import com.google.api.generator.engine.ast.ExprStatement;
import com.google.api.generator.engine.ast.ForStatement;
import com.google.api.generator.engine.ast.GeneralForStatement;
import com.google.api.generator.engine.ast.IfStatement;
import com.google.api.generator.engine.ast.MethodDefinition;
import com.google.api.generator.engine.ast.MethodInvocationExpr;
import com.google.api.generator.engine.ast.NewObjectExpr;
import com.google.api.generator.engine.ast.PrimitiveValue;
import com.google.api.generator.engine.ast.Reference;
import com.google.api.generator.engine.ast.ScopeNode;
import com.google.api.generator.engine.ast.Statement;
import com.google.api.generator.engine.ast.TypeNode;
import com.google.api.generator.engine.ast.ValueExpr;
import com.google.api.generator.engine.ast.Variable;
import com.google.api.generator.engine.ast.VariableExpr;
import com.google.api.generator.gapic.model.Field;
import com.google.api.generator.gapic.model.GapicBatchingSettings;
import com.google.api.generator.gapic.model.Message;
import com.google.api.generator.gapic.model.Method;
Expand All @@ -59,6 +64,7 @@ public class BatchingDescriptorComposer {
private static final TypeNode PARTITION_KEY_TYPE = toType(PartitionKey.class);

private static final String ADD_ALL_METHOD_PATTERN = "addAll%s";
private static final String BATCH_FOO_INDEX_PATTERN = "batch%sIndex";
private static final String GET_LIST_METHOD_PATTERN = "get%sList";
private static final String GET_COUNT_METHOD_PATTERN = "get%sCount";

Expand All @@ -67,6 +73,7 @@ public static Expr createBatchingDescriptorFieldDeclExpr(
List<MethodDefinition> javaMethods = new ArrayList<>();
javaMethods.add(createGetBatchPartitionKeyMethod(method, batchingSettings, messageTypes));
javaMethods.add(createGetRequestBuilderMethod(method, batchingSettings));
javaMethods.add(createSplitResponseMethod(method, batchingSettings, messageTypes));

javaMethods.add(createSplitExceptionMethod(method));
javaMethods.add(createCountElementsMethod(method, batchingSettings));
Expand Down Expand Up @@ -236,6 +243,200 @@ private static MethodDefinition createGetRequestBuilderMethod(
.build();
}

private static MethodDefinition createSplitResponseMethod(
Method method, GapicBatchingSettings batchingSettings, Map<String, Message> messageTypes) {
VariableExpr batchResponseVarExpr =
VariableExpr.withVariable(
Variable.builder().setType(method.outputType()).setName("batchResponse").build());

TypeNode batchedRequestIssuerType = toType(BATCHED_REQUEST_ISSUER_REF, method.outputType());
TypeNode batchVarType =
TypeNode.withReference(
ConcreteReference.builder()
.setClazz(Collection.class)
.setGenerics(
Arrays.asList(
ConcreteReference.wildcardWithUpperBound(
batchedRequestIssuerType.reference())))
.build());
VariableExpr batchVarExpr =
VariableExpr.withVariable(
Variable.builder().setType(batchVarType).setName("batch").build());

VariableExpr responderVarExpr =
VariableExpr.withVariable(
Variable.builder().setType(batchedRequestIssuerType).setName("responder").build());

String upperCamelBatchedFieldName =
JavaStyle.toUpperCamelCase(batchingSettings.batchedFieldName());
VariableExpr batchMessageIndexVarExpr =
VariableExpr.withVariable(
Variable.builder().setType(TypeNode.INT).setName("batchMessageIndex").build());

VariableExpr subresponseElementsVarExpr = null;
boolean hasSubresponseField = batchingSettings.subresponseFieldName() != null;

List<Statement> outerForBody = new ArrayList<>();
if (hasSubresponseField) {
Message outputMessage = messageTypes.get(method.outputType().reference().name());
Preconditions.checkNotNull(
outputMessage, String.format("Output message not found for RPC %s", method.name()));

Field subresponseElementField =
outputMessage.fieldMap().get(batchingSettings.subresponseFieldName());
Preconditions.checkNotNull(
subresponseElementField,
String.format(
"Subresponse field %s not found in message %s",
batchingSettings.subresponseFieldName(), outputMessage.name()));
TypeNode subresponseElementType = subresponseElementField.type();
subresponseElementsVarExpr =
VariableExpr.withVariable(
Variable.builder()
.setType(subresponseElementType)
.setName("subresponseElements")
.build());

VariableExpr subresponseCountVarExpr =
VariableExpr.withVariable(
Variable.builder().setType(TypeNode.LONG).setName("subresponseCount").build());

outerForBody.add(
ExprStatement.withExpr(
AssignmentExpr.builder()
.setVariableExpr(subresponseElementsVarExpr.toBuilder().setIsDecl(true).build())
.setValueExpr(
NewObjectExpr.builder()
.setType(
TypeNode.withReference(ConcreteReference.withClazz(ArrayList.class)))
.setIsGeneric(true)
.build())
.build()));

String getFooCountMethodName = "getMessageCount";
outerForBody.add(
ExprStatement.withExpr(
AssignmentExpr.builder()
.setVariableExpr(subresponseCountVarExpr.toBuilder().setIsDecl(true).build())
.setValueExpr(
MethodInvocationExpr.builder()
.setExprReferenceExpr(responderVarExpr)
.setMethodName(getFooCountMethodName)
.setReturnType(subresponseCountVarExpr.type())
.build())
.build()));

List<Expr> innerSubresponseForExprs = new ArrayList<>();
String getSubresponseFieldMethodName =
String.format(
"get%s", JavaStyle.toUpperCamelCase(batchingSettings.subresponseFieldName()));
Expr addMethodArgExpr =
MethodInvocationExpr.builder()
.setExprReferenceExpr(batchResponseVarExpr)
.setMethodName(getSubresponseFieldMethodName)
.setArguments(batchMessageIndexVarExpr)
.build();
innerSubresponseForExprs.add(
MethodInvocationExpr.builder()
.setExprReferenceExpr(subresponseElementsVarExpr)
.setMethodName("add")
.setArguments(addMethodArgExpr)
.build());
// TODO(miraleung): Increment batchMessageIndexVarExpr.

VariableExpr forIndexVarExpr =
VariableExpr.withVariable(Variable.builder().setType(TypeNode.INT).setName("i").build());
GeneralForStatement innerSubresponseForStatement =
GeneralForStatement.incrementWith(
forIndexVarExpr,
subresponseCountVarExpr,
innerSubresponseForExprs.stream()
.map(e -> ExprStatement.withExpr(e))
.collect(Collectors.toList()));

outerForBody.add(innerSubresponseForStatement);
}

TypeNode responseType = method.outputType();
Expr responseBuilderExpr =
MethodInvocationExpr.builder()
.setStaticReferenceType(responseType)
.setMethodName("newBuilder")
.build();
if (hasSubresponseField) {
Preconditions.checkNotNull(
subresponseElementsVarExpr,
String.format(
"subresponseElements variable should not be null for method %s", method.name()));

responseBuilderExpr =
MethodInvocationExpr.builder()
.setExprReferenceExpr(responseBuilderExpr)
.setMethodName(
String.format(
"addAll%s",
JavaStyle.toUpperCamelCase(batchingSettings.subresponseFieldName())))
.setArguments(subresponseElementsVarExpr)
.build();
}
responseBuilderExpr =
MethodInvocationExpr.builder()
.setExprReferenceExpr(responseBuilderExpr)
.setMethodName("build")
.setReturnType(responseType)
.build();

VariableExpr responseVarExpr =
VariableExpr.withVariable(
Variable.builder().setType(responseType).setName("response").build());
outerForBody.add(
ExprStatement.withExpr(
AssignmentExpr.builder()
.setVariableExpr(responseVarExpr.toBuilder().setIsDecl(true).build())
.setValueExpr(responseBuilderExpr)
.build()));

outerForBody.add(
ExprStatement.withExpr(
MethodInvocationExpr.builder()
.setExprReferenceExpr(responderVarExpr)
.setMethodName("setResponse")
.setArguments(responseVarExpr)
.build()));

ForStatement outerForStatement =
ForStatement.builder()
.setLocalVariableExpr(responderVarExpr.toBuilder().setIsDecl(true).build())
.setCollectionExpr(batchVarExpr)
.setBody(outerForBody)
.build();

List<Statement> bodyStatements = new ArrayList<>();
if (hasSubresponseField) {
bodyStatements.add(
ExprStatement.withExpr(
AssignmentExpr.builder()
.setVariableExpr(batchMessageIndexVarExpr.toBuilder().setIsDecl(true).build())
.setValueExpr(
ValueExpr.withValue(
PrimitiveValue.builder().setType(TypeNode.INT).setValue("0").build()))
.build()));
}
bodyStatements.add(outerForStatement);

return MethodDefinition.builder()
.setIsOverride(true)
.setScope(ScopeNode.PUBLIC)
.setReturnType(TypeNode.VOID)
.setName("splitResponse")
.setArguments(
Arrays.asList(batchResponseVarExpr, batchVarExpr).stream()
.map(v -> v.toBuilder().setIsDecl(true).build())
.collect(Collectors.toList()))
.setBody(bodyStatements)
.build();
}

private static MethodDefinition createSplitExceptionMethod(Method method) {
VariableExpr throwableVarExpr =
VariableExpr.withVariable(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,22 @@ public void batchingDescriptor_hasSubresponseField() {
"};\n",
"}\n",
"@Override\n",
"public void splitResponse(",
"PublishResponse batchResponse, ",
"Collection<? extends BatchedRequestIssuer<PublishResponse>> batch) {\n",
"int batchMessageIndex = 0;\n",
"for (BatchedRequestIssuer<PublishResponse> responder : batch) {\n",
"List<String> subresponseElements = new ArrayList<>();\n",
"long subresponseCount = responder.getMessageCount();\n",
"for (int i = 0; i < subresponseCount; i++) {\n",
"subresponseElements.add(batchResponse.getMessageIds(batchMessageIndex));\n",
"}\n",
"PublishResponse response = ",
"PublishResponse.newBuilder().addAllMessageIds(subresponseElements).build();\n",
"responder.setResponse(response);\n",
"}\n",
"}\n",
"@Override\n",
"public void splitException(",
"Throwable throwable, ",
"Collection<? extends BatchedRequestIssuer<PublishResponse>> batch) {\n",
Expand Down Expand Up @@ -226,6 +242,14 @@ public void batchingDescriptor_noSubresponseField() {
"};\n",
"}\n",
"@Override\n",
"public void splitResponse(WriteLogEntriesResponse batchResponse, ",
"Collection<? extends BatchedRequestIssuer<WriteLogEntriesResponse>> batch) {\n",
"for (BatchedRequestIssuer<WriteLogEntriesResponse> responder : batch) {\n",
"WriteLogEntriesResponse response = WriteLogEntriesResponse.newBuilder().build();\n",
"responder.setResponse(response);\n",
"}\n",
"}\n",
"@Override\n",
"public void splitException(",
"Throwable throwable, ",
"Collection<? extends BatchedRequestIssuer<WriteLogEntriesResponse>> batch) {\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -999,6 +999,19 @@ private static List<Service> parseServices(
+ " }\n"
+ "\n"
+ " @Override\n"
+ " public void splitResponse(\n"
+ " WriteLogEntriesResponse batchResponse,\n"
+ " Collection<? extends BatchedRequestIssuer<WriteLogEntriesResponse>>"
+ " batch) {\n"
+ " for (BatchedRequestIssuer<WriteLogEntriesResponse> responder : batch)"
+ " {\n"
+ " WriteLogEntriesResponse response ="
+ " WriteLogEntriesResponse.newBuilder().build();\n"
+ " responder.setResponse(response);\n"
+ " }\n"
+ " }\n"
+ "\n"
+ " @Override\n"
+ " public void splitException(\n"
+ " Throwable throwable,\n"
+ " Collection<? extends BatchedRequestIssuer<WriteLogEntriesResponse>>"
Expand Down Expand Up @@ -1397,6 +1410,7 @@ private static List<Service> parseServices(
+ "import com.google.pubsub.v1.Topic;\n"
+ "import com.google.pubsub.v1.UpdateTopicRequest;\n"
+ "import java.io.IOException;\n"
+ "import java.util.ArrayList;\n"
+ "import java.util.Collection;\n"
+ "import java.util.List;\n"
+ "import java.util.Objects;\n"
Expand Down Expand Up @@ -1662,6 +1676,25 @@ private static List<Service> parseServices(
+ " }\n"
+ "\n"
+ " @Override\n"
+ " public void splitResponse(\n"
+ " PublishResponse batchResponse,\n"
+ " Collection<? extends BatchedRequestIssuer<PublishResponse>> batch) {\n"
+ " int batchMessageIndex = 0;\n"
+ " for (BatchedRequestIssuer<PublishResponse> responder : batch) {\n"
+ " List<String> subresponseElements = new ArrayList<>();\n"
+ " long subresponseCount = responder.getMessageCount();\n"
+ " for (int i = 0; i < subresponseCount; i++) {\n"
+ " "
+ " subresponseElements.add(batchResponse.getMessageIds(batchMessageIndex));\n"
+ " }\n"
+ " PublishResponse response =\n"
+ " "
+ " PublishResponse.newBuilder().addAllMessageIds(subresponseElements).build();\n"
+ " responder.setResponse(response);\n"
+ " }\n"
+ " }\n"
+ "\n"
+ " @Override\n"
+ " public void splitException(\n"
+ " Throwable throwable,\n"
+ " Collection<? extends BatchedRequestIssuer<PublishResponse>> batch) {\n"
Expand Down