Skip to content

Add Catalog handling in database name #42

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions dbsqlcli/dbsqlclirc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ syntax_style = default
key_bindings = emacs

# DBSQL prompt
# \c - Catalog name
# \d - Database name
# \h - Hostname
# \D - The full current date
Expand All @@ -51,7 +52,7 @@ key_bindings = emacs
# \P - AM/PM
# \R - The current time, in 24-hour military time (0–23)
# \s - Seconds of the current time
prompt = '\h:\d> '
prompt = '\h:\c.\d> '
prompt_continuation = '-> '

# enable pager on startup
Expand Down Expand Up @@ -96,4 +97,4 @@ output.even-row = ""
# [credentials]
# host_name = ""
# http_path = ""
# access_token = ""
# access_token = ""
5 changes: 4 additions & 1 deletion dbsqlcli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,8 @@ def change_db(self, arg, **_):
None,
None,
None,
'You are now connected to database "%s"' % self.sqlexecute.database,
'You are now connected to database "%s.%s"'
% (self.sqlexecute.catalog, self.sqlexecute.database),
)

def change_prompt_format(self, arg, **_):
Expand Down Expand Up @@ -599,6 +600,7 @@ def get_prompt(self, string):
string = string.replace(
"\\h", sqlexecute.hostname.replace(".cloud.databricks.com", "")
)
string = string.replace("\\c", sqlexecute.catalog or "(none)")
string = string.replace("\\d", sqlexecute.database or "(none)")
string = string.replace("\\n", "\n")
string = string.replace("\\D", now.strftime("%a %b %d %H:%M:%S %Y"))
Expand Down Expand Up @@ -702,6 +704,7 @@ def cli(
Examples:
- dbsqlcli
- dbsqlcli my_database
- dbsqlcli my_catalog.my_database
"""
if (clirc == DBSQLCLIRC) and (not os.path.exists(os.path.expanduser(clirc))):
err_msg = (
Expand Down
13 changes: 7 additions & 6 deletions dbsqlcli/packages/special/dbcommands.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ def list_tables(cur, arg=None, arg_type=PARSED_QUERY, verbose=False):
"\\l", "\\l", "List databases.", arg_type=RAW_QUERY, case_sensitive=True
)
def list_databases(cur, **_):
_databases = cur.schemas().fetchall()
if _databases:
headers = [x[0] for x in _databases]
return [(None, _databases, headers, "")]
else:
return [(None, None, None, "")]
databases = cur.schemas().fetchall()
if databases:
headers = [
field.title().removeprefix("Table_") for field in databases[0].__fields__
]
return [(None, databases, headers, "")]
return [(None, None, None, "")]
34 changes: 21 additions & 13 deletions dbsqlcli/sqlexecute.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,32 @@ def read(self, hostname: str) -> Optional[OAuthToken]:
class SQLExecute(object):
DATABASES_QUERY = "SHOW DATABASES"

def __init__(self, hostname, http_path, access_token, database, auth_type=None):
def __init__(
self, hostname, http_path, access_token, database="default", auth_type=None
):
self.hostname = hostname
self.http_path = http_path
self.access_token = access_token
self.database = database or "default"
self.auth_type = auth_type

self.connect(database=self.database)
self._set_catalog_database(database)
self.connect()

def _set_catalog_database(self, database):
"""Sets the catalog and database name if a single dot is supplied"""
if database.count(".") == 1:
component = database.split(".")
self.catalog = component[0]
self.database = component[1]
else:
self.catalog = "hive_metastore"
self.database = database

def connect(self, database=None):
self.close_connection()

if database:
self._set_catalog_database(database)

oauth_params = {}
if self.auth_type == AuthType.DATABRICKS_OAUTH.value:
oauth_params = {
Expand All @@ -63,20 +77,14 @@ def connect(self, database=None):
server_hostname=self.hostname,
http_path=self.http_path,
access_token=self.access_token,
schema=database,
catalog=self.catalog,
schema=self.database,
_user_agent_entry=USER_AGENT_STRING,
**oauth_params,
)

self.database = database or self.database

self.conn = conn

def reconnect(self):

self.close_connection()
self.connect(database=self.database)

def close_connection(self):
"""Close any open connection and remove the `conn` attribute"""

Expand Down Expand Up @@ -138,7 +146,7 @@ def run(self, statement):
f"SQL Gateway was timed out. Attempting to reconnect. Attempt {attempts+1}. Error: {e}"
)
attempts += 1
self.reconnect()
self.connect()

def get_result(self, cursor):
"""Get the current result's data from the cursor."""
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "databricks-sql-cli"
version = "0.1.x"
version = "0.1.5"
description = "A DBCLI client for Databricks SQL"
authors = ["Databricks SQL CLI Maintainers <[email protected]>"]
packages = [{include = "dbsqlcli"}]
Expand Down