Skip to content

Commit 21d3a79

Browse files
author
Jesse Whitehouse
committed
Basic metadata operation now supported. This requires a non HMS catalog.
Signed-off-by: Jesse Whitehouse <[email protected]>
1 parent aec989a commit 21d3a79

File tree

3 files changed

+67
-8
lines changed

3 files changed

+67
-8
lines changed

src/databricks/sqlalchemy/dialect/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,25 @@
1+
"""This module's layout loosely follows example of SQLAlchemy's postgres dialect
2+
"""
3+
14
from sqlalchemy.engine import default
25
from sqlalchemy.exc import DatabaseError
36

47
from databricks import sql
58

69

10+
from databricks.sqlalchemy.dialect.base import DatabricksIdentifierPreparer
11+
712
class DatabricksDialect(default.DefaultDialect):
13+
"""This dialect implements only those methods required to pass our e2e tests
14+
"""
815

916
# Possible attributes are defined here: https://docs.sqlalchemy.org/en/14/core/internals.html#sqlalchemy.engine.Dialect
1017
name: str = "databricks"
1118
driver: str = "thrift"
1219
default_schema_name: str = "default"
1320

21+
preparer = DatabricksIdentifierPreparer
22+
1423
@classmethod
1524
def dbapi(cls):
1625
return sql
@@ -23,6 +32,8 @@ def create_connect_args(self, url):
2332
"server_hostname": url.host,
2433
"access_token": url.password,
2534
"http_path": url.query.get("http_path"),
35+
"catalog": url.query.get("catalog"),
36+
"schema": url.query.get("schema")
2637
}
2738

2839
return [], kwargs
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import re
2+
from sqlalchemy.sql import compiler
3+
4+
5+
class DatabricksIdentifierPreparer(compiler.IdentifierPreparer):
6+
# SparkSQL identifier specification:
7+
# ref: https://spark.apache.org/docs/latest/sql-ref-identifier.html
8+
9+
legal_characters = re.compile(r"^[A-Z0-9_]+$", re.I)
10+
11+
def __init__(self, dialect):
12+
super().__init__(dialect, initial_quote="`")

tests/e2e/sqlalchemy/test_queries.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import os, datetime
22
import pytest
33
from unittest import skipIf
4-
from sqlalchemy import create_engine, select, Column
4+
from sqlalchemy import create_engine, select, insert, Column, MetaData, Table
55
from sqlalchemy.orm import declarative_base, Session
6-
from sqlalchemy.types import SMALLINT, Integer, BigInteger, Float, DECIMAL, BOOLEAN, String
6+
from sqlalchemy.types import SMALLINT, Integer, BOOLEAN, String
77

88

99
@pytest.fixture
@@ -12,8 +12,10 @@ def db_engine():
1212
HOST = os.environ.get("host")
1313
HTTP_PATH = os.environ.get("http_path")
1414
ACCESS_TOKEN = os.environ.get("access_token")
15+
CATALOG = os.environ.get("catalog")
16+
SCHEMA = os.environ.get("schema")
1517

16-
engine = create_engine(f"databricks+thrift://token:{ACCESS_TOKEN}@{HOST}?http_path={HTTP_PATH}")
18+
engine = create_engine(f"databricks+thrift://token:{ACCESS_TOKEN}@{HOST}?http_path={HTTP_PATH}&catalog={CATALOG}&schema={SCHEMA}")
1719
return engine
1820

1921

@@ -26,23 +28,57 @@ def base(db_engine):
2628
def session(db_engine):
2729
return Session(bind=db_engine)
2830

31+
@pytest.fixture()
32+
def metadata_obj(db_engine):
33+
return MetaData(bind=db_engine)
34+
2935

3036
def test_can_connect(db_engine):
3137
simple_query = "SELECT 1"
3238
result = db_engine.execute(simple_query).fetchall()
3339
assert len(result) == 1
3440

35-
@skipIf(True, 'metadata operations not yet supported')
36-
def test_create_insert_drop_table(base, session: Session):
37-
"""Make sure we can automatically create and drop a table defined with SQLAlchemy's MetaData object
38-
while exercising all supported types.
41+
42+
def test_create_insert_drop_table_core(base, db_engine, metadata_obj: MetaData):
43+
"""
44+
"""
45+
46+
SampleTable = Table(
47+
"PySQLTest_{}".format(datetime.datetime.utcnow().strftime("%s")),
48+
metadata_obj,
49+
Column("name", String(255)),
50+
Column("episodes", Integer),
51+
Column("some_bool", BOOLEAN)
52+
)
53+
54+
metadata_obj.create_all()
55+
56+
insert_stmt = insert(SampleTable).values(name="Bim Adwunmi", episodes=6, some_bool=True)
57+
58+
with db_engine.connect() as conn:
59+
conn.execute(insert_stmt)
60+
61+
select_stmt = select(SampleTable)
62+
resp = db_engine.execute(select_stmt)
63+
64+
result = resp.fetchall()
65+
66+
assert len(result) == 1
67+
68+
metadata_obj.drop_all()
69+
70+
71+
@skipIf(True, 'Unity catalog must be supported')
72+
def test_create_insert_drop_table_orm(base, session: Session):
73+
"""ORM classes built on the declarative base class must have a primary key.
74+
This is restricted to Unity Catalog.
3975
"""
4076

4177
class SampleObject(base):
4278

4379
__tablename__ = "PySQLTest_{}".format(datetime.datetime.utcnow().strftime("%s"))
4480

45-
name = Column(String, primary_key=True)
81+
name = Column(String(255), primary_key=True)
4682
episodes = Column(Integer),
4783
some_bool = Column(BOOLEAN)
4884

0 commit comments

Comments
 (0)