diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java index 3dbc20299d..5799a7c3fb 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java @@ -227,8 +227,8 @@ public DBObject toDbObject(String inputCollectionName, AggregationOperationConte operationDocuments.add(operation.toDBObject(context)); - if (operation instanceof AggregationOperationContext) { - context = (AggregationOperationContext) operation; + if (operation instanceof FieldsExposingAggregationOperation) { + context = new WrappingExposedFieldsAggregationOperationContext((FieldsExposingAggregationOperation) operation); } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFields.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFields.java index 848c5cccd4..4626835744 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFields.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFields.java @@ -49,6 +49,22 @@ public static ExposedFields from(ExposedField... fields) { return from(Arrays.asList(fields)); } + /** + * Creates a new {@link ExposedFields} instance from the given {@link ExposedFields}. + * + * @param fields must not be {@literal null}. + * @return + */ + public static ExposedFields from(ExposedFields fields) { + + List exposedFields = new ArrayList(); + for (ExposedField field : fields) { + exposedFields.add(field); + } + + return from(exposedFields); + } + /** * Creates a new {@link ExposedFields} instance from the given {@link ExposedField}s. * @@ -134,6 +150,24 @@ public ExposedFields and(ExposedField field) { return new ExposedFields(field.synthetic ? originalFields : result, field.synthetic ? result : syntheticFields); } + /** + * Creates a new {@link ExposedFields} adding the given {@link ExposedFields}. + * + * @param field must not be {@literal null}. + * @return + */ + public ExposedFields and(ExposedFields fields) { + + Assert.notNull(fields, "Exposed fields must not be null!"); + + ExposedFields result = from(this); + for (ExposedField field : fields) { + result = result.and(field); + } + + return result; + } + /** * Returns the field with the given name or {@literal null} if no field with the given name is available. * diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java index 8b6c517a2e..c58eb79ef2 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java @@ -15,7 +15,6 @@ */ package org.springframework.data.mongodb.core.aggregation; -import org.springframework.data.mongodb.core.aggregation.ExposedFields.ExposedField; import org.springframework.data.mongodb.core.aggregation.ExposedFields.FieldReference; import com.mongodb.DBObject; @@ -46,22 +45,4 @@ public DBObject getMappedObject(DBObject dbObject) { public FieldReference getReference(Field field) { return getReference(field.getTarget()); } - - /* - * (non-Javadoc) - * @see org.springframework.data.mongodb.core.aggregation.AggregationOperationContext#getReference(java.lang.String) - */ - @Override - public FieldReference getReference(String name) { - - ExposedField field = getFields().getField(name); - - if (field != null) { - return new FieldReference(field); - } - - throw new IllegalArgumentException(String.format("Invalid reference '%s'!", name)); - } - - protected abstract ExposedFields getFields(); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/FieldsExposingAggregationOperation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/FieldsExposingAggregationOperation.java new file mode 100644 index 0000000000..eeef300e3d --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/FieldsExposingAggregationOperation.java @@ -0,0 +1,27 @@ +/* + * Copyright 2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.core.aggregation; + +/** + * {@link AggregationOperation} that exposes new {@link ExposedFields} that can be used for later aggregation pipeline + * {@code AggregationOperation}s. + * + * @author Thomas Darimont + */ +public interface FieldsExposingAggregationOperation extends AggregationOperation { + + ExposedFields getFields(); +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/GroupOperation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/GroupOperation.java index e3d8f1f3a6..9d16ee1bb1 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/GroupOperation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/GroupOperation.java @@ -38,7 +38,7 @@ * @author Oliver Gierke * @since 1.3 */ -public class GroupOperation extends ExposedFieldsAggregationOperationContext implements AggregationOperation { +public class GroupOperation implements FieldsExposingAggregationOperation { private final ExposedFields nonSynthecticFields; private final List operations; diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperation.java index c34ba57b60..b924e1af7d 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ProjectionOperation.java @@ -41,7 +41,7 @@ * @author Oliver Gierke * @since 1.3 */ -public class ProjectionOperation extends ExposedFieldsAggregationOperationContext implements AggregationOperation { +public class ProjectionOperation implements FieldsExposingAggregationOperation { private static final List NONE = Collections.emptyList(); @@ -152,7 +152,7 @@ public ProjectionOperation andInclude(Fields fields) { * @see org.springframework.data.mongodb.core.aggregation.ExposedFieldsAggregationOperationContext#getFields() */ @Override - protected ExposedFields getFields() { + public ExposedFields getFields() { ExposedFields fields = null; diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/TypeBasedAggregationOperationContext.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/TypeBasedAggregationOperationContext.java index 16b0a6c0ae..31a306f674 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/TypeBasedAggregationOperationContext.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/TypeBasedAggregationOperationContext.java @@ -36,7 +36,7 @@ * @author Oliver Gierke * @since 1.3 */ -public class TypeBasedAggregationOperationContext implements AggregationOperationContext { +public class TypeBasedAggregationOperationContext extends ExposedFieldsAggregationOperationContext { private final Class type; private final MappingContext, MongoPersistentProperty> mappingContext; diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/UnwindOperation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/UnwindOperation.java index 110bbd190c..5410f79f32 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/UnwindOperation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/UnwindOperation.java @@ -29,7 +29,7 @@ * @author Oliver Gierke * @since 1.3 */ -public class UnwindOperation extends ExposedFieldsAggregationOperationContext implements AggregationOperation { +public class UnwindOperation implements AggregationOperation { private final ExposedField field; @@ -44,15 +44,6 @@ public UnwindOperation(Field field) { this.field = new ExposedField(field, true); } - /* - * (non-Javadoc) - * @see org.springframework.data.mongodb.core.aggregation.ExposedFieldsAggregationOperationContext#getFields() - */ - @Override - protected ExposedFields getFields() { - return ExposedFields.from(field); - } - /* * (non-Javadoc) * @see org.springframework.data.mongodb.core.aggregation.AggregationOperation#toDBObject(org.springframework.data.mongodb.core.aggregation.AggregationOperationContext) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/WrappingExposedFieldsAggregationOperationContext.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/WrappingExposedFieldsAggregationOperationContext.java new file mode 100644 index 0000000000..bafdd6ccfa --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/WrappingExposedFieldsAggregationOperationContext.java @@ -0,0 +1,49 @@ +package org.springframework.data.mongodb.core.aggregation; + +import org.springframework.data.mongodb.core.aggregation.ExposedFields.ExposedField; +import org.springframework.data.mongodb.core.aggregation.ExposedFields.FieldReference; + +/** + * {@link AggregationOperationContext} that combines the available field references from a given + * {@code AggregationOperationContext} and an {@link FieldsExposingAggregationOperation}. + * + * @author Thomas Darimont + * @since 1.4 + */ +class WrappingExposedFieldsAggregationOperationContext extends ExposedFieldsAggregationOperationContext { + + private final FieldsExposingAggregationOperation fieldExposingOperation; + + /** + * Creates a new {@link WrappingExposedFieldsAggregationOperationContext} from the given + * {@link FieldsExposingAggregationOperation}. + * + * @param fieldExposingOperation + */ + public WrappingExposedFieldsAggregationOperationContext(FieldsExposingAggregationOperation fieldExposingOperation) { + this.fieldExposingOperation = fieldExposingOperation; + } + + /* (non-Javadoc) + * @see org.springframework.data.mongodb.core.aggregation.ExposedFieldsAggregationOperationContext#getFields() + */ + private ExposedFields getFields() { + return fieldExposingOperation.getFields(); + } + + /* + * (non-Javadoc) + * @see org.springframework.data.mongodb.core.aggregation.AggregationOperationContext#getReference(java.lang.String) + */ + @Override + public FieldReference getReference(String name) { + + ExposedField field = getFields().getField(name); + + if (field != null) { + return new FieldReference(field); + } + + throw new IllegalArgumentException(String.format("Invalid reference '%s'!", name)); + } +} diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java index e017bfb2d2..5255ce6709 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java @@ -84,6 +84,7 @@ private void cleanDb() { mongoTemplate.dropCollection(INPUT_COLLECTION); mongoTemplate.dropCollection(Product.class); mongoTemplate.dropCollection(UserWithLikes.class); + mongoTemplate.dropCollection(DATAMONGO753.class); } /** @@ -452,7 +453,38 @@ public void arithmenticOperatorsInProjectionExample() { assertThat((Double) resultList.get(0).get("netPriceMul2"), is(netPrice * 2)); assertThat((Double) resultList.get(0).get("netPriceDiv119"), is(netPrice / 1.19)); assertThat((Integer) resultList.get(0).get("spaceUnitsMod2"), is(spaceUnits % 2)); + } + /** + * @see DATAMONGO-753 + * @see http + * ://stackoverflow.com/questions/18653574/spring-data-mongodb-aggregation-framework-invalid-reference-in-group + * -operati + */ + @Test + public void allowNestedFieldReferencesAsGroupIdsInGroupExpressions() { + + mongoTemplate.insert(new DATAMONGO753().withPDs(new PD("A", 1), new PD("B", 1), new PD("C", 1))); + mongoTemplate.insert(new DATAMONGO753().withPDs(new PD("B", 1), new PD("B", 1), new PD("C", 1))); + + Aggregation agg = newAggregation( // + unwind("pd"), // + group("pd.pDch") // the nested field expression + .sum("pd.up").as("uplift") // + , project("_id", "uplift")); + + AggregationResults result = mongoTemplate.aggregate(agg, // + DATAMONGO753.class // + , DBObject.class); + List stats = result.getMappedResults(); + + assertThat(stats.size(), is(3)); + assertThat(stats.get(0).get("_id").toString(), is("C")); + assertThat((Integer) stats.get(0).get("uplift"), is(2)); + assertThat(stats.get(1).get("_id").toString(), is("B")); + assertThat((Integer) stats.get(1).get("uplift"), is(3)); + assertThat(stats.get(2).get("_id").toString(), is("A")); + assertThat((Integer) stats.get(2).get("uplift"), is(1)); } private void assertLikeStats(LikeStats like, String id, long count) { @@ -502,4 +534,22 @@ private static void assertTagCount(String tag, int n, TagCount tagCount) { assertThat(tagCount.getN(), is(n)); } + static class DATAMONGO753 { + PD[] pd; + + DATAMONGO753 withPDs(PD... pds) { + this.pd = pds; + return this; + } + } + + static class PD { + String pDch; + int up; + + public PD(String pDch, int up) { + this.pDch = pDch; + this.up = up; + } + } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationUnitTests.java index 6ae0ce90ae..0753cfd426 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationUnitTests.java @@ -15,32 +15,75 @@ */ package org.springframework.data.mongodb.core.aggregation; +import static org.springframework.data.mongodb.core.aggregation.Aggregation.*; +import static org.springframework.data.mongodb.core.query.Criteria.*; + import org.junit.Test; /** * Unit tests for {@link Aggregation}. * * @author Oliver Gierke + * @author Thomas Darimont */ public class AggregationUnitTests { @Test(expected = IllegalArgumentException.class) public void rejectsNullAggregationOperation() { - Aggregation.newAggregation((AggregationOperation[]) null); + newAggregation((AggregationOperation[]) null); } @Test(expected = IllegalArgumentException.class) public void rejectsNullTypedAggregationOperation() { - Aggregation.newAggregation(String.class, (AggregationOperation[]) null); + newAggregation(String.class, (AggregationOperation[]) null); } @Test(expected = IllegalArgumentException.class) public void rejectsNoAggregationOperation() { - Aggregation.newAggregation(new AggregationOperation[0]); + newAggregation(new AggregationOperation[0]); } @Test(expected = IllegalArgumentException.class) public void rejectsNoTypedAggregationOperation() { - Aggregation.newAggregation(String.class, new AggregationOperation[0]); + newAggregation(String.class, new AggregationOperation[0]); + } + + /** + * @see DATAMONGO-753 + */ + @Test(expected = IllegalArgumentException.class) + public void checkForCorrectFieldScopeTransfer() { + + newAggregation( // + project("a", "b"), // + group("a").count().as("cnt"), // a was introduced to the context by the project operation + project("cnt", "b") // b was removed from the context by the group operation + ).toDbObject("foo", Aggregation.DEFAULT_CONTEXT); // -> triggers IllegalArgumentException + } + + /** + * @see DATAMONGO-753 + */ + @Test + public void unwindOperationShouldNotChangeAvailableFields() { + + newAggregation( // + project("a", "b"), // + unwind("a"), // + project("a", "b") // b should still be available + ).toDbObject("foo", Aggregation.DEFAULT_CONTEXT); + } + + /** + * @see DATAMONGO-753 + */ + @Test + public void matchOperationShouldNotChangeAvailableFields() { + + newAggregation( // + project("a", "b"), // + match(where("a").gte(1)), // + project("a", "b") // b should still be available + ).toDbObject("foo", Aggregation.DEFAULT_CONTEXT); } }