Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions examples/pinot_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

async def run_pinot_async_example():
async with connect_async(host='localhost', port=8000, path='/query/sql',
scheme='http', verify_ssl=False, timeout=10.0) as conn:
scheme='http', verify_ssl=False, timeout=10.0,
extra_request_headers="Database=default") as conn:
curs = await conn.execute("""
SELECT count(*)
FROM baseballStats
Expand All @@ -20,7 +21,7 @@ async def run_pinot_async_example():
session = httpx.AsyncClient(verify=False)
conn = connect_async(
host='localhost', port=8000, path='/query/sql', scheme='http',
verify_ssl=False, session=session)
verify_ssl=False, session=session, extra_request_headers="Database=default")

# launch 10 requests in parallel spanning a limit/offset range
reqs = []
Expand Down
5 changes: 3 additions & 2 deletions examples/pinot_live.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

def run_pinot_live_example() -> None:
# Query pinot.live with pinotdb connect
conn = connect(host="pinot-broker.pinot.live", port=443, path="/query/sql", scheme="https")
conn = connect(host="pinot-broker.pinot.live", port=443, path="/query/sql", scheme="https",
extra_request_headers="Database=default")
curs = conn.cursor()
sql = "SELECT * FROM airlineStats LIMIT 5"
print(f"Sending SQL to Pinot: {sql}")
Expand All @@ -21,7 +22,7 @@ def run_pinot_live_example() -> None:
"pinot+https://pinot-broker.pinot.live:443/query/sql?controller=https://pinot-controller.pinot.live/"
) # uses HTTP by default :(

airlineStats = Table("airlineStats", MetaData(bind=engine), autoload=True)
airlineStats = Table("airlineStats", MetaData(bind=engine), autoload=True, schema="default")
print(f"\nSending Count(*) SQL to Pinot")
query=select([func.count("*")], from_obj=airlineStats)
print(engine.execute(query).scalar())
Expand Down
3 changes: 2 additions & 1 deletion examples/pinot_quickstart_auth_zk.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def run_pinot_quickstart_batch_example() -> None:
scheme="http",
username="admin",
password="verysecret",
extra_request_headers="Database=default",
)
curs = conn.cursor()
tables = [
Expand Down Expand Up @@ -65,7 +66,7 @@ def run_pinot_quickstart_batch_sqlalchemy_example() -> None:
# engine = create_engine('pinot+http://localhost:8000/query/sql?controller=http://localhost:9000/')
# engine = create_engine('pinot+https://localhost:8000/query/sql?controller=http://localhost:9000/')

baseballStats = Table("baseballStats", MetaData(bind=engine), autoload=True)
baseballStats = Table("baseballStats", MetaData(bind=engine), autoload=True, schema="default")
print(f"\nSending Count(*) SQL to Pinot")
query = select([func.count("*")], from_obj=baseballStats)
print(engine.execute(query).scalar())
Expand Down
5 changes: 3 additions & 2 deletions examples/pinot_quickstart_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@


def run_pinot_quickstart_batch_example() -> None:
conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http")
conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http",
extra_request_headers="Database=default")
curs = conn.cursor()

tables = [
Expand Down Expand Up @@ -52,7 +53,7 @@ def run_pinot_quickstart_batch_sqlalchemy_example() -> None:
# engine = create_engine('pinot+http://localhost:8000/query/sql?controller=http://localhost:9000/')
# engine = create_engine('pinot+https://localhost:8000/query/sql?controller=http://localhost:9000/')

baseballStats = Table("baseballStats", MetaData(bind=engine), autoload=True)
baseballStats = Table("baseballStats", MetaData(bind=engine), autoload=True, schema="default")
print(f"\nSending Count(*) SQL to Pinot")
query = select([func.count("*")], from_obj=baseballStats)
print(engine.execute(query).scalar())
Expand Down
5 changes: 3 additions & 2 deletions examples/pinot_quickstart_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
## -d apachepinot/pinot:latest QuickStart -type hybrid

def run_pinot_quickstart_hybrid_example() -> None:
conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http")
conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http",
extra_request_headers="Database=default")
curs = conn.cursor()
sql = "SELECT * FROM airlineStats LIMIT 5"
print(f"Sending SQL to Pinot: {sql}")
Expand Down Expand Up @@ -53,7 +54,7 @@ def run_pinot_quickstart_hybrid_sqlalchemy_example() -> None:
# engine = create_engine('pinot+http://localhost:8000/query/sql?controller=http://localhost:9000/')
# engine = create_engine('pinot+https://localhost:8000/query/sql?controller=http://localhost:9000/')

airlineStats = Table("airlineStats", MetaData(bind=engine), autoload=True)
airlineStats = Table("airlineStats", MetaData(bind=engine), autoload=True, schema="default")
print(f"\nSending Count(*) SQL to Pinot")
query=select([func.count("*")], from_obj=airlineStats)
print(engine.execute(query).scalar())
Expand Down
5 changes: 3 additions & 2 deletions examples/pinot_quickstart_json_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@


def run_quickstart_json_batch_example() -> None:
conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http")
conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http",
extra_request_headers="Database=default")
curs = conn.cursor()
sql = "SELECT * FROM githubEvents LIMIT 5"
print(f"Sending SQL to Pinot: {sql}")
Expand Down Expand Up @@ -43,7 +44,7 @@ def run_quickstart_json_batch_sqlalchemy_example() -> None:
# engine = create_engine('pinot+http://localhost:8000/query/sql?controller=http://localhost:9000/')
# engine = create_engine('pinot+https://localhost:8000/query/sql?controller=http://localhost:9000/')

githubEvents = Table("githubEvents", MetaData(bind=engine), autoload=True)
githubEvents = Table("githubEvents", MetaData(bind=engine), autoload=True, schema="default")
print(f"\nSending Count(*) SQL to Pinot\nResults:")
query=select([func.count("*")], from_obj=githubEvents)
print(engine.execute(query).scalar())
Expand Down
3 changes: 2 additions & 1 deletion examples/pinot_quickstart_multi_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
## -d apachepinot/pinot:latest QuickStart -type MULTI_STAGE

def run_pinot_quickstart_multi_stage_example() -> None:
conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http", use_multistage_engine=True)
conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http", use_multistage_engine=True,
extra_request_headers="Database=default")
curs = conn.cursor()

sql = "SELECT a.playerID, a.runs, a.yearID, b.runs, b.yearID FROM baseballStats_OFFLINE AS a JOIN baseballStats_OFFLINE AS b ON a.playerID = b.playerID WHERE a.runs > 160 AND b.runs < 2 LIMIT 10"
Expand Down
9 changes: 6 additions & 3 deletions examples/pinot_quickstart_timeout.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ def run_pinot_quickstart_timeout_example() -> None:

#Test 1 : Try without timeout. The request should succeed.

conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http")
conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http",
extra_request_headers="Database=default")
curs = conn.cursor()
sql = "SELECT * FROM airlineStats LIMIT 5"
print(f"Sending SQL to Pinot: {sql}")
Expand All @@ -20,7 +21,8 @@ def run_pinot_quickstart_timeout_example() -> None:

#Test 2 : Try with timeout=None. The request should succeed.

conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http", timeout=None)
conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http", timeout=None,
extra_request_headers="Database=default")
curs = conn.cursor()
sql = "SELECT count(*) FROM airlineStats LIMIT 5"
print(f"Sending SQL to Pinot: {sql}")
Expand All @@ -29,7 +31,8 @@ def run_pinot_quickstart_timeout_example() -> None:

#Test 3 : Try with a really small timeout. The query should raise an exception.

conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http", timeout=0.001)
conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http", timeout=0.001,
extra_request_headers="Database=default")
curs = conn.cursor()
sql = "SELECT AirlineID, sum(Cancelled) FROM airlineStats WHERE Year > 2010 GROUP BY AirlineID LIMIT 5"
print(f"Sending SQL to Pinot: {sql}")
Expand Down
5 changes: 3 additions & 2 deletions pinotdb/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def close(self):
except exceptions.Error:
pass # already closed
# if we're managing the httpx session, attempt to close it
if not self.is_session_external:
if not self.is_session_external and self.session:
self.session.close()

@check_closed
Expand Down Expand Up @@ -334,7 +334,8 @@ def __init__(
for header in extra_request_headers.split(","):
k, v = header.split("=", 1)
extra_headers[k] = v

if 'database' in kwargs:
extra_headers['database'] = kwargs['database']
self.session.headers.update(extra_headers)

@check_closed
Expand Down
24 changes: 20 additions & 4 deletions pinotdb/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,11 @@ def __init__(
)


def extract_table_name(fqn):
split = fqn.split(".", 2)
return fqn if len(split) == 1 else split[1]


class PinotDialect(default.DefaultDialect):

name = "pinot"
Expand All @@ -132,6 +137,7 @@ class PinotDialect(default.DefaultDialect):
preparer = PinotIdentifierPareparer
statement_compiler = PinotCompiler
type_compiler = PinotTypeCompiler
supports_schemas = False
supports_statement_cache = False
supports_alter = False
supports_pk_autoincrement = False
Expand All @@ -154,6 +160,7 @@ def __init__(self, *args, **kwargs):
self._password = None
self._debug = False
self._verify_ssl = True
self._database = None
self.update_from_kwargs(kwargs)

def update_from_kwargs(self, givenkw):
Expand All @@ -167,6 +174,8 @@ def update_from_kwargs(self, givenkw):
kwargs["username"] = self._username = kwargs.pop("username")
if "password" in kwargs:
kwargs["password"] = self._password = kwargs.pop("password")
if "database" in kwargs:
kwargs["database"] = self._database = kwargs.pop("database")
kwargs["debug"] = self._debug = bool(kwargs.get("debug", False))
kwargs["verify_ssl"] = self._verify_ssl = (str(kwargs.get("verify_ssl", "true")).lower() in ['true'])
logger.info(
Expand Down Expand Up @@ -206,7 +215,7 @@ def create_connect_args(self, url):

def get_metadata_from_controller(self, path):
url = parse.urljoin(self._controller, path)
r = requests.get(url, headers={"Accept": "application/json"}, verify=self._verify_ssl, auth= HTTPBasicAuth(self._username, self._password))
r = requests.get(url, headers={"Accept": "application/json", "Database": self._database}, verify=self._verify_ssl, auth= HTTPBasicAuth(self._username, self._password))
try:
result = r.json()
except ValueError as e:
Expand All @@ -221,13 +230,20 @@ def get_metadata_from_controller(self, path):
return result

def get_schema_names(self, connection, **kwargs):
return ["default"]
if self._database:
return [self._database]
else:
return ['default']

def has_table(self, connection, table_name, schema=None):
return table_name in self.get_table_names(connection, schema)

def get_table_names(self, connection, schema=None, **kwargs):
return self.get_metadata_from_controller("/tables")["tables"]
resp = self.get_metadata_from_controller("/tables")
if 'tables' in resp:
return list(map(extract_table_name, resp["tables"]))
else:
return []

def get_view_names(self, connection, schema=None, **kwargs):
return []
Expand Down Expand Up @@ -296,7 +312,7 @@ def _check_unicode_returns(self, connection, additional_tests=None):

def _check_unicode_description(self, connection):
return True

# Fix for SQL Alchemy error
def _json_deserializer(self, content: any):
"""
Expand Down
4 changes: 3 additions & 1 deletion tests/unit/test_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,10 @@ def test_cannot_get_metadata_if_broken_json(self):

def test_gets_schema_names(self):
names = self.dialect.get_schema_names('some connection')

self.assertEqual(names, ['default'])
self.dialect._database = 'foo'
names = self.dialect.get_schema_names('some connection')
self.assertEqual(names, ['foo'])

@responses.activate
def test_gets_table_names_from_controller(self):
Expand Down