@@ -455,37 +455,35 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
455455 test(" isInCollection: Scala Collection" ) {
456456 val df = Seq ((1 , " x" ), (2 , " y" ), (3 , " z" )).toDF(" a" , " b" )
457457
458+ Seq (1 , 2 ).foreach { conf =>
459+ withSQLConf(SQLConf .OPTIMIZER_INSET_CONVERSION_THRESHOLD .key -> conf.toString) {
460+ if (conf <= 1 ) {
461+ assert($" a" .isInCollection(Seq (3 , 1 )).expr.isInstanceOf [InSet ], " Expect expr to be InSet" )
462+ } else {
463+ assert($" a" .isInCollection(Seq (3 , 1 )).expr.isInstanceOf [In ], " Expect expr to be In" )
464+ }
458465
459- // Test with different types of collections
460- checkAnswer(df.filter($" a" .isInCollection(Seq (3 , 1 ))),
461- df.collect().toSeq.filter(r => r.getInt(0 ) == 3 || r.getInt(0 ) == 1 ))
462- checkAnswer(df.filter($" a" .isInCollection(Seq (1 , 2 ).toSet)),
463- df.collect().toSeq.filter(r => r.getInt(0 ) == 1 || r.getInt(0 ) == 2 ))
464- checkAnswer(df.filter($" a" .isInCollection(Seq (3 , 2 ).toArray)),
465- df.collect().toSeq.filter(r => r.getInt(0 ) == 3 || r.getInt(0 ) == 2 ))
466- checkAnswer(df.filter($" a" .isInCollection(Seq (3 , 1 ).toList)),
467- df.collect().toSeq.filter(r => r.getInt(0 ) == 3 || r.getInt(0 ) == 1 ))
468-
469- assert($" a" .isInCollection(Seq (3 , 1 )).expr.isInstanceOf [In ], " Expect expr to be In" )
470-
471- withSQLConf(SQLConf .OPTIMIZER_INSET_CONVERSION_THRESHOLD .key -> " 1" ) {
472- checkAnswer(df.filter($" a" .isInCollection(Seq (3 , 1 ))),
473- df.collect().toSeq.filter(r => r.getInt(0 ) == 3 || r.getInt(0 ) == 1 ))
474- checkAnswer(df.filter($" a" .isInCollection(Seq (1 , 2 ).toSet)),
475- df.collect().toSeq.filter(r => r.getInt(0 ) == 1 || r.getInt(0 ) == 2 ))
476-
477- assert($" a" .isInCollection(Seq (3 , 1 )).expr.isInstanceOf [InSet ], " Expect expr to be InSet" )
478- }
466+ // Test with different types of collections
467+ checkAnswer(df.filter($" a" .isInCollection(Seq (3 , 1 ))),
468+ df.collect().toSeq.filter(r => r.getInt(0 ) == 3 || r.getInt(0 ) == 1 ))
469+ checkAnswer(df.filter($" a" .isInCollection(Seq (1 , 2 ).toSet)),
470+ df.collect().toSeq.filter(r => r.getInt(0 ) == 1 || r.getInt(0 ) == 2 ))
471+ checkAnswer(df.filter($" a" .isInCollection(Seq (3 , 2 ).toArray)),
472+ df.collect().toSeq.filter(r => r.getInt(0 ) == 3 || r.getInt(0 ) == 2 ))
473+ checkAnswer(df.filter($" a" .isInCollection(Seq (3 , 1 ).toList)),
474+ df.collect().toSeq.filter(r => r.getInt(0 ) == 3 || r.getInt(0 ) == 1 ))
479475
480- val df2 = Seq ((1 , Seq (1 )), (2 , Seq (2 )), (3 , Seq (3 ))).toDF(" a" , " b" )
476+ val df2 = Seq ((1 , Seq (1 )), (2 , Seq (2 )), (3 , Seq (3 ))).toDF(" a" , " b" )
481477
482- val e = intercept[AnalysisException ] {
483- df2.filter($" a" .isInCollection(Seq ($" b" )))
484- }
485- Seq (" cannot resolve" , " due to data type mismatch: Arguments must be same type but were" )
486- .foreach { s =>
487- assert(e.getMessage.toLowerCase(Locale .ROOT ).contains(s.toLowerCase(Locale .ROOT )))
478+ val e = intercept[AnalysisException ] {
479+ df2.filter($" a" .isInCollection(Seq ($" b" )))
480+ }
481+ Seq (" cannot resolve" ,
482+ " due to data type mismatch: Arguments must be same type but were" ).foreach { s =>
483+ assert(e.getMessage.toLowerCase(Locale .ROOT ).contains(s.toLowerCase(Locale .ROOT )))
484+ }
488485 }
486+ }
489487 }
490488
491489 test(" &&" ) {
0 commit comments