diff --git a/src/datacustomcode/io/reader/query_api.py b/src/datacustomcode/io/reader/query_api.py index 29c5c25..f41e767 100644 --- a/src/datacustomcode/io/reader/query_api.py +++ b/src/datacustomcode/io/reader/query_api.py @@ -21,6 +21,7 @@ Union, ) +import pandas.api.types as pd_types from pyspark.sql.types import ( BooleanType, DoubleType, @@ -48,8 +49,6 @@ "object": StringType(), "int64": LongType(), "float64": DoubleType(), - "datetime64[ns]": TimestampType(), - "datetime64[ns, UTC]": TimestampType(), "bool": BooleanType(), } @@ -59,7 +58,11 @@ def _pandas_to_spark_schema( ) -> StructType: fields = [] for column, dtype in pandas_df.dtypes.items(): - spark_type = PANDAS_TYPE_MAPPING.get(str(dtype), StringType()) + spark_type: AtomicType + if pd_types.is_datetime64_any_dtype(dtype): + spark_type = TimestampType() + else: + spark_type = PANDAS_TYPE_MAPPING.get(str(dtype), StringType()) fields.append(StructField(column, spark_type, nullable)) return StructType(fields) diff --git a/tests/io/reader/test_query_api.py b/tests/io/reader/test_query_api.py index c9a4e86..b0bde69 100644 --- a/tests/io/reader/test_query_api.py +++ b/tests/io/reader/test_query_api.py @@ -14,6 +14,7 @@ StringType, StructField, StructType, + TimestampType, ) import pytest @@ -59,6 +60,55 @@ def test_pandas_to_spark_schema_nullable(self): schema = _pandas_to_spark_schema(df, nullable=False) assert not schema.fields[0].nullable + def test_pandas_to_spark_schema_datetime_types(self): + """Test conversion of pandas datetime types to Spark TimestampType.""" + + # Create test data with different datetime types + data = { + "datetime_ns": pd.to_datetime( + ["2023-01-01 10:00:00", "2023-01-02 11:00:00"] + ), + "datetime_ns_utc": pd.to_datetime( + ["2023-01-01 10:00:00", "2023-01-02 11:00:00"], utc=True + ), + "datetime_ms": pd.to_datetime( + ["2023-01-01 10:00:00", "2023-01-02 11:00:00"] + ).astype("datetime64[ms]"), + "datetime_ms_utc": pd.to_datetime( + ["2023-01-01 10:00:00", "2023-01-02 11:00:00"], utc=True + ) + .tz_localize(None) + .astype("datetime64[ms]"), + } + df = pd.DataFrame(data) + + # Convert to Spark schema + schema = _pandas_to_spark_schema(df) + + # Verify the schema + assert isinstance(schema, StructType) + assert len(schema.fields) == 4 + + # Check that all datetime columns map to TimestampType + field_dict = {field.name: field for field in schema.fields} + for field_name in [ + "datetime_ns", + "datetime_ns_utc", + "datetime_ms", + "datetime_ms_utc", + ]: + assert isinstance(field_dict[field_name].dataType, TimestampType), ( + f"Field {field_name} should be TimestampType, " + f"got {type(field_dict[field_name].dataType)}" + ) + assert field_dict[field_name].nullable + + # Verify the actual pandas dtypes to ensure our test data has the expected types + assert str(df["datetime_ns"].dtype) == "datetime64[ns]" + assert str(df["datetime_ns_utc"].dtype) == "datetime64[ns, UTC]" + assert str(df["datetime_ms"].dtype) == "datetime64[ms]" + assert str(df["datetime_ms_utc"].dtype) == "datetime64[ms]" + # Completely isolated test class for QueryAPIDataCloudReader @pytest.mark.usefixtures("patch_all_requests")