diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 386ba5da0..0ee6d92ea 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -112,6 +112,7 @@ export access_token="" There are several e2e test suites available: - `PySQLCoreTestSuite` - `PySQLLargeQueriesSuite` +- `PySQLStagingIngestionTestSuite` - `PySQLRetryTestSuite.HTTP503Suite` **[not documented]** - `PySQLRetryTestSuite.HTTP429Suite` **[not documented]** - `PySQLUnityCatalogTestSuite` **[not documented]** @@ -122,6 +123,12 @@ To execute the core test suite: poetry run python -m pytest tests/e2e/driver_tests.py::PySQLCoreTestSuite ``` +The `PySQLCoreTestSuite` namespace contains tests for all of the connector's basic features and behaviours. This is the default namespace where tests should be written unless they require specially configured clusters or take an especially long-time to execute by design. + +The `PySQLLargeQueriesSuite` namespace contains long-running query tests and is kept separate. In general, if the `PySQLCoreTestSuite` passes then these tests will as well. + +The `PySQLStagingIngestionTestSuite` namespace requires a cluster running DBR version > 12.x which supports staging ingestion commands. + The suites marked `[not documented]` require additional configuration which will be documented at a later time. ### Code formatting diff --git a/examples/README.md b/examples/README.md index 74446adeb..c4fe8ad68 100644 --- a/examples/README.md +++ b/examples/README.md @@ -35,4 +35,5 @@ To run all of these examples you can clone the entire repository to your disk. O - **`interactive_oauth.py`** shows the simplest example of authenticating by OAuth (no need for a PAT generated in the DBSQL UI) while Bring Your Own IDP is in public preview. When you run the script it will open a browser window so you can authenticate. Afterward, the script fetches some sample data from Databricks and prints it to the screen. For this script, the OAuth token is not persisted which means you need to authenticate every time you run the script. - **`persistent_oauth.py`** shows a more advanced example of authenticating by OAuth while Bring Your Own IDP is in public preview. In this case, it shows how to use a sublcass of `OAuthPersistence` to reuse an OAuth token across script executions. - **`set_user_agent.py`** shows how to customize the user agent header used for Thrift commands. In -this example the string `ExamplePartnerTag` will be added to the the user agent on every request. \ No newline at end of file +this example the string `ExamplePartnerTag` will be added to the the user agent on every request. +- **`staging_ingestion.py`** shows how the connector handles Databricks' experimental staging ingestion commands `GET`, `PUT`, and `REMOVE`. \ No newline at end of file diff --git a/examples/staging_ingestion.py b/examples/staging_ingestion.py new file mode 100644 index 000000000..2980506d0 --- /dev/null +++ b/examples/staging_ingestion.py @@ -0,0 +1,87 @@ +from databricks import sql +import os + +""" +Databricks experimentally supports data ingestion of local files via a cloud staging location. +Ingestion commands will work on DBR >12. And you must include a staging_allowed_local_path kwarg when +calling sql.connect(). + +Use databricks-sql-connector to PUT files into the staging location where Databricks can access them: + + PUT '/path/to/local/data.csv' INTO 'stage://tmp/some.user@databricks.com/salesdata/september.csv' OVERWRITE + +Files in a staging location can also be retrieved with a GET command + + GET 'stage://tmp/some.user@databricks.com/salesdata/september.csv' TO 'data.csv' + +and deleted with a REMOVE command: + + REMOVE 'stage://tmp/some.user@databricks.com/salesdata/september.csv' + +Ingestion queries are passed to cursor.execute() like any other query. For GET and PUT commands, a local file +will be read or written. For security, this local file must be contained within, or descended from, a +staging_allowed_local_path of the connection. + +Additionally, the connection can only manipulate files within the cloud storage location of the authenticated user. + +To run this script: + +1. Set the INGESTION_USER constant to the account email address of the authenticated user +2. Set the FILEPATH constant to the path of a file that will be uploaded (this example assumes its a CSV file) +3. Run this file + +Note: staging_allowed_local_path can be either a Pathlike object or a list of Pathlike objects. +""" + +INGESTION_USER = "some.user@example.com" +FILEPATH = "example.csv" + +# FILEPATH can be relative to the current directory. +# Resolve it into an absolute path +_complete_path = os.path.realpath(FILEPATH) + +if not os.path.exists(_complete_path): + + # It's easiest to save a file in the same directory as this script. But any path to a file will work. + raise Exception( + "You need to set FILEPATH in this script to a file that actually exists." + ) + +# Set staging_allowed_local_path equal to the directory that contains FILEPATH +staging_allowed_local_path = os.path.split(_complete_path)[0] + +with sql.connect( + server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + http_path=os.getenv("DATABRICKS_HTTP_PATH"), + access_token=os.getenv("DATABRICKS_TOKEN"), + staging_allowed_local_path=staging_allowed_local_path, +) as connection: + + with connection.cursor() as cursor: + + # Ingestion commands are executed like any other SQL. + # Here's a sample PUT query. You can remove OVERWRITE at the end to avoid silently overwriting data. + query = f"PUT '{_complete_path}' INTO 'stage://tmp/{INGESTION_USER}/pysql_examples/demo.csv' OVERWRITE" + + print(f"Uploading {FILEPATH} to staging location") + cursor.execute(query) + print("Upload was successful") + + temp_fp = os.path.realpath("temp.csv") + + # Here's a sample GET query. Note that `temp_fp` must also be contained within, or descended from, + # the staging_allowed_local_path. + query = ( + f"GET 'stage://tmp/{INGESTION_USER}/pysql_examples/demo.csv' TO '{temp_fp}'" + ) + + print(f"Fetching from staging location into new file called temp.csv") + cursor.execute(query) + print("Download was successful") + + # Here's a sample REMOVE query. It cleans up the the demo.csv created in our first query + query = f"REMOVE 'stage://tmp/{INGESTION_USER}/pysql_examples/demo.csv'" + + print("Removing demo.csv from staging location") + cursor.execute(query) + print("Remove was successful") diff --git a/poetry.lock b/poetry.lock index e0d197995..9fef1e5a6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -33,7 +33,7 @@ python-versions = ">=3.5" dev = ["cloudpickle", "coverage[toml] (>=5.0.2)", "furo", "hypothesis", "mypy (>=0.900,!=0.940)", "pre-commit", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "sphinx", "sphinx-notfound-page", "zope.interface"] docs = ["furo", "sphinx", "sphinx-notfound-page", "zope.interface"] tests = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "zope.interface"] -tests_no_zope = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins"] +tests-no-zope = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins"] [[package]] name = "black" @@ -75,7 +75,7 @@ optional = false python-versions = ">=3.6.0" [package.extras] -unicode_backport = ["unicodedata2"] +unicode-backport = ["unicodedata2"] [[package]] name = "click" @@ -151,9 +151,9 @@ python-versions = ">=3.6.1,<4.0" [package.extras] colors = ["colorama (>=0.4.3,<0.5.0)"] -pipfile_deprecated_finder = ["pipreqs", "requirementslib"] +pipfile-deprecated-finder = ["pipreqs", "requirementslib"] plugins = ["setuptools"] -requirements_deprecated_finder = ["pip-api", "pipreqs"] +requirements-deprecated-finder = ["pip-api", "pipreqs"] [[package]] name = "lazy-object-proxy" @@ -219,6 +219,14 @@ category = "main" optional = false python-versions = ">=3.7" +[[package]] +name = "numpy" +version = "1.23.4" +description = "NumPy is the fundamental package for array computing with Python." +category = "main" +optional = false +python-versions = ">=3.8" + [[package]] name = "oauthlib" version = "3.2.0" @@ -407,7 +415,7 @@ urllib3 = ">=1.21.1,<1.27" [package.extras] socks = ["PySocks (>=1.5.6,!=1.5.7)"] -use_chardet_on_py3 = ["chardet (>=3.0.2,<6)"] +use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] [[package]] name = "setuptools" @@ -506,7 +514,7 @@ testing = ["func-timeout", "jaraco.itertools", "pytest (>=6)", "pytest-black (>= [metadata] lock-version = "1.1" python-versions = "^3.7.1" -content-hash = "5de07f9b2c9a2f80ca0411f0f99b6b529b00b034f2ad13199cf29c862e125a57" +content-hash = "40ffbb9e4aa38da3f1169ab074b63a9e5b45461018c78e9b6d1fa784d2d8c4d1" [metadata.files] astroid = [ @@ -705,6 +713,34 @@ numpy = [ {file = "numpy-1.21.1-cp39-cp39-win_amd64.whl", hash = "sha256:01721eefe70544d548425a07c80be8377096a54118070b8a62476866d5208e33"}, {file = "numpy-1.21.1-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:2d4d1de6e6fb3d28781c73fbde702ac97f03d79e4ffd6598b880b2d95d62ead4"}, {file = "numpy-1.21.1.zip", hash = "sha256:dff4af63638afcc57a3dfb9e4b26d434a7a602d225b42d746ea7fe2edf1342fd"}, + {file = "numpy-1.23.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:95d79ada05005f6f4f337d3bb9de8a7774f259341c70bc88047a1f7b96a4bcb2"}, + {file = "numpy-1.23.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:926db372bc4ac1edf81cfb6c59e2a881606b409ddc0d0920b988174b2e2a767f"}, + {file = "numpy-1.23.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c237129f0e732885c9a6076a537e974160482eab8f10db6292e92154d4c67d71"}, + {file = "numpy-1.23.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a8365b942f9c1a7d0f0dc974747d99dd0a0cdfc5949a33119caf05cb314682d3"}, + {file = "numpy-1.23.4-cp310-cp310-win32.whl", hash = "sha256:2341f4ab6dba0834b685cce16dad5f9b6606ea8a00e6da154f5dbded70fdc4dd"}, + {file = "numpy-1.23.4-cp310-cp310-win_amd64.whl", hash = "sha256:d331afac87c92373826af83d2b2b435f57b17a5c74e6268b79355b970626e329"}, + {file = "numpy-1.23.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:488a66cb667359534bc70028d653ba1cf307bae88eab5929cd707c761ff037db"}, + {file = "numpy-1.23.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ce03305dd694c4873b9429274fd41fc7eb4e0e4dea07e0af97a933b079a5814f"}, + {file = "numpy-1.23.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8981d9b5619569899666170c7c9748920f4a5005bf79c72c07d08c8a035757b0"}, + {file = "numpy-1.23.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a70a7d3ce4c0e9284e92285cba91a4a3f5214d87ee0e95928f3614a256a1488"}, + {file = "numpy-1.23.4-cp311-cp311-win32.whl", hash = "sha256:5e13030f8793e9ee42f9c7d5777465a560eb78fa7e11b1c053427f2ccab90c79"}, + {file = "numpy-1.23.4-cp311-cp311-win_amd64.whl", hash = "sha256:7607b598217745cc40f751da38ffd03512d33ec06f3523fb0b5f82e09f6f676d"}, + {file = "numpy-1.23.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:7ab46e4e7ec63c8a5e6dbf5c1b9e1c92ba23a7ebecc86c336cb7bf3bd2fb10e5"}, + {file = "numpy-1.23.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a8aae2fb3180940011b4862b2dd3756616841c53db9734b27bb93813cd79fce6"}, + {file = "numpy-1.23.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c053d7557a8f022ec823196d242464b6955a7e7e5015b719e76003f63f82d0f"}, + {file = "numpy-1.23.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0882323e0ca4245eb0a3d0a74f88ce581cc33aedcfa396e415e5bba7bf05f68"}, + {file = "numpy-1.23.4-cp38-cp38-win32.whl", hash = "sha256:dada341ebb79619fe00a291185bba370c9803b1e1d7051610e01ed809ef3a4ba"}, + {file = "numpy-1.23.4-cp38-cp38-win_amd64.whl", hash = "sha256:0fe563fc8ed9dc4474cbf70742673fc4391d70f4363f917599a7fa99f042d5a8"}, + {file = "numpy-1.23.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c67b833dbccefe97cdd3f52798d430b9d3430396af7cdb2a0c32954c3ef73894"}, + {file = "numpy-1.23.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f76025acc8e2114bb664294a07ede0727aa75d63a06d2fae96bf29a81747e4a7"}, + {file = "numpy-1.23.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:12ac457b63ec8ded85d85c1e17d85efd3c2b0967ca39560b307a35a6703a4735"}, + {file = "numpy-1.23.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95de7dc7dc47a312f6feddd3da2500826defdccbc41608d0031276a24181a2c0"}, + {file = "numpy-1.23.4-cp39-cp39-win32.whl", hash = "sha256:f2f390aa4da44454db40a1f0201401f9036e8d578a25f01a6e237cea238337ef"}, + {file = "numpy-1.23.4-cp39-cp39-win_amd64.whl", hash = "sha256:f260da502d7441a45695199b4e7fd8ca87db659ba1c78f2bbf31f934fe76ae0e"}, + {file = "numpy-1.23.4-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:61be02e3bf810b60ab74e81d6d0d36246dbfb644a462458bb53b595791251911"}, + {file = "numpy-1.23.4-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:296d17aed51161dbad3c67ed6d164e51fcd18dbcd5dd4f9d0a9c6055dce30810"}, + {file = "numpy-1.23.4-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:4d52914c88b4930dafb6c48ba5115a96cbab40f45740239d9f4159c4ba779962"}, + {file = "numpy-1.23.4.tar.gz", hash = "sha256:ed2cc92af0efad20198638c69bb0fc2870a58dabfba6eb722c933b48556c686c"}, ] oauthlib = [ {file = "oauthlib-3.2.0-py3-none-any.whl", hash = "sha256:6db33440354787f9b7f3a6dbd4febf5d0f93758354060e802f6c06cb493022fe"}, diff --git a/pyproject.toml b/pyproject.toml index 9bc589599..8ee88ab03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,10 @@ pyarrow = "^9.0.0" lz4 = "^4.0.2" requests=">2.18.1" oauthlib=">=3.1.0" +numpy = [ + {version = "1.21.1", python = ">=3.7,<3.8"}, + {version = "1.23.4", python = ">=3.8"} +] [tool.poetry.dev-dependencies] pytest = "^7.1.2" diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 53b0c9715..863a67491 100644 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -2,6 +2,9 @@ import pandas import pyarrow +import requests +import json +import os from databricks.sql import __version__ from databricks.sql import * @@ -28,7 +31,7 @@ def __init__( session_configuration: Dict[str, Any] = None, catalog: Optional[str] = None, schema: Optional[str] = None, - **kwargs + **kwargs, ) -> None: """ Connect to a Databricks SQL endpoint or a Databricks cluster. @@ -173,7 +176,7 @@ def read(self) -> Optional[OAuthToken]: http_path, (http_headers or []) + base_headers, auth_provider, - **kwargs + **kwargs, ) self._session_handle = self.thrift_backend.open_session( @@ -297,6 +300,149 @@ def _check_not_closed(self): if not self.open: raise Error("Attempting operation on closed cursor") + def _handle_staging_operation( + self, staging_allowed_local_path: Union[None, str, List[str]] + ): + """Fetch the HTTP request instruction from a staging ingestion command + and call the designated handler. + + Raise an exception if localFile is specified by the server but the localFile + is not descended from staging_allowed_local_path. + """ + + if isinstance(staging_allowed_local_path, type(str())): + _staging_allowed_local_paths = [staging_allowed_local_path] + elif isinstance(staging_allowed_local_path, type(list())): + _staging_allowed_local_paths = staging_allowed_local_path + else: + raise Error( + "You must provide at least one staging_allowed_local_path when initialising a connection to perform ingestion commands" + ) + + abs_staging_allowed_local_paths = [ + os.path.abspath(i) for i in _staging_allowed_local_paths + ] + + assert self.active_result_set is not None + row = self.active_result_set.fetchone() + assert row is not None + + # Must set to None in cases where server response does not include localFile + abs_localFile = None + + # Default to not allow staging operations + allow_operation = False + if getattr(row, "localFile", None): + abs_localFile = os.path.abspath(row.localFile) + for abs_staging_allowed_local_path in abs_staging_allowed_local_paths: + # If the indicated local file matches at least one allowed base path, allow the operation + if ( + os.path.commonpath([abs_localFile, abs_staging_allowed_local_path]) + == abs_staging_allowed_local_path + ): + allow_operation = True + else: + continue + if not allow_operation: + raise Error( + "Local file operations are restricted to paths within the configured staging_allowed_local_path" + ) + + # TODO: Experiment with DBR sending real headers. + # The specification says headers will be in JSON format but the current null value is actually an empty list [] + handler_args = { + "presigned_url": row.presignedUrl, + "local_file": abs_localFile, + "headers": json.loads(row.headers or "{}"), + } + + logger.debug( + f"Attempting staging operation indicated by server: {row.operation} - {getattr(row, 'localFile', '')}" + ) + + # TODO: Create a retry loop here to re-attempt if the request times out or fails + if row.operation == "GET": + return self._handle_staging_get(**handler_args) + elif row.operation == "PUT": + return self._handle_staging_put(**handler_args) + elif row.operation == "REMOVE": + # Local file isn't needed to remove a remote resource + handler_args.pop("local_file") + return self._handle_staging_remove(**handler_args) + else: + raise Error( + f"Operation {row.operation} is not supported. " + + "Supported operations are GET, PUT, and REMOVE" + ) + + def _handle_staging_put( + self, presigned_url: str, local_file: str, headers: dict = None + ): + """Make an HTTP PUT request + + Raise an exception if request fails. Returns no data. + """ + + if local_file is None: + raise Error("Cannot perform PUT without specifying a local_file") + + with open(local_file, "rb") as fh: + r = requests.put(url=presigned_url, data=fh, headers=headers) + + # fmt: off + # Design borrowed from: https://stackoverflow.com/a/2342589/5093960 + + OK = requests.codes.ok # 200 + CREATED = requests.codes.created # 201 + ACCEPTED = requests.codes.accepted # 202 + NO_CONTENT = requests.codes.no_content # 204 + + # fmt: on + + if r.status_code not in [OK, CREATED, NO_CONTENT, ACCEPTED]: + raise Error( + f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}" + ) + + if r.status_code == ACCEPTED: + logger.debug( + f"Response code {ACCEPTED} from server indicates ingestion command was accepted " + + "but not yet applied on the server. It's possible this command may fail later." + ) + + def _handle_staging_get( + self, local_file: str, presigned_url: str, headers: dict = None + ): + """Make an HTTP GET request, create a local file with the received data + + Raise an exception if request fails. Returns no data. + """ + + if local_file is None: + raise Error("Cannot perform GET without specifying a local_file") + + r = requests.get(url=presigned_url, headers=headers) + + # response.ok verifies the status code is not between 400-600. + # Any 2xx or 3xx will evaluate r.ok == True + if not r.ok: + raise Error( + f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}" + ) + + with open(local_file, "wb") as fp: + fp.write(r.content) + + def _handle_staging_remove(self, presigned_url: str, headers: dict = None): + """Make an HTTP DELETE request to the presigned_url""" + + r = requests.delete(url=presigned_url, headers=headers) + + if not r.ok: + raise Error( + f"Staging operation over HTTP was unsuccessful: {r.status_code}-{r.text}" + ) + def execute( self, operation: str, parameters: Optional[Dict[str, str]] = None ) -> "Cursor": @@ -331,6 +477,12 @@ def execute( self.buffer_size_bytes, self.arraysize, ) + + if execute_response.is_staging_operation: + self._handle_staging_operation( + staging_allowed_local_path=self.thrift_backend.staging_allowed_local_path + ) + return self def executemany(self, operation, seq_of_parameters): diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index 48d7c2012..de505a8a0 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -6,6 +6,7 @@ import threading import lz4.frame from ssl import CERT_NONE, CERT_REQUIRED, create_default_context +from typing import List, Union import pyarrow import thrift.transport.THttpClient @@ -61,6 +62,7 @@ def __init__( http_path: str, http_headers, auth_provider: AuthProvider, + staging_allowed_local_path: Union[None, str, List[str]] = None, **kwargs, ): # Internal arguments in **kwargs: @@ -110,6 +112,7 @@ def __init__( else: raise ValueError("No valid connection settings.") + self.staging_allowed_local_path = staging_allowed_local_path self._initialize_retry_args(kwargs) self._use_arrow_native_complex_types = kwargs.get( "_use_arrow_native_complex_types", True @@ -452,7 +455,7 @@ def open_session(self, session_configuration, catalog, schema): initial_namespace = None open_session_req = ttypes.TOpenSessionReq( - client_protocol_i64=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V6, + client_protocol_i64=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7, client_protocol=None, initialNamespace=initial_namespace, canUseMultipleCatalogs=True, @@ -733,6 +736,7 @@ def _results_message_to_execute_response(self, resp, operation_state): .to_pybytes() ) lz4_compressed = t_result_set_metadata_resp.lz4Compressed + is_staging_operation = t_result_set_metadata_resp.isStagingOperation if direct_results and direct_results.resultSet: assert direct_results.resultSet.results.startRowOffset == 0 assert direct_results.resultSetMetadata @@ -752,6 +756,7 @@ def _results_message_to_execute_response(self, resp, operation_state): has_been_closed_server_side=has_been_closed_server_side, has_more_rows=has_more_rows, lz4_compressed=lz4_compressed, + is_staging_operation=is_staging_operation, command_handle=resp.operationHandle, description=description, arrow_schema_bytes=schema_bytes, diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 9c466886b..ae411c7dc 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -40,7 +40,7 @@ def remaining_rows(self) -> pyarrow.Table: ExecuteResponse = namedtuple( "ExecuteResponse", - "status has_been_closed_server_side has_more_rows description lz4_compressed " + "status has_been_closed_server_side has_more_rows description lz4_compressed is_staging_operation " "command_handle arrow_queue arrow_schema_bytes", ) diff --git a/tests/e2e/driver_tests.py b/tests/e2e/driver_tests.py index c29ebabc9..1c09d70e9 100644 --- a/tests/e2e/driver_tests.py +++ b/tests/e2e/driver_tests.py @@ -5,6 +5,7 @@ import logging import os import sys +import tempfile import threading import time from unittest import loader, skipIf, skipUnless, TestCase @@ -14,6 +15,7 @@ import pyarrow import pytz import thrift +import pytest import databricks.sql as sql from databricks.sql import STRING, BINARY, NUMBER, DATETIME, DATE, DatabaseError, Error, OperationalError @@ -630,6 +632,278 @@ def test_initial_namespace(self): cursor.execute("select current_database()") self.assertEqual(cursor.fetchone()[0], table_name) +class PySQLStagingIngestionTestSuite(PySQLTestCase): + """Simple namespace for ingestion tests. These should be run against DBR >12.x + + In addition to connection credentials (host, path, token) this suite requires an env var + named staging_ingestion_user""" + + staging_ingestion_user = os.getenv("staging_ingestion_user") + + if staging_ingestion_user is None: + raise ValueError( + "To run these tests you must designate a `staging_ingestion_user` environment variable. This will the user associated with the personal access token." + ) + + def test_staging_ingestion_life_cycle(self): + """PUT a file into the staging location + GET the file from the staging location + REMOVE the file from the staging location + Try to GET the file again expecting to raise an exception + """ + + # PUT should succeed + + fh, temp_path = tempfile.mkstemp() + + original_text = "hello world!".encode("utf-8") + + with open(fh, "wb") as fp: + fp.write(original_text) + + with self.connection(extra_params={"staging_allowed_local_path": temp_path}) as conn: + + cursor = conn.cursor() + query = f"PUT '{temp_path}' INTO 'stage://tmp/{self.staging_ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" + cursor.execute(query) + + # GET should succeed + + new_fh, new_temp_path = tempfile.mkstemp() + + with self.connection(extra_params={"staging_allowed_local_path": new_temp_path}) as conn: + cursor = conn.cursor() + query = f"GET 'stage://tmp/{self.staging_ingestion_user}/tmp/11/15/file1.csv' TO '{new_temp_path}'" + cursor.execute(query) + + with open(new_fh, "rb") as fp: + fetched_text = fp.read() + + assert fetched_text == original_text + + # REMOVE should succeed + + remove_query = ( + f"REMOVE 'stage://tmp/{self.staging_ingestion_user}/tmp/11/15/file1.csv'" + ) + + with self.connection(extra_params={"staging_allowed_local_path": "/"}) as conn: + cursor = conn.cursor() + cursor.execute(remove_query) + + # GET after REMOVE should fail + + with pytest.raises(Error): + cursor = conn.cursor() + query = f"GET 'stage://tmp/{self.staging_ingestion_user}/tmp/11/15/file1.csv' TO '{new_temp_path}'" + cursor.execute(query) + + os.remove(temp_path) + os.remove(new_temp_path) + + + def test_staging_ingestion_put_fails_without_staging_allowed_local_path(self): + """PUT operations are not supported unless the connection was built with + a parameter called staging_allowed_local_path + """ + + fh, temp_path = tempfile.mkstemp() + + original_text = "hello world!".encode("utf-8") + + with open(fh, "wb") as fp: + fp.write(original_text) + + with pytest.raises(Error): + with self.connection() as conn: + cursor = conn.cursor() + query = f"PUT '{temp_path}' INTO 'stage://tmp/{self.staging_ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" + cursor.execute(query) + + def test_staging_ingestion_put_fails_if_localFile_not_in_staging_allowed_local_path(self): + + + fh, temp_path = tempfile.mkstemp() + + original_text = "hello world!".encode("utf-8") + + with open(fh, "wb") as fp: + fp.write(original_text) + + base_path, filename = os.path.split(temp_path) + + # Add junk to base_path + base_path = os.path.join(base_path, "temp") + + with pytest.raises(Error): + with self.connection(extra_params={"staging_allowed_local_path": base_path}) as conn: + cursor = conn.cursor() + query = f"PUT '{temp_path}' INTO 'stage://tmp/{self.staging_ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" + cursor.execute(query) + + def test_staging_ingestion_put_fails_if_file_exists_and_overwrite_not_set(self): + """PUT a file into the staging location twice. First command should succeed. Second should fail. + """ + + fh, temp_path = tempfile.mkstemp() + + original_text = "hello world!".encode("utf-8") + + with open(fh, "wb") as fp: + fp.write(original_text) + + def perform_put(): + with self.connection(extra_params={"staging_allowed_local_path": temp_path}) as conn: + cursor = conn.cursor() + query = f"PUT '{temp_path}' INTO 'stage://tmp/{self.staging_ingestion_user}/tmp/12/15/file1.csv'" + cursor.execute(query) + + def perform_remove(): + remove_query = ( + f"REMOVE 'stage://tmp/{self.staging_ingestion_user}/tmp/12/15/file1.csv'" + ) + + with self.connection(extra_params={"staging_allowed_local_path": "/"}) as conn: + cursor = conn.cursor() + cursor.execute(remove_query) + + + # Make sure file does not exist + perform_remove() + + # Put the file + perform_put() + + # Try to put it again + with pytest.raises(sql.exc.ServerOperationError, match="FILE_IN_STAGING_PATH_ALREADY_EXISTS"): + perform_put() + + # Clean up after ourselves + perform_remove() + + def test_staging_ingestion_fails_to_modify_another_staging_user(self): + """The server should only allow modification of the staging_ingestion_user's files + """ + + some_other_user = "mary.poppins@databricks.com" + + fh, temp_path = tempfile.mkstemp() + + original_text = "hello world!".encode("utf-8") + + with open(fh, "wb") as fp: + fp.write(original_text) + + def perform_put(): + with self.connection(extra_params={"staging_allowed_local_path": temp_path}) as conn: + cursor = conn.cursor() + query = f"PUT '{temp_path}' INTO 'stage://tmp/{some_other_user}/tmp/12/15/file1.csv' OVERWRITE" + cursor.execute(query) + + def perform_remove(): + remove_query = ( + f"REMOVE 'stage://tmp/{some_other_user}/tmp/12/15/file1.csv'" + ) + + with self.connection(extra_params={"staging_allowed_local_path": "/"}) as conn: + cursor = conn.cursor() + cursor.execute(remove_query) + + def perform_get(): + with self.connection(extra_params={"staging_allowed_local_path": temp_path}) as conn: + cursor = conn.cursor() + query = f"GET 'stage://tmp/{some_other_user}/tmp/11/15/file1.csv' TO '{temp_path}'" + cursor.execute(query) + + # PUT should fail with permissions error + with pytest.raises(sql.exc.ServerOperationError, match="PERMISSION_DENIED"): + perform_put() + + # REMOVE should fail with permissions error + with pytest.raises(sql.exc.ServerOperationError, match="PERMISSION_DENIED"): + perform_remove() + + # GET should fail with permissions error + with pytest.raises(sql.exc.ServerOperationError, match="PERMISSION_DENIED"): + perform_get() + + def test_staging_ingestion_put_fails_if_absolute_localFile_not_in_staging_allowed_local_path(self): + """ + This test confirms that staging_allowed_local_path and target_file are resolved into absolute paths. + """ + + # If these two paths are not resolved absolutely, they appear to share a common path of /var/www/html + # after resolution their common path is only /var/www which should raise an exception + # Because the common path must always be equal to staging_allowed_local_path + staging_allowed_local_path = "/var/www/html" + target_file = "/var/www/html/../html1/not_allowed.html" + + with pytest.raises(Error): + with self.connection(extra_params={"staging_allowed_local_path": staging_allowed_local_path}) as conn: + cursor = conn.cursor() + query = f"PUT '{target_file}' INTO 'stage://tmp/{self.staging_ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" + cursor.execute(query) + + def test_staging_ingestion_empty_local_path_fails_to_parse_at_server(self): + staging_allowed_local_path = "/var/www/html" + target_file = "" + + with pytest.raises(Error, match="EMPTY_LOCAL_FILE_IN_STAGING_ACCESS_QUERY"): + with self.connection(extra_params={"staging_allowed_local_path": staging_allowed_local_path}) as conn: + cursor = conn.cursor() + query = f"PUT '{target_file}' INTO 'stage://tmp/{self.staging_ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" + cursor.execute(query) + + def test_staging_ingestion_invalid_staging_path_fails_at_server(self): + staging_allowed_local_path = "/var/www/html" + target_file = "index.html" + + with pytest.raises(Error, match="INVALID_STAGING_PATH_IN_STAGING_ACCESS_QUERY"): + with self.connection(extra_params={"staging_allowed_local_path": staging_allowed_local_path}) as conn: + cursor = conn.cursor() + query = f"PUT '{target_file}' INTO 'stageRANDOMSTRINGOFCHARACTERS://tmp/{self.staging_ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" + cursor.execute(query) + + def test_staging_ingestion_supports_multiple_staging_allowed_local_path_values(self): + """staging_allowed_local_path may be either a path-like object or a list of path-like objects. + + This test confirms that two configured base paths: + 1 - doesn't raise an exception + 2 - allows uploads from both paths + 3 - doesn't allow uploads from a third path + """ + + def generate_file_and_path_and_queries(): + """ + 1. Makes a temp file with some contents. + 2. Write a query to PUT it into a staging location + 3. Write a query to REMOVE it from that location (for cleanup) + """ + fh, temp_path = tempfile.mkstemp() + with open(fh, "wb") as fp: + original_text = "hello world!".encode("utf-8") + fp.write(original_text) + put_query = f"PUT '{temp_path}' INTO 'stage://tmp/{self.staging_ingestion_user}/tmp/11/15/{id(temp_path)}.csv' OVERWRITE" + remove_query = f"REMOVE 'stage://tmp/{self.staging_ingestion_user}/tmp/11/15/{id(temp_path)}.csv'" + return fh, temp_path, put_query, remove_query + + fh1, temp_path1, put_query1, remove_query1 = generate_file_and_path_and_queries() + fh2, temp_path2, put_query2, remove_query2 = generate_file_and_path_and_queries() + fh3, temp_path3, put_query3, remove_query3 = generate_file_and_path_and_queries() + + with self.connection(extra_params={"staging_allowed_local_path": [temp_path1, temp_path2]}) as conn: + cursor = conn.cursor() + + cursor.execute(put_query1) + cursor.execute(put_query2) + + with pytest.raises(Error, match="Local file operations are restricted to paths within the configured staging_allowed_local_path"): + cursor.execute(put_query3) + + # Then clean up the files we made + cursor.execute(remove_query1) + cursor.execute(remove_query2) + def main(cli_args): global get_args_from_env diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index 33fca0751..7d5686f84 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -41,7 +41,8 @@ def make_dummy_result_set_from_initial_results(initial_results): lz4_compressed=Mock(), command_handle=None, arrow_queue=arrow_queue, - arrow_schema_bytes=schema.serialize().to_pybytes())) + arrow_schema_bytes=schema.serialize().to_pybytes(), + is_staging_operation=False)) num_cols = len(initial_results[0]) if initial_results else 0 rs.description = [(f'col{col_id}', 'integer', None, None, None, None, None) for col_id in range(num_cols)] @@ -75,7 +76,8 @@ def fetch_results(op_handle, max_rows, max_bytes, expected_row_start_offset, lz4 lz4_compressed=Mock(), command_handle=None, arrow_queue=None, - arrow_schema_bytes=None)) + arrow_schema_bytes=None, + is_staging_operation=False)) return rs def assertEqualRowValues(self, actual, expected): diff --git a/tests/unit/tests.py b/tests/unit/tests.py index d5ca23877..74274373f 100644 --- a/tests/unit/tests.py +++ b/tests/unit/tests.py @@ -12,9 +12,9 @@ from databricks.sql import InterfaceError, DatabaseError, Error, NotSupportedError from databricks.sql.types import Row -from test_fetches import FetchTests -from test_thrift_backend import ThriftBackendTestSuite -from test_arrow_queue import ArrowQueueSuite +from tests.unit.test_fetches import FetchTests +from tests.unit.test_thrift_backend import ThriftBackendTestSuite +from tests.unit.test_arrow_queue import ArrowQueueSuite class ClientTestSuite(unittest.TestCase): @@ -534,6 +534,21 @@ def test_cursor_keeps_connection_alive(self, mock_client_class): self.assertEqual(instance.close_session.call_count, 0) cursor.close() + @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.client.Cursor._handle_staging_operation" % PACKAGE_NAME) + @patch("%s.utils.ExecuteResponse" % PACKAGE_NAME) + def test_staging_operation_response_is_handled(self, mock_client_class, mock_handle_staging_operation, mock_execute_response): + # If server sets ExecuteResponse.is_staging_operation True then _handle_staging_operation should be called + + mock_execute_response.is_staging_operation = True + + connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + cursor = connection.cursor() + cursor.execute("Text of some staging operation command;") + connection.close() + + mock_handle_staging_operation.assert_called_once_with() + if __name__ == '__main__': suite = unittest.TestLoader().loadTestsFromModule(sys.modules[__name__])