diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index ea901c3a..1810c2a7 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -223,7 +223,8 @@ def read(self) -> Optional[OAuthToken]: self.port = kwargs.get("_port", 443) self.disable_pandas = kwargs.get("_disable_pandas", False) self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True) - + self.staging_allowed_local_path = kwargs.get("staging_allowed_local_path", None) + auth_provider = get_python_sql_connector_auth_provider( server_hostname, **kwargs ) @@ -1254,6 +1255,7 @@ def __init__( self._arrow_schema_bytes = execute_response.arrow_schema_bytes self._next_row_index = 0 self._use_cloud_fetch = use_cloud_fetch + self.is_staging_operation = execute_response.is_staging_operation if execute_response.arrow_queue: # In this case the server has taken the fast path and returned an initial batch of diff --git a/src/databricks/sql/utils/string_util.py b/src/databricks/sql/utils/string_util.py new file mode 100644 index 00000000..f42fb7b8 --- /dev/null +++ b/src/databricks/sql/utils/string_util.py @@ -0,0 +1,98 @@ +class StringUtil: + @staticmethod + def escape_string_literal(s: str) -> str: + """Escapes single quotes in a string for safe SQL usage.""" + if s is None: + return None + return s.replace("'", "''") + + @staticmethod + def get_volume_path(catalog: str, schema: str, volume: str) -> str: + """ + Constructs and escapes the volume path in the form of /Volumes/catalog/schema/volume/ + """ + path = f"/Volumes/{catalog}/{schema}/{volume}/" + return StringUtil.escape_string_literal(path) + + @staticmethod + def get_object_full_path( + catalog: str, schema: str, volume: str, object_path: str + ) -> str: + """ + Returns the full escaped object path by appending the escaped object name to the volume path. + """ + return StringUtil.get_volume_path( + catalog, schema, volume + ) + StringUtil.escape_string_literal(object_path) + + @staticmethod + def create_get_object_query( + catalog: str, schema: str, volume: str, object_path: str, local_path: str + ) -> str: + """ + Returns the SQL GET command for retrieving an object to a local path. + """ + return f"GET '{StringUtil.get_object_full_path(catalog, schema, volume, object_path)}' TO '{StringUtil.escape_string_literal(local_path)}'" + + @staticmethod + def create_get_object_query_for_input_stream( + catalog: str, schema: str, volume: str, object_path: str + ) -> str: + """ + Constructs a GET query that writes the retrieved object to an input stream placeholder. + """ + full_path = StringUtil.get_object_full_path( + catalog, schema, volume, object_path + ) + return f"GET '{full_path}' TO '__input_stream__'" + + @staticmethod + def create_put_object_query( + catalog: str, + schema: str, + volume: str, + object_path: str, + local_path: str, + overwrite: bool, + ) -> str: + escaped_local_path = StringUtil.escape_string_literal(local_path) + full_remote_path = StringUtil.get_object_full_path( + catalog, schema, volume, object_path + ) + overwrite_clause = " OVERWRITE" if overwrite else "" + return f"PUT '{escaped_local_path}' INTO '{full_remote_path}'{overwrite_clause}" + + @staticmethod + def create_put_object_query_for_input_stream( + catalog: str, schema: str, volume: str, object_path: str, to_overwrite: bool + ) -> str: + """ + Constructs a PUT query that uploads from an input stream to a volume path. + Appends 'OVERWRITE' if to_overwrite is True. + """ + full_remote_path = StringUtil.get_object_full_path( + catalog, schema, volume, object_path + ) + overwrite_clause = " OVERWRITE" if to_overwrite else "" + return f"PUT '__input_stream__' INTO '{full_remote_path}'{overwrite_clause}" + + @staticmethod + def get_object_query( + catalog: str, schema: str, volume: str, object_path: str, local_path: str + ) -> str: + """ + Public entry point to create GET object query. + Equivalent to: String getObjectQuery = createGetObjectQuery(...) + """ + return StringUtil.create_get_object_query( + catalog, schema, volume, object_path, local_path + ) + + @staticmethod + def create_delete_object_query( + catalog: str, schema: str, volume: str, object_path: str + ) -> str: + """ + Returns the SQL REMOVE command for deleting an object from a volume. + """ + return f"REMOVE '{StringUtil.get_object_full_path(catalog, schema, volume, object_path)}'" diff --git a/src/databricks/sql/volume/volume_client.py b/src/databricks/sql/volume/volume_client.py new file mode 100644 index 00000000..83fa425d --- /dev/null +++ b/src/databricks/sql/volume/volume_client.py @@ -0,0 +1,81 @@ +from typing import BinaryIO + +from databricks.sql.utils.string_util import StringUtil + + +class VolumeClient: + """ + Databricks Volume Client + """ + + def __init__(self, conn): + """ + Initialize the VolumeClient with a connection object. + + :param conn: Connection object to Databricks. + """ + self.conn = conn + + def is_staging_operation_allowed(self, condition) -> bool: + if not condition: + raise ValueError("Staging operation is not allowed") + + def get_object( + self, catalog: str, schema: str, volume: str, object_path: str, local_path: str + ) -> bool: + get_object_query = StringUtil.create_get_object_query( + catalog, schema, volume, object_path, local_path + ) + + with self.conn.cursor() as cursor: + cursor.execute(get_object_query) + self.is_staging_operation_allowed(cursor.active_result_set.is_staging_operation) + volume_processor = VolumeProcessor(local_path=local_path) + return True + + def get_object( + self, catalog: str, schema: str, volume: str, object_path: str + ) -> BinaryIO: + get_object_query = StringUtil.create_get_object_query_for_input_stream( + catalog, schema, volume, object_path + ) + return True + + def put_object( + self, + catalog: str, + schema: str, + volume: str, + object_path: str, + local_path: str, + to_overwrite: bool, + ) -> bool: + put_object_query = StringUtil.create_put_object_query( + catalog, schema, volume, object_path, local_path, to_overwrite + ) + return True + + def put_object( + self, + catalog: str, + schema: str, + volume: str, + object_path: str, + input_stream: BinaryIO, + content_length: int, + to_overwrite: bool, + ) -> bool: + put_object_query_for_input_stream = ( + StringUtil.create_put_object_query_for_input_stream( + catalog, schema, volume, object_path, to_overwrite + ) + ) + return True + + def delete_object( + self, catalog: str, schema: str, volume: str, object_path: str + ) -> bool: + delete_object_query = StringUtil.create_delete_object_query( + catalog, schema, volume, object_path + ) + return True diff --git a/src/databricks/sql/volume/volume_processor.py b/src/databricks/sql/volume/volume_processor.py new file mode 100644 index 00000000..95c14be8 --- /dev/null +++ b/src/databricks/sql/volume/volume_processor.py @@ -0,0 +1,12 @@ +from databricks.sql import Connection + +class VolumeProcessor: + def __init__(self, **kwargs): + self.connection = kwargs.get("connection") + self.local_path = kwargs.get("local_path") + self.input_stream = kwargs.get("input_stream") + self.content_length = kwargs.get("content_length") + self.abs_staging_allowed_local_paths = self.get_abs_staging_allowed_local_paths(kwargs.get("staging_allowed_local_path")) + + def process_volume(self, catalog: str, schema: str, volume: str): + pass