|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
| 15 | +import pytest |
15 | 16 | import pandas as pd
|
16 | 17 | import pyarrow as pa
|
17 | 18 |
|
@@ -55,10 +56,26 @@ def test_sql_scalar_on_scalars_null_index(scalars_df_null_index):
|
55 | 56 | assert len(result) == len(scalars_df_null_index)
|
56 | 57 |
|
57 | 58 |
|
58 |
| -def test_sql_scalar_w_bool_series(scalars_df_index): |
59 |
| - series: bpd.Series = scalars_df_index["bool_col"] |
60 |
| - result = bbq.sql_scalar("CAST({0} AS INT64)", [series]) |
61 |
| - expected = series.astype(dtypes.INT_DTYPE) |
| 59 | +@pytest.mark.parametrize( |
| 60 | + ("column_name"), |
| 61 | + [ |
| 62 | + pytest.param("bool_col"), |
| 63 | + pytest.param("bytes_col"), |
| 64 | + pytest.param("date_col"), |
| 65 | + pytest.param("datetime_col"), |
| 66 | + pytest.param("geography_col"), |
| 67 | + pytest.param("int64_col"), |
| 68 | + pytest.param("numeric_col"), |
| 69 | + pytest.param("float64_col"), |
| 70 | + pytest.param("string_col"), |
| 71 | + pytest.param("time_col"), |
| 72 | + pytest.param("timestamp_col"), |
| 73 | + ], |
| 74 | +) |
| 75 | +def test_sql_scalar_w_all_scalar_output(scalars_df_index, column_name): |
| 76 | + series: bpd.Series = scalars_df_index[column_name] |
| 77 | + result = bbq.sql_scalar("{0}", [series]) |
| 78 | + expected = series |
62 | 79 | expected.name = None
|
63 | 80 | pd.testing.assert_series_equal(result.to_pandas(), expected.to_pandas())
|
64 | 81 |
|
|
0 commit comments