diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 31dc5a58e68e..103ddc7d04fd 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -84,8 +84,10 @@ package object dsl { def > (other: Expression) = GreaterThan(expr, other) def >= (other: Expression) = GreaterThanOrEqual(expr, other) def === (other: Expression) = EqualTo(expr, other) + def -=- (other: Expression) = EqualTo(expr, other) def <=> (other: Expression) = EqualNullSafe(expr, other) def !== (other: Expression) = Not(EqualTo(expr, other)) + def !=- (other: Expression) = Not(EqualTo(expr, other)) def in(list: Expression*) = In(expr, list) @@ -149,7 +151,47 @@ package object dsl { def lower(e: Expression) = Lower(e) implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s = sym.name } - // TODO more implicit class for literal? + + implicit class LiteralOnTheLeft[T](x: T) { + def literal = Literal(x) + + def + (other: Expression) = Add(literal, other) + def - (other: Expression) = Subtract(literal, other) + def * (other: Expression) = Multiply(literal, other) + def / (other: Expression) = Divide(literal, other) + def % (other: Expression) = Remainder(literal, other) + + def && (other: Expression) = And(literal, other) + def || (other: Expression) = Or(literal, other) + def < (other: Expression) = LessThan(literal, other) + def <= (other: Expression) = LessThanOrEqual(literal, other) + def > (other: Expression) = GreaterThan(literal, other) + def >= (other: Expression) = GreaterThanOrEqual(literal, other) + /* === not allowed because it conflicts with scalatest. */ + def -=- (other: Expression) = EqualTo(literal, other) + def <=> (other: Expression) = EqualNullSafe(literal, other) + /* !== not allowed because it conflicts with scalatest. */ + def !=- (other: Expression) = Not(EqualTo(literal, other)) + + def + (other: Symbol) = Add(literal, other) + def - (other: Symbol) = Subtract(literal, other) + def * (other: Symbol) = Multiply(literal, other) + def / (other: Symbol) = Divide(literal, other) + def % (other: Symbol) = Remainder(literal, other) + + def && (other: Symbol) = And(literal, other) + def || (other: Symbol) = Or(literal, other) + def < (other: Symbol) = LessThan(literal, other) + def <= (other: Symbol) = LessThanOrEqual(literal, other) + def > (other: Symbol) = GreaterThan(literal, other) + def >= (other: Symbol) = GreaterThanOrEqual(literal, other) + /* === not allowed because it conflicts with scalatest. */ + def -=- (other: Symbol) = EqualTo(literal, other) + def <=> (other: Symbol) = EqualNullSafe(literal, other) + /* !== not allowed because it conflicts with scalatest. */ + def !=- (other: Symbol) = Not(EqualTo(literal, other)) + } + implicit class DslString(val s: String) extends ImplicitOperators { override def expr: Expression = Literal(s) def attr = analysis.UnresolvedAttribute(s) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 918996f11da2..d2ddb7049675 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -774,4 +774,24 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(c1 ^ c2, 3, row) checkEvaluation(~c1, -2, row) } + + /* + * Testing the DSL conversions which allow literals on the left hand + * side of an expression. The DSL conversions collide with the + * scalatest === operator so we can the scalatest conversion + * explicitly: assert(X === Y) --> assert(EQ(X).===(Y)) + */ + import org.scalatest.Assertions.{convertToEqualizer => EQ} + test("expressions with a literal on the left") { + assert(EQ(-1 + 'x).===(Add(-1, 'x))) + assert(EQ(3 + 'x * 'y).===(Add(3, Multiply('x, 'y)))) + assert(EQ(0 < 'x).===(LessThan(0, 'x))) + assert(EQ(1.5 -=- 'x).===(EqualTo(1.5, 'x))) + assert(EQ(false !=- 'x).===(Not(EqualTo(false, 'x)))) + assert(EQ("a string" >= 'x).===(GreaterThanOrEqual("a string", 'x))) + //assert(EQ(RichDate("2014-11-05") > 'date).===(GreaterThan(RichDate("2014-11-05"), 'date))) + //assert(EQ(RichTimestamp("2014-11-05 12:34:56.789") < 'now) + // .===(LessThan(RichTimestamp("2014-11-05 12:34:56.789"), 'now))) + } + }