diff --git a/src/datacustomcode/io/reader/query_api.py b/src/datacustomcode/io/reader/query_api.py index 775cb29..29c5c25 100644 --- a/src/datacustomcode/io/reader/query_api.py +++ b/src/datacustomcode/io/reader/query_api.py @@ -43,7 +43,7 @@ logger = logging.getLogger(__name__) -SQL_QUERY_TEMPLATE: Final = "SELECT * FROM {}" +SQL_QUERY_TEMPLATE: Final = "SELECT * FROM {} LIMIT {}" PANDAS_TYPE_MAPPING = { "object": StringType(), "int64": LongType(), @@ -85,19 +85,25 @@ def __init__(self, spark: SparkSession) -> None: ) def read_dlo( - self, name: str, schema: Union[AtomicType, StructType, str, None] = None + self, + name: str, + schema: Union[AtomicType, StructType, str, None] = None, + row_limit: int = 1000, ) -> PySparkDataFrame: """ - Read a Data Lake Object (DLO) from the Data Cloud. + Read a Data Lake Object (DLO) from the Data Cloud, limited to a number of rows. Args: name (str): The name of the DLO. schema (Optional[Union[AtomicType, StructType, str]]): Schema of the DLO. + row_limit (int): Maximum number of rows to fetch. Returns: PySparkDataFrame: The PySpark DataFrame. """ - pandas_df = self._conn.get_pandas_dataframe(SQL_QUERY_TEMPLATE.format(name)) + pandas_df = self._conn.get_pandas_dataframe( + SQL_QUERY_TEMPLATE.format(name, row_limit) + ) if not schema: # auto infer schema schema = _pandas_to_spark_schema(pandas_df) @@ -105,9 +111,14 @@ def read_dlo( return spark_dataframe def read_dmo( - self, name: str, schema: Union[AtomicType, StructType, str, None] = None + self, + name: str, + schema: Union[AtomicType, StructType, str, None] = None, + row_limit: int = 1000, ) -> PySparkDataFrame: - pandas_df = self._conn.get_pandas_dataframe(SQL_QUERY_TEMPLATE.format(name)) + pandas_df = self._conn.get_pandas_dataframe( + SQL_QUERY_TEMPLATE.format(name, row_limit) + ) if not schema: # auto infer schema schema = _pandas_to_spark_schema(pandas_df) diff --git a/src/datacustomcode/io/writer/print.py b/src/datacustomcode/io/writer/print.py index 0c78ee1..7b9ffd4 100644 --- a/src/datacustomcode/io/writer/print.py +++ b/src/datacustomcode/io/writer/print.py @@ -14,20 +14,73 @@ # limitations under the License. -from pyspark.sql import DataFrame as PySparkDataFrame +from typing import Optional +from pyspark.sql import DataFrame as PySparkDataFrame, SparkSession + +from datacustomcode.io.reader.query_api import QueryAPIDataCloudReader from datacustomcode.io.writer.base import BaseDataCloudWriter, WriteMode class PrintDataCloudWriter(BaseDataCloudWriter): CONFIG_NAME = "PrintDataCloudWriter" + def __init__( + self, spark: SparkSession, reader: Optional[QueryAPIDataCloudReader] = None + ) -> None: + super().__init__(spark) + self.reader = QueryAPIDataCloudReader(self.spark) if reader is None else reader + + def validate_dataframe_columns_against_dlo( + self, + dataframe: PySparkDataFrame, + dlo_name: str, + ) -> None: + """ + Validates that all columns in the given dataframe exist in the DLO schema. + + Args: + dataframe (PySparkDataFrame): The DataFrame to validate. + dlo_name (str): The name of the DLO to check against. + reader (QueryAPIDataCloudReader): The reader to use for schema retrieval. + + Raises: + ValueError: If any columns in the dataframe are not present in the DLO + schema. + """ + # Get DLO schema (no data, just schema) + dlo_df = self.reader.read_dlo(dlo_name, row_limit=0) + dlo_columns = set(dlo_df.columns) + df_columns = set(dataframe.columns) + + # Find columns in dataframe not present in DLO + extra_columns = df_columns - dlo_columns + if extra_columns: + raise ValueError( + "The following columns are not present in the \n" + f"DLO '{dlo_name}': {sorted(extra_columns)}.\n" + "To fix this error, you can either:\n" + " - Drop these columns from your DataFrame before writing, e.g.,\n" + " dataframe = dataframe.drop({cols})\n" + " - Or, add these columns to the DLO schema in Data Cloud.".format( + cols=sorted(extra_columns) + ) + ) + def write_to_dlo( self, name: str, dataframe: PySparkDataFrame, write_mode: WriteMode ) -> None: + + # Validate columns before proceeding + self.validate_dataframe_columns_against_dlo(dataframe, name) + dataframe.show() def write_to_dmo( self, name: str, dataframe: PySparkDataFrame, write_mode: WriteMode ) -> None: + # The way its validating for DLO and dataframes columns, + # its not going to work for DMO because DMO may not exists, + # so just show the dataframe. + dataframe.show() diff --git a/src/datacustomcode/scan.py b/src/datacustomcode/scan.py index ac1eb99..ca64c26 100644 --- a/src/datacustomcode/scan.py +++ b/src/datacustomcode/scan.py @@ -16,6 +16,7 @@ import ast import os +import sys from typing import ( Any, ClassVar, @@ -40,6 +41,8 @@ }, } +STANDARD_LIBS = set(sys.stdlib_module_names) + class DataAccessLayerCalls(pydantic.BaseModel): read_dlo: frozenset[str] @@ -137,54 +140,6 @@ def found(self) -> DataAccessLayerCalls: class ImportVisitor(ast.NodeVisitor): """AST Visitor that extracts external package imports from Python code.""" - # Standard library modules that should be excluded from requirements - STANDARD_LIBS: ClassVar[set[str]] = { - "abc", - "argparse", - "ast", - "asyncio", - "base64", - "collections", - "configparser", - "contextlib", - "copy", - "csv", - "datetime", - "enum", - "functools", - "glob", - "hashlib", - "http", - "importlib", - "inspect", - "io", - "itertools", - "json", - "logging", - "math", - "os", - "pathlib", - "pickle", - "random", - "re", - "shutil", - "site", - "socket", - "sqlite3", - "string", - "subprocess", - "sys", - "tempfile", - "threading", - "time", - "traceback", - "typing", - "uuid", - "warnings", - "xml", - "zipfile", - } - # Additional packages to exclude from requirements.txt EXCLUDED_PACKAGES: ClassVar[set[str]] = { "datacustomcode", # Internal package @@ -200,7 +155,7 @@ def visit_Import(self, node: ast.Import) -> None: # Get the top-level package name package = name.name.split(".")[0] if ( - package not in self.STANDARD_LIBS + package not in STANDARD_LIBS and package not in self.EXCLUDED_PACKAGES and not package.startswith("_") ): @@ -213,7 +168,7 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None: # Get the top-level package package = node.module.split(".")[0] if ( - package not in self.STANDARD_LIBS + package not in STANDARD_LIBS and package not in self.EXCLUDED_PACKAGES and not package.startswith("_") ): diff --git a/tests/io/reader/test_query_api.py b/tests/io/reader/test_query_api.py index 9de8bd9..c9a4e86 100644 --- a/tests/io/reader/test_query_api.py +++ b/tests/io/reader/test_query_api.py @@ -143,7 +143,7 @@ def test_read_dlo( # Verify get_pandas_dataframe was called with the right SQL mock_connection.get_pandas_dataframe.assert_called_once_with( - SQL_QUERY_TEMPLATE.format("test_dlo") + SQL_QUERY_TEMPLATE.format("test_dlo", 1000) ) # Verify DataFrame was created with auto-inferred schema @@ -172,7 +172,7 @@ def test_read_dlo_with_schema( # Verify get_pandas_dataframe was called with the right SQL mock_connection.get_pandas_dataframe.assert_called_once_with( - SQL_QUERY_TEMPLATE.format("test_dlo") + SQL_QUERY_TEMPLATE.format("test_dlo", 1000) ) # Verify DataFrame was created with provided schema @@ -192,7 +192,7 @@ def test_read_dmo( # Verify get_pandas_dataframe was called with the right SQL mock_connection.get_pandas_dataframe.assert_called_once_with( - SQL_QUERY_TEMPLATE.format("test_dmo") + SQL_QUERY_TEMPLATE.format("test_dmo", 1000) ) # Verify DataFrame was created @@ -220,7 +220,7 @@ def test_read_dmo_with_schema( # Verify get_pandas_dataframe was called with the right SQL mock_connection.get_pandas_dataframe.assert_called_once_with( - SQL_QUERY_TEMPLATE.format("test_dmo") + SQL_QUERY_TEMPLATE.format("test_dmo", 1000) ) # Verify DataFrame was created with provided schema diff --git a/tests/io/writer/test_print.py b/tests/io/writer/test_print.py index 48c2f57..9245daa 100644 --- a/tests/io/writer/test_print.py +++ b/tests/io/writer/test_print.py @@ -23,18 +23,33 @@ def mock_dataframe(self): return df @pytest.fixture - def print_writer(self, mock_spark_session): + def mock_reader(self): + """Create a mock QueryAPIDataCloudReader.""" + reader = MagicMock() + mock_dlo_df = MagicMock() + mock_dlo_df.columns = ["col1", "col2"] + reader.read_dlo.return_value = mock_dlo_df + return reader + + @pytest.fixture + def print_writer(self, mock_spark_session, mock_reader): """Create a PrintDataCloudWriter instance.""" - return PrintDataCloudWriter(mock_spark_session) + return PrintDataCloudWriter(mock_spark_session, mock_reader) def test_write_to_dlo(self, print_writer, mock_dataframe): """Test write_to_dlo method calls dataframe.show().""" + # Mock the validate_dataframe_columns_against_dlo method + print_writer.validate_dataframe_columns_against_dlo = MagicMock() + # Call the method print_writer.write_to_dlo("test_dlo", mock_dataframe, WriteMode.OVERWRITE) # Verify show() was called mock_dataframe.show.assert_called_once() + # Verify validate_dataframe_columns_against_dlo was called + print_writer.validate_dataframe_columns_against_dlo.assert_called_once() + def test_write_to_dmo(self, print_writer, mock_dataframe): """Test write_to_dmo method calls dataframe.show().""" # Call the method @@ -59,9 +74,31 @@ def test_ignores_name_and_write_mode(self, print_writer, mock_dataframe): for name, write_mode in test_cases: # Reset mock before each call mock_dataframe.show.reset_mock() - + # Mock the validate_dataframe_columns_against_dlo method + print_writer.validate_dataframe_columns_against_dlo = MagicMock() # Call method print_writer.write_to_dlo(name, mock_dataframe, write_mode) # Verify show() was called with no arguments mock_dataframe.show.assert_called_once_with() + + print_writer.validate_dataframe_columns_against_dlo.assert_called_once() + + def test_validate_dataframe_columns_against_dlo(self, print_writer, mock_dataframe): + """Test validate_dataframe_columns_against_dlo method.""" + # Mock the QueryAPIDataCloudReader + + # Set up mock dataframe columns + mock_dataframe.columns = ["col1", "col2", "col3"] + + # Test that validation raises ValueError for extra columns + with pytest.raises(ValueError) as exc_info: + print_writer.validate_dataframe_columns_against_dlo( + mock_dataframe, "test_dlo" + ) + + assert "col3" in str(exc_info.value) + + # Test successful validation with matching columns + mock_dataframe.columns = ["col1", "col2"] + print_writer.validate_dataframe_columns_against_dlo(mock_dataframe, "test_dlo")