diff --git a/tests/unit/test_compiler.py b/tests/unit/test_compiler.py index def13cfd..cc9116e3 100644 --- a/tests/unit/test_compiler.py +++ b/tests/unit/test_compiler.py @@ -28,6 +28,23 @@ from sqlalchemy.sql.functions import rollup, cube, grouping_sets +@pytest.fixture +def table(faux_conn, metadata): + # Fixture to create a sample table for testing + + table = setup_table( + faux_conn, + "table1", + metadata, + sqlalchemy.Column("foo", sqlalchemy.Integer), + sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)), + ) + + yield table + + table.drop(faux_conn) + + def test_constraints_are_ignored(faux_conn, metadata): sqlalchemy.Table( "ref", @@ -282,85 +299,92 @@ def test_no_implicit_join_for_inner_unnest_no_table2_column(faux_conn, metadata) assert found_outer_sql == expected_outer_sql -def test_grouping_sets(faux_conn, metadata): - table = setup_table( - faux_conn, - "table1", - metadata, - sqlalchemy.Column("foo", sqlalchemy.Integer), - sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)), +grouping_ops = ( + "grouping_op, grouping_op_func", + [("GROUPING SETS", grouping_sets), ("ROLLUP", rollup), ("CUBE", cube)], +) + + +@pytest.mark.parametrize(*grouping_ops) +def test_grouping_ops_vs_single_column(faux_conn, table, grouping_op, grouping_op_func): + # Tests each of the grouping ops against a single column + + q = sqlalchemy.select(table.c.foo).group_by(grouping_op_func(table.c.foo)) + found_sql = q.compile(faux_conn).string + + expected_sql = ( + f"SELECT `table1`.`foo` \n" + f"FROM `table1` GROUP BY {grouping_op}(`table1`.`foo`)" ) + assert found_sql == expected_sql + + +@pytest.mark.parametrize(*grouping_ops) +def test_grouping_ops_vs_multi_columns(faux_conn, table, grouping_op, grouping_op_func): + # Tests each of the grouping ops against multiple columns + q = sqlalchemy.select(table.c.foo, table.c.bar).group_by( - grouping_sets(table.c.foo, table.c.bar) + grouping_op_func(table.c.foo, table.c.bar) ) + found_sql = q.compile(faux_conn).string expected_sql = ( - "SELECT `table1`.`foo`, `table1`.`bar` \n" - "FROM `table1` GROUP BY GROUPING SETS(`table1`.`foo`, `table1`.`bar`)" + f"SELECT `table1`.`foo`, `table1`.`bar` \n" + f"FROM `table1` GROUP BY {grouping_op}(`table1`.`foo`, `table1`.`bar`)" ) - found_sql = q.compile(faux_conn).string + assert found_sql == expected_sql -def test_rollup(faux_conn, metadata): - table = setup_table( - faux_conn, - "table1", - metadata, - sqlalchemy.Column("foo", sqlalchemy.Integer), - sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)), - ) +@pytest.mark.parametrize(*grouping_ops) +def test_grouping_op_with_grouping_op(faux_conn, table, grouping_op, grouping_op_func): + # Tests multiple grouping ops in a single statement q = sqlalchemy.select(table.c.foo, table.c.bar).group_by( - rollup(table.c.foo, table.c.bar) + grouping_op_func(table.c.foo, table.c.bar), grouping_op_func(table.c.foo) ) + found_sql = q.compile(faux_conn).string expected_sql = ( - "SELECT `table1`.`foo`, `table1`.`bar` \n" - "FROM `table1` GROUP BY ROLLUP(`table1`.`foo`, `table1`.`bar`)" + f"SELECT `table1`.`foo`, `table1`.`bar` \n" + f"FROM `table1` GROUP BY {grouping_op}(`table1`.`foo`, `table1`.`bar`), {grouping_op}(`table1`.`foo`)" ) - found_sql = q.compile(faux_conn).string + assert found_sql == expected_sql -def test_cube(faux_conn, metadata): - table = setup_table( - faux_conn, - "table1", - metadata, - sqlalchemy.Column("foo", sqlalchemy.Integer), - sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)), - ) +@pytest.mark.parametrize(*grouping_ops) +def test_grouping_ops_vs_group_by(faux_conn, table, grouping_op, grouping_op_func): + # Tests grouping op against regular group by statement q = sqlalchemy.select(table.c.foo, table.c.bar).group_by( - cube(table.c.foo, table.c.bar) + table.c.foo, grouping_op_func(table.c.bar) ) + found_sql = q.compile(faux_conn).string expected_sql = ( - "SELECT `table1`.`foo`, `table1`.`bar` \n" - "FROM `table1` GROUP BY CUBE(`table1`.`foo`, `table1`.`bar`)" + f"SELECT `table1`.`foo`, `table1`.`bar` \n" + f"FROM `table1` GROUP BY `table1`.`foo`, {grouping_op}(`table1`.`bar`)" ) - found_sql = q.compile(faux_conn).string + assert found_sql == expected_sql -def test_multiple_grouping_sets(faux_conn, metadata): - table = setup_table( - faux_conn, - "table1", - metadata, - sqlalchemy.Column("foo", sqlalchemy.Integer), - sqlalchemy.Column("bar", sqlalchemy.ARRAY(sqlalchemy.Integer)), - ) +@pytest.mark.parametrize(*grouping_ops) +def test_complex_grouping_ops_vs_nested_grouping_ops( + faux_conn, table, grouping_op, grouping_op_func +): + # Tests grouping ops nested within grouping ops q = sqlalchemy.select(table.c.foo, table.c.bar).group_by( - grouping_sets(table.c.foo, table.c.bar), grouping_sets(table.c.foo) + grouping_sets(table.c.foo, grouping_op_func(table.c.bar)) ) + found_sql = q.compile(faux_conn).string expected_sql = ( - "SELECT `table1`.`foo`, `table1`.`bar` \n" - "FROM `table1` GROUP BY GROUPING SETS(`table1`.`foo`, `table1`.`bar`), GROUPING SETS(`table1`.`foo`)" + f"SELECT `table1`.`foo`, `table1`.`bar` \n" + f"FROM `table1` GROUP BY GROUPING SETS(`table1`.`foo`, {grouping_op}(`table1`.`bar`))" ) - found_sql = q.compile(faux_conn).string + assert found_sql == expected_sql