diff --git a/pandas-stubs/_typing.pyi b/pandas-stubs/_typing.pyi index 0fb8eed2a..6bdb6f91e 100644 --- a/pandas-stubs/_typing.pyi +++ b/pandas-stubs/_typing.pyi @@ -77,6 +77,7 @@ class FulldatetimeDict(YearMonthDayDict, total=False): NpDtype: TypeAlias = str | np.dtype[np.generic] | type[str | complex | bool | object] Dtype: TypeAlias = ExtensionDtype | NpDtype DtypeArg: TypeAlias = Dtype | dict[Any, Dtype] +DtypeBackend: TypeAlias = Literal["pyarrow", "numpy_nullable"] BooleanDtypeArg: TypeAlias = ( # Builtin bool type and its string alias type[bool] # noqa: Y030 diff --git a/pandas-stubs/io/sql.pyi b/pandas-stubs/io/sql.pyi index 589dc03ac..f26f9bece 100644 --- a/pandas-stubs/io/sql.pyi +++ b/pandas-stubs/io/sql.pyi @@ -16,8 +16,10 @@ import sqlalchemy.engine import sqlalchemy.sql.expression from typing_extensions import TypeAlias +from pandas._libs.lib import NoDefault from pandas._typing import ( DtypeArg, + DtypeBackend, npt, ) @@ -84,6 +86,8 @@ def read_sql( columns: list[str] = ..., *, chunksize: int, + dtype: DtypeArg | None = ..., + dtype_backend: DtypeBackend | NoDefault = ..., ) -> Generator[DataFrame, None, None]: ... @overload def read_sql( @@ -95,6 +99,8 @@ def read_sql( parse_dates: list[str] | dict[str, str] | dict[str, dict[str, Any]] | None = ..., columns: list[str] = ..., chunksize: None = ..., + dtype: DtypeArg | None = ..., + dtype_backend: DtypeBackend | NoDefault = ..., ) -> DataFrame: ... class PandasSQL(PandasObject): diff --git a/tests/test_io.py b/tests/test_io.py index e259b82ec..24ff74262 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -1204,3 +1204,68 @@ def test_sqlalchemy_text() -> None: assert_type(read_sql(sql_select, con=conn), DataFrame), DataFrame, ) + + +def test_read_sql_dtype() -> None: + with ensure_clean() as path: + conn = sqlite3.connect(path) + df = pd.DataFrame( + data=[[0, "10/11/12"], [1, "12/11/10"]], + columns=["int_column", "date_column"], + ) + check(assert_type(df.to_sql("test_data", con=conn), Union[int, None]), int) + check( + assert_type( + pd.read_sql( + "SELECT int_column, date_column FROM test_data", + con=conn, + dtype=None, + ), + pd.DataFrame, + ), + pd.DataFrame, + ) + check( + assert_type( + pd.read_sql( + "SELECT int_column, date_column FROM test_data", + con=conn, + dtype={"int_column": int}, + ), + pd.DataFrame, + ), + pd.DataFrame, + ) + check(assert_type(DF.to_sql("test", con=conn), Union[int, None]), int) + + check( + assert_type( + read_sql("select * from test", con=conn, dtype=int), + pd.DataFrame, + ), + pd.DataFrame, + ) + conn.close() + + +def test_read_sql_dtype_backend() -> None: + with ensure_clean() as path: + conn2 = sqlite3.connect(path) + check(assert_type(DF.to_sql("test", con=conn2), Union[int, None]), int) + check( + assert_type( + read_sql("select * from test", con=conn2, dtype_backend="pyarrow"), + pd.DataFrame, + ), + pd.DataFrame, + ) + check( + assert_type( + read_sql( + "select * from test", con=conn2, dtype_backend="numpy_nullable" + ), + pd.DataFrame, + ), + pd.DataFrame, + ) + conn2.close()