diff --git a/driver-core/src/main/com/mongodb/client/model/expressions/Expressions.java b/driver-core/src/main/com/mongodb/client/model/expressions/Expressions.java index 6a2f4442219..a11cef3dfa4 100644 --- a/driver-core/src/main/com/mongodb/client/model/expressions/Expressions.java +++ b/driver-core/src/main/com/mongodb/client/model/expressions/Expressions.java @@ -20,6 +20,7 @@ import org.bson.BsonArray; import org.bson.BsonBoolean; import org.bson.BsonDateTime; +import org.bson.BsonDecimal128; import org.bson.BsonDocument; import org.bson.BsonDouble; import org.bson.BsonInt32; @@ -28,6 +29,7 @@ import org.bson.BsonString; import org.bson.BsonValue; import org.bson.conversions.Bson; +import org.bson.types.Decimal128; import java.time.Instant; import java.util.ArrayList; @@ -35,6 +37,7 @@ import java.util.List; import java.util.stream.Collectors; +import static com.mongodb.client.model.expressions.MqlExpression.AstPlaceholder; /** * Convenience methods related to {@link Expression}. @@ -52,7 +55,7 @@ private Expressions() {} */ public static BooleanExpression of(final boolean of) { // we intentionally disallow ofBoolean(null) - return new MqlExpression<>((codecRegistry) -> new BsonBoolean(of)); + return new MqlExpression<>((codecRegistry) -> new AstPlaceholder(new BsonBoolean(of))); } /** @@ -63,16 +66,21 @@ public static BooleanExpression of(final boolean of) { * @return the integer expression */ public static IntegerExpression of(final int of) { - return new MqlExpression<>((codecRegistry) -> new BsonInt32(of)); + return new MqlExpression<>((codecRegistry) -> new AstPlaceholder(new BsonInt32(of))); } public static IntegerExpression of(final long of) { - return new MqlExpression<>((codecRegistry) -> new BsonInt64(of)); + return new MqlExpression<>((codecRegistry) -> new AstPlaceholder(new BsonInt64(of))); } public static NumberExpression of(final double of) { - return new MqlExpression<>((codecRegistry) -> new BsonDouble(of)); + return new MqlExpression<>((codecRegistry) -> new AstPlaceholder(new BsonDouble(of))); + } + public static NumberExpression of(final Decimal128 of) { + Assertions.notNull("Decimal128", of); + return new MqlExpression<>((codecRegistry) -> new AstPlaceholder(new BsonDecimal128(of))); } public static DateExpression of(final Instant of) { - return new MqlExpression<>((codecRegistry) -> new BsonDateTime(of.toEpochMilli())); + Assertions.notNull("Instant", of); + return new MqlExpression<>((codecRegistry) -> new AstPlaceholder(new BsonDateTime(of.toEpochMilli()))); } /** @@ -84,7 +92,7 @@ public static DateExpression of(final Instant of) { */ public static StringExpression of(final String of) { Assertions.notNull("String", of); - return new MqlExpression<>((codecRegistry) -> new BsonString(of)); + return new MqlExpression<>((codecRegistry) -> new AstPlaceholder(new BsonString(of))); } /** @@ -99,7 +107,7 @@ public static ArrayExpression ofBooleanArray(final boolean... for (boolean b : array) { result.add(new BsonBoolean(b)); } - return new MqlExpression<>((cr) -> new BsonArray(result)); + return new MqlExpression<>((cr) -> new AstPlaceholder(new BsonArray(result))); } @@ -107,7 +115,7 @@ public static ArrayExpression ofIntegerArray(final int... ofI List array = Arrays.stream(ofIntegerArray) .mapToObj(BsonInt32::new) .collect(Collectors.toList()); - return new MqlExpression<>((cr) -> new BsonArray(array)); + return new MqlExpression<>((cr) -> new AstPlaceholder(new BsonArray(array))); } public static DocumentExpression ofDocument(final Bson document) { @@ -115,11 +123,26 @@ public static DocumentExpression ofDocument(final Bson document) { // All documents are wrapped in a $literal. If we don't wrap, we need to // check for empty documents and documents that are actually expressions // (and need to be wrapped in $literal anyway). This would be brittle. - return new MqlExpression<>((cr) -> new BsonDocument("$literal", - document.toBsonDocument(BsonDocument.class, cr))); + return new MqlExpression<>((cr) -> new AstPlaceholder(new BsonDocument("$literal", + document.toBsonDocument(BsonDocument.class, cr)))); } public static R ofNull() { - return new MqlExpression<>((cr) -> new BsonNull()).assertImplementsAllExpressions(); + return new MqlExpression<>((cr) -> new AstPlaceholder(new BsonNull())) + .assertImplementsAllExpressions(); + } + + static NumberExpression numberToExpression(final Number number) { + if (number instanceof Integer) { + return of((int) number); + } else if (number instanceof Long) { + return of((long) number); + } else if (number instanceof Double) { + return of((double) number); + } else if (number instanceof Decimal128) { + return of((Decimal128) number); + } else { + throw new IllegalArgumentException("Number must be one of: Integer, Long, Double, Decimal128"); + } } } diff --git a/driver-core/src/main/com/mongodb/client/model/expressions/IntegerExpression.java b/driver-core/src/main/com/mongodb/client/model/expressions/IntegerExpression.java index 6b53b9d62b8..eab541c0474 100644 --- a/driver-core/src/main/com/mongodb/client/model/expressions/IntegerExpression.java +++ b/driver-core/src/main/com/mongodb/client/model/expressions/IntegerExpression.java @@ -20,5 +20,26 @@ * Expresses an integer value. */ public interface IntegerExpression extends NumberExpression { + IntegerExpression multiply(IntegerExpression i); + default IntegerExpression multiply(final int multiply) { + return this.multiply(Expressions.of(multiply)); + } + + IntegerExpression add(IntegerExpression i); + + default IntegerExpression add(final int add) { + return this.add(Expressions.of(add)); + } + + IntegerExpression subtract(IntegerExpression i); + + default IntegerExpression subtract(final int subtract) { + return this.subtract(Expressions.of(subtract)); + } + + IntegerExpression max(IntegerExpression i); + IntegerExpression min(IntegerExpression i); + + IntegerExpression abs(); } diff --git a/driver-core/src/main/com/mongodb/client/model/expressions/MqlExpression.java b/driver-core/src/main/com/mongodb/client/model/expressions/MqlExpression.java index 70c075ab87d..302dd7e0610 100644 --- a/driver-core/src/main/com/mongodb/client/model/expressions/MqlExpression.java +++ b/driver-core/src/main/com/mongodb/client/model/expressions/MqlExpression.java @@ -29,9 +29,9 @@ final class MqlExpression implements Expression, BooleanExpression, IntegerExpression, NumberExpression, StringExpression, DateExpression, DocumentExpression, ArrayExpression { - private final Function fn; + private final Function fn; - MqlExpression(final Function fn) { + MqlExpression(final Function fn) { this.fn = fn; } @@ -41,33 +41,41 @@ final class MqlExpression * {@link MqlExpressionCodec}. */ BsonValue toBsonValue(final CodecRegistry codecRegistry) { - return fn.apply(codecRegistry); + return fn.apply(codecRegistry).bsonValue; } - private Function astDoc(final String name, final BsonDocument value) { - return (cr) -> new BsonDocument(name, value); + private AstPlaceholder astDoc(final String name, final BsonDocument value) { + return new AstPlaceholder(new BsonDocument(name, value)); } - private Function ast(final String name) { - return (cr) -> new BsonDocument(name, this.toBsonValue(cr)); + static final class AstPlaceholder { + private final BsonValue bsonValue; + + AstPlaceholder(final BsonValue bsonValue) { + this.bsonValue = bsonValue; + } + } + + private Function ast(final String name) { + return (cr) -> new AstPlaceholder(new BsonDocument(name, this.toBsonValue(cr))); } - private Function ast(final String name, final Expression param1) { + private Function ast(final String name, final Expression param1) { return (cr) -> { BsonArray value = new BsonArray(); value.add(this.toBsonValue(cr)); value.add(extractBsonValue(cr, param1)); - return new BsonDocument(name, value); + return new AstPlaceholder(new BsonDocument(name, value)); }; } - private Function ast(final String name, final Expression param1, final Expression param2) { + private Function ast(final String name, final Expression param1, final Expression param2) { return (cr) -> { BsonArray value = new BsonArray(); value.add(this.toBsonValue(cr)); value.add(extractBsonValue(cr, param1)); value.add(extractBsonValue(cr, param2)); - return new BsonDocument(name, value); + return new AstPlaceholder(new BsonDocument(name, value)); }; } @@ -89,12 +97,12 @@ R assertImplementsAllExpressions() { return (R) this; } - private static R newMqlExpression(final Function ast) { + private static R newMqlExpression(final Function ast) { return new MqlExpression<>(ast).assertImplementsAllExpressions(); } private R variable(final String variable) { - return newMqlExpression((cr) -> new BsonString(variable)); + return newMqlExpression((cr) -> new AstPlaceholder(new BsonString(variable))); } /** @see BooleanExpression */ @@ -159,7 +167,7 @@ public ArrayExpression map(final Function((cr) -> astDoc("$map", new BsonDocument() .append("input", this.toBsonValue(cr)) - .append("in", extractBsonValue(cr, in.apply(varThis)))).apply(cr)); + .append("in", extractBsonValue(cr, in.apply(varThis))))); } @Override @@ -167,7 +175,7 @@ public ArrayExpression filter(final Function((cr) -> astDoc("$filter", new BsonDocument() .append("input", this.toBsonValue(cr)) - .append("cond", extractBsonValue(cr, cond.apply(varThis)))).apply(cr)); + .append("cond", extractBsonValue(cr, cond.apply(varThis))))); } @Override @@ -177,7 +185,81 @@ public T reduce(final T initialValue, final BinaryOperator in) { return newMqlExpression((cr) -> astDoc("$reduce", new BsonDocument() .append("input", this.toBsonValue(cr)) .append("initialValue", extractBsonValue(cr, initialValue)) - .append("in", extractBsonValue(cr, in.apply(varThis, varValue)))).apply(cr)); + .append("in", extractBsonValue(cr, in.apply(varThis, varValue))))); + } + + + /** @see IntegerExpression + * @see NumberExpression */ + + @Override + public IntegerExpression multiply(final NumberExpression n) { + return newMqlExpression(ast("$multiply", n)); + } + + @Override + public NumberExpression add(final NumberExpression n) { + return new MqlExpression<>(ast("$add", n)); + } + + @Override + public NumberExpression divide(final NumberExpression n) { + return new MqlExpression<>(ast("$divide", n)); + } + + @Override + public NumberExpression max(final NumberExpression n) { + return new MqlExpression<>(ast("$max", n)); + } + + @Override + public NumberExpression min(final NumberExpression n) { + return new MqlExpression<>(ast("$min", n)); + } + + @Override + public IntegerExpression round() { + return new MqlExpression<>(ast("$round")); + } + + @Override + public NumberExpression round(final IntegerExpression place) { + return new MqlExpression<>(ast("$round", place)); + } + + @Override + public IntegerExpression multiply(final IntegerExpression i) { + return new MqlExpression<>(ast("$multiply", i)); + } + + @Override + public IntegerExpression abs() { + return newMqlExpression(ast("$abs")); + } + + @Override + public NumberExpression subtract(final NumberExpression n) { + return new MqlExpression<>(ast("$subtract", n)); + } + + @Override + public IntegerExpression add(final IntegerExpression i) { + return new MqlExpression<>(ast("$add", i)); + } + + @Override + public IntegerExpression subtract(final IntegerExpression i) { + return new MqlExpression<>(ast("$subtract", i)); + } + + @Override + public IntegerExpression max(final IntegerExpression i) { + return new MqlExpression<>(ast("$max", i)); + } + + @Override + public IntegerExpression min(final IntegerExpression i) { + return new MqlExpression<>(ast("$min", i)); } diff --git a/driver-core/src/main/com/mongodb/client/model/expressions/NumberExpression.java b/driver-core/src/main/com/mongodb/client/model/expressions/NumberExpression.java index 25084473853..bbef9f5768f 100644 --- a/driver-core/src/main/com/mongodb/client/model/expressions/NumberExpression.java +++ b/driver-core/src/main/com/mongodb/client/model/expressions/NumberExpression.java @@ -21,4 +21,37 @@ */ public interface NumberExpression extends Expression { + NumberExpression multiply(NumberExpression n); + + default NumberExpression multiply(final Number multiply) { + return this.multiply(Expressions.numberToExpression(multiply)); + } + + NumberExpression divide(NumberExpression n); + + default NumberExpression divide(final Number divide) { + return this.divide(Expressions.numberToExpression(divide)); + } + + NumberExpression add(NumberExpression n); + + default NumberExpression add(final Number add) { + return this.add(Expressions.numberToExpression(add)); + } + + NumberExpression subtract(NumberExpression n); + + default NumberExpression subtract(final Number subtract) { + return this.subtract(Expressions.numberToExpression(subtract)); + } + + NumberExpression max(NumberExpression n); + + NumberExpression min(NumberExpression n); + + IntegerExpression round(); + + NumberExpression round(IntegerExpression place); + + NumberExpression abs(); } diff --git a/driver-core/src/test/functional/com/mongodb/client/model/expressions/ArithmeticExpressionsFunctionalTest.java b/driver-core/src/test/functional/com/mongodb/client/model/expressions/ArithmeticExpressionsFunctionalTest.java new file mode 100644 index 00000000000..4e156c86516 --- /dev/null +++ b/driver-core/src/test/functional/com/mongodb/client/model/expressions/ArithmeticExpressionsFunctionalTest.java @@ -0,0 +1,243 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.client.model.expressions; + +import org.bson.types.Decimal128; +import org.junit.jupiter.api.Test; + +import java.math.BigDecimal; +import java.math.RoundingMode; + +import static com.mongodb.client.model.expressions.Expressions.numberToExpression; +import static com.mongodb.client.model.expressions.Expressions.of; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +@SuppressWarnings("ConstantConditions") +class ArithmeticExpressionsFunctionalTest extends AbstractExpressionsFunctionalTest { + // https://www.mongodb.com/docs/manual/reference/operator/aggregation/#arithmetic-expression-operators + + @Test + public void literalsTest() { + assertExpression(1, of(1), "1"); + assertExpression(1L, of(1L)); + assertExpression(1.0, of(1.0)); + assertExpression(Decimal128.parse("1.0"), of(Decimal128.parse("1.0"))); + assertThrows(IllegalArgumentException.class, () -> of((Decimal128) null)); + + // expression equality differs from bson equality + assertExpression(true, of(1L).eq(of(1.0))); + assertExpression(true, of(1L).eq(of(1))); + + // bson equality; underlying type is preserved + // this behaviour is not defined by the API, but tested for clarity + assertEquals(toBsonValue(1), evaluate(of(1))); + assertEquals(toBsonValue(1L), evaluate(of(1L))); + assertEquals(toBsonValue(1.0), evaluate(of(1.0))); + assertNotEquals(toBsonValue(1), evaluate(of(1L))); + assertNotEquals(toBsonValue(1.0), evaluate(of(1L))); + + // Number conversions; used internally + assertExpression(1, numberToExpression(1)); + assertExpression(1L, numberToExpression(1L)); + assertExpression(1.0, numberToExpression(1.0)); + assertExpression(Decimal128.parse("1.0"), numberToExpression(Decimal128.parse("1.0"))); + assertThrows(IllegalArgumentException.class, + () -> assertExpression("n/a", numberToExpression(BigDecimal.valueOf(1)))); + } + + @Test + public void multiplyTest() { + // https://www.mongodb.com/docs/manual/reference/operator/aggregation/multiply/ + assertExpression( + 2.0 * 2, + of(2.0).multiply(of(2)), + "{'$multiply': [2.0, 2]}"); + + // mixing integers and numbers + IntegerExpression oneInt = of(1); + NumberExpression oneNum = of(1.0); + IntegerExpression resultInt = oneInt.multiply(oneInt); + NumberExpression resultNum = oneNum.multiply(oneNum); + // compile time error if these were IntegerExpressions: + NumberExpression r2 = oneNum.multiply(oneInt); + NumberExpression r3 = oneInt.multiply(oneNum); + assertExpression(1, resultInt); + // 1 is also a valid expected value in our API + assertExpression(1.0, resultNum); + assertExpression(1.0, r2); + assertExpression(1.0, r3); + + // convenience + assertExpression(2.0, of(1.0).multiply(2.0)); + assertExpression(2L, of(1).multiply(2L)); + assertExpression(2, of(1).multiply(2)); + } + + @Test + public void divideTest() { + // https://www.mongodb.com/docs/manual/reference/operator/aggregation/divide/ + assertExpression( + 1.0 / 2.0, + of(1.0).divide(of(2.0)), + "{'$divide': [1.0, 2.0]}"); + // unlike Java's 1/2==0, dividing any type of numbers always yields an + // equal result, in this case represented using a double. + assertExpression( + 0.5, + of(1).divide(of(2)), + "{'$divide': [1, 2]}"); + + // however, there are differences in evaluation between numbers + // represented using Decimal128 and double: + assertExpression( + 2.5242187499999997, + of(3.231).divide(of(1.28))); + assertExpression( + Decimal128.parse("2.52421875"), + of(Decimal128.parse("3.231")).divide(of(Decimal128.parse("1.28")))); + assertExpression( + Decimal128.parse("2.52421875"), + of(Decimal128.parse("3.231")).divide(of(1.28))); + assertExpression( + Decimal128.parse("2.524218750000"), + of(3.231).divide(of(Decimal128.parse("1.28")))); + + // convenience + assertExpression(0.5, of(1.0).divide(2.0)); + assertExpression(0.5, of(1).divide(2.0)); + assertExpression(0.5, of(1).divide(2L)); + assertExpression(0.5, of(1).divide(2)); + + // divide always returns a Number, so the method is not on IntegerExpression + } + + @Test + public void addTest() { + // https://www.mongodb.com/docs/manual/reference/operator/aggregation/add/ + IntegerExpression actual = of(2).add(of(2)); + assertExpression( + 2 + 2, actual, + "{'$add': [2, 2]}"); + assertExpression( + 2.0 + 2, + of(2.0).add(of(2)), + "{'$add': [2.0, 2]}"); + + // convenience + assertExpression(3.0, of(1.0).add(2.0)); + assertExpression(3L, of(1).add(2L)); + assertExpression(3, of(1).add(2)); + + // https://www.mongodb.com/docs/manual/reference/operator/aggregation/sum/ + // sum's alternative behaviour exists for purposes of reduction, but is + // inconsistent with multiply, and potentially confusing. Unimplemented. + } + + @Test + public void subtractTest() { + // https://www.mongodb.com/docs/manual/reference/operator/aggregation/subtract/ + IntegerExpression actual = of(2).subtract(of(2)); + assertExpression( + 0, + actual, + "{'$subtract': [2, 2]} "); + assertExpression( + 2.0 - 2, + of(2.0).subtract(of(2)), + "{'$subtract': [2.0, 2]} "); + + // convenience + assertExpression(-1.0, of(1.0).subtract(2.0)); + assertExpression(-1, of(1).subtract(2)); + } + + @Test + public void maxTest() { + // https://www.mongodb.com/docs/manual/reference/operator/aggregation/max/ + IntegerExpression actual = of(-2).max(of(2)); + assertExpression( + Math.max(-2, 2), + actual, + "{'$max': [-2, 2]}"); + assertExpression( + Math.max(-2.0, 2.0), + of(-2.0).max(of(2.0)), + "{'$max': [-2.0, 2.0]}"); + } + + @Test + public void minTest() { + // https://www.mongodb.com/docs/manual/reference/operator/aggregation/min/ (63) + IntegerExpression actual = of(-2).min(of(2)); + assertExpression( + Math.min(-2, 2), + actual, + "{'$min': [-2, 2]}"); + assertExpression( + Math.min(-2.0, 2.0), + of(-2.0).min(of(2.0)), + "{'$min': [-2.0, 2.0]}"); + } + + @Test + public void roundTest() { + // https://www.mongodb.com/docs/manual/reference/operator/aggregation/round/ + IntegerExpression actual = of(5.5).round(); + assertExpression( + 6.0, + actual, + "{'$round': 5.5} "); + NumberExpression actualNum = of(5.5).round(of(0)); + assertExpression( + new BigDecimal("5.5").setScale(0, RoundingMode.HALF_EVEN).doubleValue(), + actualNum, + "{'$round': [5.5, 0]} "); + // unlike Java, uses banker's rounding (half_even) + assertExpression( + 2.0, + of(2.5).round(), + "{'$round': 2.5} "); + assertExpression( + new BigDecimal("-5.5").setScale(0, RoundingMode.HALF_EVEN).doubleValue(), + of(-5.5).round()); + // to place + assertExpression( + 555.55, + of(555.555).round(of(2)), + "{'$round': [555.555, 2]} "); + assertExpression( + 600.0, + of(555.555).round(of(-2)), + "{'$round': [555.555, -2]} "); + } + + @Test + public void absTest() { + // https://www.mongodb.com/docs/manual/reference/operator/aggregation/round/ (?) + assertExpression( + Math.abs(-2.0), + of(-2.0).abs(), + "{'$abs': -2.0}"); + // integer + IntegerExpression abs = of(-2).abs(); + assertExpression( + Math.abs(-2), abs, + "{'$abs': -2}"); + } +} diff --git a/driver-core/src/test/functional/com/mongodb/client/model/expressions/ArrayExpressionsFunctionalTest.java b/driver-core/src/test/functional/com/mongodb/client/model/expressions/ArrayExpressionsFunctionalTest.java index acee30a0628..b056a9fb8d0 100644 --- a/driver-core/src/test/functional/com/mongodb/client/model/expressions/ArrayExpressionsFunctionalTest.java +++ b/driver-core/src/test/functional/com/mongodb/client/model/expressions/ArrayExpressionsFunctionalTest.java @@ -25,7 +25,7 @@ import static com.mongodb.client.model.expressions.Expressions.of; import static com.mongodb.client.model.expressions.Expressions.ofBooleanArray; -@SuppressWarnings({"PointlessBooleanExpression", "ConstantConditions", "Convert2MethodRef"}) +@SuppressWarnings({"ConstantConditions", "Convert2MethodRef"}) class ArrayExpressionsFunctionalTest extends AbstractExpressionsFunctionalTest { // https://www.mongodb.com/docs/manual/reference/operator/aggregation/#array-expression-operators // (Incomplete) diff --git a/driver-core/src/test/functional/com/mongodb/client/model/expressions/ComparisonExpressionsFunctionalTest.java b/driver-core/src/test/functional/com/mongodb/client/model/expressions/ComparisonExpressionsFunctionalTest.java index b2d32c352b0..984119adaa8 100644 --- a/driver-core/src/test/functional/com/mongodb/client/model/expressions/ComparisonExpressionsFunctionalTest.java +++ b/driver-core/src/test/functional/com/mongodb/client/model/expressions/ComparisonExpressionsFunctionalTest.java @@ -42,7 +42,8 @@ class ComparisonExpressionsFunctionalTest extends AbstractExpressionsFunctionalT static R ofRem() { // $$REMOVE is intentionally not exposed to users - return new MqlExpression<>((cr) -> new BsonString("$$REMOVE")).assertImplementsAllExpressions(); + return new MqlExpression<>((cr) -> new MqlExpression.AstPlaceholder(new BsonString("$$REMOVE"))) + .assertImplementsAllExpressions(); } // https://www.mongodb.com/docs/manual/reference/bson-type-comparison-order/#std-label-bson-types-comparison-order