2121from datafusion import col , lit , functions as F
2222from util import get_answer_file
2323
24+
2425def df_selection (col_name , col_type ):
2526 if col_type == pa .float64 () or isinstance (col_type , pa .Decimal128Type ):
2627 return F .round (col (col_name ), lit (2 )).alias (col_name )
@@ -29,14 +30,16 @@ def df_selection(col_name, col_type):
2930 else :
3031 return col (col_name )
3132
33+
3234def load_schema (col_name , col_type ):
3335 if col_type == pa .int64 () or col_type == pa .int32 ():
3436 return col_name , pa .string ()
3537 elif isinstance (col_type , pa .Decimal128Type ):
3638 return col_name , pa .float64 ()
3739 else :
3840 return col_name , col_type
39-
41+
42+
4043def expected_selection (col_name , col_type ):
4144 if col_type == pa .int64 () or col_type == pa .int32 ():
4245 return F .trim (col (col_name )).cast (col_type ).alias (col_name )
@@ -45,20 +48,23 @@ def expected_selection(col_name, col_type):
4548 else :
4649 return col (col_name )
4750
51+
4852def selections_and_schema (original_schema ):
49- columns = [ (c , original_schema .field (c ).type ) for c in original_schema .names ]
53+ columns = [(c , original_schema .field (c ).type ) for c in original_schema .names ]
5054
51- df_selections = [ df_selection (c , t ) for (c , t ) in columns ]
52- expected_schema = [ load_schema (c , t ) for (c , t ) in columns ]
53- expected_selections = [ expected_selection (c , t ) for (c , t ) in columns ]
55+ df_selections = [df_selection (c , t ) for (c , t ) in columns ]
56+ expected_schema = [load_schema (c , t ) for (c , t ) in columns ]
57+ expected_selections = [expected_selection (c , t ) for (c , t ) in columns ]
5458
5559 return (df_selections , expected_schema , expected_selections )
5660
61+
5762def check_q17 (df ):
5863 raw_value = float (df .collect ()[0 ]["avg_yearly" ][0 ].as_py ())
5964 value = round (raw_value , 2 )
6065 assert abs (value - 348406.05 ) < 0.001
6166
67+
6268@pytest .mark .parametrize (
6369 ("query_code" , "answer_file" ),
6470 [
@@ -72,9 +78,7 @@ def check_q17(df):
7278 ("q08_market_share" , "q8" ),
7379 ("q09_product_type_profit_measure" , "q9" ),
7480 ("q10_returned_item_reporting" , "q10" ),
75- pytest .param (
76- "q11_important_stock_identification" , "q11" ,
77- ),
81+ ("q11_important_stock_identification" , "q11" ),
7882 ("q12_ship_mode_order_priority" , "q12" ),
7983 ("q13_customer_distribution" , "q13" ),
8084 ("q14_promotion_effect" , "q14" ),
@@ -97,13 +101,20 @@ def test_tpch_query_vs_answer_file(query_code: str, answer_file: str):
97101 if answer_file == "q17" :
98102 return check_q17 (df )
99103
100- (df_selections , expected_schema , expected_selections ) = selections_and_schema (df .schema ())
104+ (df_selections , expected_schema , expected_selections ) = selections_and_schema (
105+ df .schema ()
106+ )
101107
102108 df = df .select (* df_selections )
103109
104110 read_schema = pa .schema (expected_schema )
105111
106- df_expected = module .ctx .read_csv (get_answer_file (answer_file ), schema = read_schema , delimiter = "|" , file_extension = ".out" )
112+ df_expected = module .ctx .read_csv (
113+ get_answer_file (answer_file ),
114+ schema = read_schema ,
115+ delimiter = "|" ,
116+ file_extension = ".out" ,
117+ )
107118
108119 df_expected = df_expected .select (* expected_selections )
109120
0 commit comments