Skip to content

Volume client #555

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,8 @@ def read(self) -> Optional[OAuthToken]:
self.port = kwargs.get("_port", 443)
self.disable_pandas = kwargs.get("_disable_pandas", False)
self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True)

self.staging_allowed_local_path = kwargs.get("staging_allowed_local_path", None)

auth_provider = get_python_sql_connector_auth_provider(
server_hostname, **kwargs
)
Expand Down Expand Up @@ -1254,6 +1255,7 @@ def __init__(
self._arrow_schema_bytes = execute_response.arrow_schema_bytes
self._next_row_index = 0
self._use_cloud_fetch = use_cloud_fetch
self.is_staging_operation = execute_response.is_staging_operation

if execute_response.arrow_queue:
# In this case the server has taken the fast path and returned an initial batch of
Expand Down
98 changes: 98 additions & 0 deletions src/databricks/sql/utils/string_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
class StringUtil:
@staticmethod
def escape_string_literal(s: str) -> str:
"""Escapes single quotes in a string for safe SQL usage."""
if s is None:
return None
return s.replace("'", "''")

@staticmethod
def get_volume_path(catalog: str, schema: str, volume: str) -> str:
"""
Constructs and escapes the volume path in the form of /Volumes/catalog/schema/volume/
"""
path = f"/Volumes/{catalog}/{schema}/{volume}/"
return StringUtil.escape_string_literal(path)

@staticmethod
def get_object_full_path(
catalog: str, schema: str, volume: str, object_path: str
) -> str:
"""
Returns the full escaped object path by appending the escaped object name to the volume path.
"""
return StringUtil.get_volume_path(
catalog, schema, volume
) + StringUtil.escape_string_literal(object_path)

@staticmethod
def create_get_object_query(
catalog: str, schema: str, volume: str, object_path: str, local_path: str
) -> str:
"""
Returns the SQL GET command for retrieving an object to a local path.
"""
return f"GET '{StringUtil.get_object_full_path(catalog, schema, volume, object_path)}' TO '{StringUtil.escape_string_literal(local_path)}'"

@staticmethod
def create_get_object_query_for_input_stream(
catalog: str, schema: str, volume: str, object_path: str
) -> str:
"""
Constructs a GET query that writes the retrieved object to an input stream placeholder.
"""
full_path = StringUtil.get_object_full_path(
catalog, schema, volume, object_path
)
return f"GET '{full_path}' TO '__input_stream__'"

@staticmethod
def create_put_object_query(
catalog: str,
schema: str,
volume: str,
object_path: str,
local_path: str,
overwrite: bool,
) -> str:
escaped_local_path = StringUtil.escape_string_literal(local_path)
full_remote_path = StringUtil.get_object_full_path(
catalog, schema, volume, object_path
)
overwrite_clause = " OVERWRITE" if overwrite else ""
return f"PUT '{escaped_local_path}' INTO '{full_remote_path}'{overwrite_clause}"

@staticmethod
def create_put_object_query_for_input_stream(
catalog: str, schema: str, volume: str, object_path: str, to_overwrite: bool
) -> str:
"""
Constructs a PUT query that uploads from an input stream to a volume path.
Appends 'OVERWRITE' if to_overwrite is True.
"""
full_remote_path = StringUtil.get_object_full_path(
catalog, schema, volume, object_path
)
overwrite_clause = " OVERWRITE" if to_overwrite else ""
return f"PUT '__input_stream__' INTO '{full_remote_path}'{overwrite_clause}"

@staticmethod
def get_object_query(
catalog: str, schema: str, volume: str, object_path: str, local_path: str
) -> str:
"""
Public entry point to create GET object query.
Equivalent to: String getObjectQuery = createGetObjectQuery(...)
"""
return StringUtil.create_get_object_query(
catalog, schema, volume, object_path, local_path
)

@staticmethod
def create_delete_object_query(
catalog: str, schema: str, volume: str, object_path: str
) -> str:
"""
Returns the SQL REMOVE command for deleting an object from a volume.
"""
return f"REMOVE '{StringUtil.get_object_full_path(catalog, schema, volume, object_path)}'"
81 changes: 81 additions & 0 deletions src/databricks/sql/volume/volume_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from typing import BinaryIO

from databricks.sql.utils.string_util import StringUtil


class VolumeClient:
"""
Databricks Volume Client
"""

def __init__(self, conn):
"""
Initialize the VolumeClient with a connection object.

:param conn: Connection object to Databricks.
"""
self.conn = conn

def is_staging_operation_allowed(self, condition) -> bool:
if not condition:
raise ValueError("Staging operation is not allowed")

def get_object(
self, catalog: str, schema: str, volume: str, object_path: str, local_path: str
) -> bool:
get_object_query = StringUtil.create_get_object_query(
catalog, schema, volume, object_path, local_path
)

with self.conn.cursor() as cursor:
cursor.execute(get_object_query)
self.is_staging_operation_allowed(cursor.active_result_set.is_staging_operation)
volume_processor = VolumeProcessor(local_path=local_path)
return True

def get_object(
self, catalog: str, schema: str, volume: str, object_path: str
) -> BinaryIO:
get_object_query = StringUtil.create_get_object_query_for_input_stream(
catalog, schema, volume, object_path
)
return True

def put_object(
self,
catalog: str,
schema: str,
volume: str,
object_path: str,
local_path: str,
to_overwrite: bool,
) -> bool:
put_object_query = StringUtil.create_put_object_query(
catalog, schema, volume, object_path, local_path, to_overwrite
)
return True

def put_object(
self,
catalog: str,
schema: str,
volume: str,
object_path: str,
input_stream: BinaryIO,
content_length: int,
to_overwrite: bool,
) -> bool:
put_object_query_for_input_stream = (
StringUtil.create_put_object_query_for_input_stream(
catalog, schema, volume, object_path, to_overwrite
)
)
return True

def delete_object(
self, catalog: str, schema: str, volume: str, object_path: str
) -> bool:
delete_object_query = StringUtil.create_delete_object_query(
catalog, schema, volume, object_path
)
return True
12 changes: 12 additions & 0 deletions src/databricks/sql/volume/volume_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from databricks.sql import Connection

class VolumeProcessor:
def __init__(self, **kwargs):
self.connection = kwargs.get("connection")
self.local_path = kwargs.get("local_path")
self.input_stream = kwargs.get("input_stream")
self.content_length = kwargs.get("content_length")
self.abs_staging_allowed_local_paths = self.get_abs_staging_allowed_local_paths(kwargs.get("staging_allowed_local_path"))

def process_volume(self, catalog: str, schema: str, volume: str):
pass
Loading