diff --git a/src/main/java/com/google/api/generator/gapic/composer/BatchingDescriptorComposer.java b/src/main/java/com/google/api/generator/gapic/composer/BatchingDescriptorComposer.java index bcdb809df3..05b40f77f6 100644 --- a/src/main/java/com/google/api/generator/gapic/composer/BatchingDescriptorComposer.java +++ b/src/main/java/com/google/api/generator/gapic/composer/BatchingDescriptorComposer.java @@ -24,15 +24,20 @@ import com.google.api.generator.engine.ast.Expr; import com.google.api.generator.engine.ast.ExprStatement; import com.google.api.generator.engine.ast.ForStatement; +import com.google.api.generator.engine.ast.GeneralForStatement; import com.google.api.generator.engine.ast.IfStatement; import com.google.api.generator.engine.ast.MethodDefinition; import com.google.api.generator.engine.ast.MethodInvocationExpr; import com.google.api.generator.engine.ast.NewObjectExpr; +import com.google.api.generator.engine.ast.PrimitiveValue; import com.google.api.generator.engine.ast.Reference; import com.google.api.generator.engine.ast.ScopeNode; +import com.google.api.generator.engine.ast.Statement; import com.google.api.generator.engine.ast.TypeNode; +import com.google.api.generator.engine.ast.ValueExpr; import com.google.api.generator.engine.ast.Variable; import com.google.api.generator.engine.ast.VariableExpr; +import com.google.api.generator.gapic.model.Field; import com.google.api.generator.gapic.model.GapicBatchingSettings; import com.google.api.generator.gapic.model.Message; import com.google.api.generator.gapic.model.Method; @@ -59,6 +64,7 @@ public class BatchingDescriptorComposer { private static final TypeNode PARTITION_KEY_TYPE = toType(PartitionKey.class); private static final String ADD_ALL_METHOD_PATTERN = "addAll%s"; + private static final String BATCH_FOO_INDEX_PATTERN = "batch%sIndex"; private static final String GET_LIST_METHOD_PATTERN = "get%sList"; private static final String GET_COUNT_METHOD_PATTERN = "get%sCount"; @@ -67,6 +73,7 @@ public static Expr createBatchingDescriptorFieldDeclExpr( List javaMethods = new ArrayList<>(); javaMethods.add(createGetBatchPartitionKeyMethod(method, batchingSettings, messageTypes)); javaMethods.add(createGetRequestBuilderMethod(method, batchingSettings)); + javaMethods.add(createSplitResponseMethod(method, batchingSettings, messageTypes)); javaMethods.add(createSplitExceptionMethod(method)); javaMethods.add(createCountElementsMethod(method, batchingSettings)); @@ -236,6 +243,200 @@ private static MethodDefinition createGetRequestBuilderMethod( .build(); } + private static MethodDefinition createSplitResponseMethod( + Method method, GapicBatchingSettings batchingSettings, Map messageTypes) { + VariableExpr batchResponseVarExpr = + VariableExpr.withVariable( + Variable.builder().setType(method.outputType()).setName("batchResponse").build()); + + TypeNode batchedRequestIssuerType = toType(BATCHED_REQUEST_ISSUER_REF, method.outputType()); + TypeNode batchVarType = + TypeNode.withReference( + ConcreteReference.builder() + .setClazz(Collection.class) + .setGenerics( + Arrays.asList( + ConcreteReference.wildcardWithUpperBound( + batchedRequestIssuerType.reference()))) + .build()); + VariableExpr batchVarExpr = + VariableExpr.withVariable( + Variable.builder().setType(batchVarType).setName("batch").build()); + + VariableExpr responderVarExpr = + VariableExpr.withVariable( + Variable.builder().setType(batchedRequestIssuerType).setName("responder").build()); + + String upperCamelBatchedFieldName = + JavaStyle.toUpperCamelCase(batchingSettings.batchedFieldName()); + VariableExpr batchMessageIndexVarExpr = + VariableExpr.withVariable( + Variable.builder().setType(TypeNode.INT).setName("batchMessageIndex").build()); + + VariableExpr subresponseElementsVarExpr = null; + boolean hasSubresponseField = batchingSettings.subresponseFieldName() != null; + + List outerForBody = new ArrayList<>(); + if (hasSubresponseField) { + Message outputMessage = messageTypes.get(method.outputType().reference().name()); + Preconditions.checkNotNull( + outputMessage, String.format("Output message not found for RPC %s", method.name())); + + Field subresponseElementField = + outputMessage.fieldMap().get(batchingSettings.subresponseFieldName()); + Preconditions.checkNotNull( + subresponseElementField, + String.format( + "Subresponse field %s not found in message %s", + batchingSettings.subresponseFieldName(), outputMessage.name())); + TypeNode subresponseElementType = subresponseElementField.type(); + subresponseElementsVarExpr = + VariableExpr.withVariable( + Variable.builder() + .setType(subresponseElementType) + .setName("subresponseElements") + .build()); + + VariableExpr subresponseCountVarExpr = + VariableExpr.withVariable( + Variable.builder().setType(TypeNode.LONG).setName("subresponseCount").build()); + + outerForBody.add( + ExprStatement.withExpr( + AssignmentExpr.builder() + .setVariableExpr(subresponseElementsVarExpr.toBuilder().setIsDecl(true).build()) + .setValueExpr( + NewObjectExpr.builder() + .setType( + TypeNode.withReference(ConcreteReference.withClazz(ArrayList.class))) + .setIsGeneric(true) + .build()) + .build())); + + String getFooCountMethodName = "getMessageCount"; + outerForBody.add( + ExprStatement.withExpr( + AssignmentExpr.builder() + .setVariableExpr(subresponseCountVarExpr.toBuilder().setIsDecl(true).build()) + .setValueExpr( + MethodInvocationExpr.builder() + .setExprReferenceExpr(responderVarExpr) + .setMethodName(getFooCountMethodName) + .setReturnType(subresponseCountVarExpr.type()) + .build()) + .build())); + + List innerSubresponseForExprs = new ArrayList<>(); + String getSubresponseFieldMethodName = + String.format( + "get%s", JavaStyle.toUpperCamelCase(batchingSettings.subresponseFieldName())); + Expr addMethodArgExpr = + MethodInvocationExpr.builder() + .setExprReferenceExpr(batchResponseVarExpr) + .setMethodName(getSubresponseFieldMethodName) + .setArguments(batchMessageIndexVarExpr) + .build(); + innerSubresponseForExprs.add( + MethodInvocationExpr.builder() + .setExprReferenceExpr(subresponseElementsVarExpr) + .setMethodName("add") + .setArguments(addMethodArgExpr) + .build()); + // TODO(miraleung): Increment batchMessageIndexVarExpr. + + VariableExpr forIndexVarExpr = + VariableExpr.withVariable(Variable.builder().setType(TypeNode.INT).setName("i").build()); + GeneralForStatement innerSubresponseForStatement = + GeneralForStatement.incrementWith( + forIndexVarExpr, + subresponseCountVarExpr, + innerSubresponseForExprs.stream() + .map(e -> ExprStatement.withExpr(e)) + .collect(Collectors.toList())); + + outerForBody.add(innerSubresponseForStatement); + } + + TypeNode responseType = method.outputType(); + Expr responseBuilderExpr = + MethodInvocationExpr.builder() + .setStaticReferenceType(responseType) + .setMethodName("newBuilder") + .build(); + if (hasSubresponseField) { + Preconditions.checkNotNull( + subresponseElementsVarExpr, + String.format( + "subresponseElements variable should not be null for method %s", method.name())); + + responseBuilderExpr = + MethodInvocationExpr.builder() + .setExprReferenceExpr(responseBuilderExpr) + .setMethodName( + String.format( + "addAll%s", + JavaStyle.toUpperCamelCase(batchingSettings.subresponseFieldName()))) + .setArguments(subresponseElementsVarExpr) + .build(); + } + responseBuilderExpr = + MethodInvocationExpr.builder() + .setExprReferenceExpr(responseBuilderExpr) + .setMethodName("build") + .setReturnType(responseType) + .build(); + + VariableExpr responseVarExpr = + VariableExpr.withVariable( + Variable.builder().setType(responseType).setName("response").build()); + outerForBody.add( + ExprStatement.withExpr( + AssignmentExpr.builder() + .setVariableExpr(responseVarExpr.toBuilder().setIsDecl(true).build()) + .setValueExpr(responseBuilderExpr) + .build())); + + outerForBody.add( + ExprStatement.withExpr( + MethodInvocationExpr.builder() + .setExprReferenceExpr(responderVarExpr) + .setMethodName("setResponse") + .setArguments(responseVarExpr) + .build())); + + ForStatement outerForStatement = + ForStatement.builder() + .setLocalVariableExpr(responderVarExpr.toBuilder().setIsDecl(true).build()) + .setCollectionExpr(batchVarExpr) + .setBody(outerForBody) + .build(); + + List bodyStatements = new ArrayList<>(); + if (hasSubresponseField) { + bodyStatements.add( + ExprStatement.withExpr( + AssignmentExpr.builder() + .setVariableExpr(batchMessageIndexVarExpr.toBuilder().setIsDecl(true).build()) + .setValueExpr( + ValueExpr.withValue( + PrimitiveValue.builder().setType(TypeNode.INT).setValue("0").build())) + .build())); + } + bodyStatements.add(outerForStatement); + + return MethodDefinition.builder() + .setIsOverride(true) + .setScope(ScopeNode.PUBLIC) + .setReturnType(TypeNode.VOID) + .setName("splitResponse") + .setArguments( + Arrays.asList(batchResponseVarExpr, batchVarExpr).stream() + .map(v -> v.toBuilder().setIsDecl(true).build()) + .collect(Collectors.toList())) + .setBody(bodyStatements) + .build(); + } + private static MethodDefinition createSplitExceptionMethod(Method method) { VariableExpr throwableVarExpr = VariableExpr.withVariable( diff --git a/src/test/java/com/google/api/generator/gapic/composer/BatchingDescriptorComposerTest.java b/src/test/java/com/google/api/generator/gapic/composer/BatchingDescriptorComposerTest.java index e412e0a832..844f8b3cbd 100644 --- a/src/test/java/com/google/api/generator/gapic/composer/BatchingDescriptorComposerTest.java +++ b/src/test/java/com/google/api/generator/gapic/composer/BatchingDescriptorComposerTest.java @@ -129,6 +129,22 @@ public void batchingDescriptor_hasSubresponseField() { "};\n", "}\n", "@Override\n", + "public void splitResponse(", + "PublishResponse batchResponse, ", + "Collection> batch) {\n", + "int batchMessageIndex = 0;\n", + "for (BatchedRequestIssuer responder : batch) {\n", + "List subresponseElements = new ArrayList<>();\n", + "long subresponseCount = responder.getMessageCount();\n", + "for (int i = 0; i < subresponseCount; i++) {\n", + "subresponseElements.add(batchResponse.getMessageIds(batchMessageIndex));\n", + "}\n", + "PublishResponse response = ", + "PublishResponse.newBuilder().addAllMessageIds(subresponseElements).build();\n", + "responder.setResponse(response);\n", + "}\n", + "}\n", + "@Override\n", "public void splitException(", "Throwable throwable, ", "Collection> batch) {\n", @@ -226,6 +242,14 @@ public void batchingDescriptor_noSubresponseField() { "};\n", "}\n", "@Override\n", + "public void splitResponse(WriteLogEntriesResponse batchResponse, ", + "Collection> batch) {\n", + "for (BatchedRequestIssuer responder : batch) {\n", + "WriteLogEntriesResponse response = WriteLogEntriesResponse.newBuilder().build();\n", + "responder.setResponse(response);\n", + "}\n", + "}\n", + "@Override\n", "public void splitException(", "Throwable throwable, ", "Collection> batch) {\n", diff --git a/src/test/java/com/google/api/generator/gapic/composer/ServiceStubSettingsClassComposerTest.java b/src/test/java/com/google/api/generator/gapic/composer/ServiceStubSettingsClassComposerTest.java index 10456faed3..52b025ead1 100644 --- a/src/test/java/com/google/api/generator/gapic/composer/ServiceStubSettingsClassComposerTest.java +++ b/src/test/java/com/google/api/generator/gapic/composer/ServiceStubSettingsClassComposerTest.java @@ -999,6 +999,19 @@ private static List parseServices( + " }\n" + "\n" + " @Override\n" + + " public void splitResponse(\n" + + " WriteLogEntriesResponse batchResponse,\n" + + " Collection>" + + " batch) {\n" + + " for (BatchedRequestIssuer responder : batch)" + + " {\n" + + " WriteLogEntriesResponse response =" + + " WriteLogEntriesResponse.newBuilder().build();\n" + + " responder.setResponse(response);\n" + + " }\n" + + " }\n" + + "\n" + + " @Override\n" + " public void splitException(\n" + " Throwable throwable,\n" + " Collection>" @@ -1397,6 +1410,7 @@ private static List parseServices( + "import com.google.pubsub.v1.Topic;\n" + "import com.google.pubsub.v1.UpdateTopicRequest;\n" + "import java.io.IOException;\n" + + "import java.util.ArrayList;\n" + "import java.util.Collection;\n" + "import java.util.List;\n" + "import java.util.Objects;\n" @@ -1662,6 +1676,25 @@ private static List parseServices( + " }\n" + "\n" + " @Override\n" + + " public void splitResponse(\n" + + " PublishResponse batchResponse,\n" + + " Collection> batch) {\n" + + " int batchMessageIndex = 0;\n" + + " for (BatchedRequestIssuer responder : batch) {\n" + + " List subresponseElements = new ArrayList<>();\n" + + " long subresponseCount = responder.getMessageCount();\n" + + " for (int i = 0; i < subresponseCount; i++) {\n" + + " " + + " subresponseElements.add(batchResponse.getMessageIds(batchMessageIndex));\n" + + " }\n" + + " PublishResponse response =\n" + + " " + + " PublishResponse.newBuilder().addAllMessageIds(subresponseElements).build();\n" + + " responder.setResponse(response);\n" + + " }\n" + + " }\n" + + "\n" + + " @Override\n" + " public void splitException(\n" + " Throwable throwable,\n" + " Collection> batch) {\n"