From cd4e63a24b3e6a4c5b6375bb49a6ec80fc18e305 Mon Sep 17 00:00:00 2001 From: Shounak kulkarni Date: Mon, 4 Mar 2024 10:34:56 +0530 Subject: [PATCH 1/5] allow passing database for pinot queries --- examples/pinot_async.py | 5 +++-- examples/pinot_live.py | 5 +++-- examples/pinot_quickstart_auth_zk.py | 3 ++- examples/pinot_quickstart_batch.py | 5 +++-- examples/pinot_quickstart_hybrid.py | 5 +++-- examples/pinot_quickstart_json_batch.py | 5 +++-- examples/pinot_quickstart_multi_stage.py | 3 ++- examples/pinot_quickstart_timeout.py | 9 ++++++--- pinotdb/sqlalchemy.py | 11 ++++++----- 9 files changed, 31 insertions(+), 20 deletions(-) diff --git a/examples/pinot_async.py b/examples/pinot_async.py index 2752417..ed01ad3 100644 --- a/examples/pinot_async.py +++ b/examples/pinot_async.py @@ -7,7 +7,8 @@ async def run_pinot_async_example(): async with connect_async(host='localhost', port=8000, path='/query/sql', - scheme='http', verify_ssl=False, timeout=10.0) as conn: + scheme='http', verify_ssl=False, timeout=10.0, + extra_request_headers="Database=default") as conn: curs = await conn.execute(""" SELECT count(*) FROM baseballStats @@ -20,7 +21,7 @@ async def run_pinot_async_example(): session = httpx.AsyncClient(verify=False) conn = connect_async( host='localhost', port=8000, path='/query/sql', scheme='http', - verify_ssl=False, session=session) + verify_ssl=False, session=session, extra_request_headers="Database=default") # launch 10 requests in parallel spanning a limit/offset range reqs = [] diff --git a/examples/pinot_live.py b/examples/pinot_live.py index 3b1564c..be0e77d 100644 --- a/examples/pinot_live.py +++ b/examples/pinot_live.py @@ -8,7 +8,8 @@ def run_pinot_live_example() -> None: # Query pinot.live with pinotdb connect - conn = connect(host="pinot-broker.pinot.live", port=443, path="/query/sql", scheme="https") + conn = connect(host="pinot-broker.pinot.live", port=443, path="/query/sql", scheme="https", + extra_request_headers="Database=default") curs = conn.cursor() sql = "SELECT * FROM airlineStats LIMIT 5" print(f"Sending SQL to Pinot: {sql}") @@ -21,7 +22,7 @@ def run_pinot_live_example() -> None: "pinot+https://pinot-broker.pinot.live:443/query/sql?controller=https://pinot-controller.pinot.live/" ) # uses HTTP by default :( - airlineStats = Table("airlineStats", MetaData(bind=engine), autoload=True) + airlineStats = Table("airlineStats", MetaData(bind=engine), autoload=True, schema="default") print(f"\nSending Count(*) SQL to Pinot") query=select([func.count("*")], from_obj=airlineStats) print(engine.execute(query).scalar()) diff --git a/examples/pinot_quickstart_auth_zk.py b/examples/pinot_quickstart_auth_zk.py index e59dd89..b074d12 100644 --- a/examples/pinot_quickstart_auth_zk.py +++ b/examples/pinot_quickstart_auth_zk.py @@ -20,6 +20,7 @@ def run_pinot_quickstart_batch_example() -> None: scheme="http", username="admin", password="verysecret", + extra_request_headers="Database=default", ) curs = conn.cursor() tables = [ @@ -65,7 +66,7 @@ def run_pinot_quickstart_batch_sqlalchemy_example() -> None: # engine = create_engine('pinot+http://localhost:8000/query/sql?controller=http://localhost:9000/') # engine = create_engine('pinot+https://localhost:8000/query/sql?controller=http://localhost:9000/') - baseballStats = Table("baseballStats", MetaData(bind=engine), autoload=True) + baseballStats = Table("baseballStats", MetaData(bind=engine), autoload=True, schema="default") print(f"\nSending Count(*) SQL to Pinot") query = select([func.count("*")], from_obj=baseballStats) print(engine.execute(query).scalar()) diff --git a/examples/pinot_quickstart_batch.py b/examples/pinot_quickstart_batch.py index 4d18cea..af6d5b8 100644 --- a/examples/pinot_quickstart_batch.py +++ b/examples/pinot_quickstart_batch.py @@ -12,7 +12,8 @@ def run_pinot_quickstart_batch_example() -> None: - conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http") + conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http", + extra_request_headers="Database=default") curs = conn.cursor() tables = [ @@ -52,7 +53,7 @@ def run_pinot_quickstart_batch_sqlalchemy_example() -> None: # engine = create_engine('pinot+http://localhost:8000/query/sql?controller=http://localhost:9000/') # engine = create_engine('pinot+https://localhost:8000/query/sql?controller=http://localhost:9000/') - baseballStats = Table("baseballStats", MetaData(bind=engine), autoload=True) + baseballStats = Table("baseballStats", MetaData(bind=engine), autoload=True, schema="default") print(f"\nSending Count(*) SQL to Pinot") query = select([func.count("*")], from_obj=baseballStats) print(engine.execute(query).scalar()) diff --git a/examples/pinot_quickstart_hybrid.py b/examples/pinot_quickstart_hybrid.py index f7ad687..957de71 100644 --- a/examples/pinot_quickstart_hybrid.py +++ b/examples/pinot_quickstart_hybrid.py @@ -11,7 +11,8 @@ ## -d apachepinot/pinot:latest QuickStart -type hybrid def run_pinot_quickstart_hybrid_example() -> None: - conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http") + conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http", + extra_request_headers="Database=default") curs = conn.cursor() sql = "SELECT * FROM airlineStats LIMIT 5" print(f"Sending SQL to Pinot: {sql}") @@ -53,7 +54,7 @@ def run_pinot_quickstart_hybrid_sqlalchemy_example() -> None: # engine = create_engine('pinot+http://localhost:8000/query/sql?controller=http://localhost:9000/') # engine = create_engine('pinot+https://localhost:8000/query/sql?controller=http://localhost:9000/') - airlineStats = Table("airlineStats", MetaData(bind=engine), autoload=True) + airlineStats = Table("airlineStats", MetaData(bind=engine), autoload=True, schema="default") print(f"\nSending Count(*) SQL to Pinot") query=select([func.count("*")], from_obj=airlineStats) print(engine.execute(query).scalar()) diff --git a/examples/pinot_quickstart_json_batch.py b/examples/pinot_quickstart_json_batch.py index 74937b5..e7caad5 100644 --- a/examples/pinot_quickstart_json_batch.py +++ b/examples/pinot_quickstart_json_batch.py @@ -13,7 +13,8 @@ def run_quickstart_json_batch_example() -> None: - conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http") + conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http", + extra_request_headers="Database=default") curs = conn.cursor() sql = "SELECT * FROM githubEvents LIMIT 5" print(f"Sending SQL to Pinot: {sql}") @@ -43,7 +44,7 @@ def run_quickstart_json_batch_sqlalchemy_example() -> None: # engine = create_engine('pinot+http://localhost:8000/query/sql?controller=http://localhost:9000/') # engine = create_engine('pinot+https://localhost:8000/query/sql?controller=http://localhost:9000/') - githubEvents = Table("githubEvents", MetaData(bind=engine), autoload=True) + githubEvents = Table("githubEvents", MetaData(bind=engine), autoload=True, schema="default") print(f"\nSending Count(*) SQL to Pinot\nResults:") query=select([func.count("*")], from_obj=githubEvents) print(engine.execute(query).scalar()) diff --git a/examples/pinot_quickstart_multi_stage.py b/examples/pinot_quickstart_multi_stage.py index 42a1e51..2219b42 100644 --- a/examples/pinot_quickstart_multi_stage.py +++ b/examples/pinot_quickstart_multi_stage.py @@ -11,7 +11,8 @@ ## -d apachepinot/pinot:latest QuickStart -type MULTI_STAGE def run_pinot_quickstart_multi_stage_example() -> None: - conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http", use_multistage_engine=True) + conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http", use_multistage_engine=True, + extra_request_headers="Database=default") curs = conn.cursor() sql = "SELECT a.playerID, a.runs, a.yearID, b.runs, b.yearID FROM baseballStats_OFFLINE AS a JOIN baseballStats_OFFLINE AS b ON a.playerID = b.playerID WHERE a.runs > 160 AND b.runs < 2 LIMIT 10" diff --git a/examples/pinot_quickstart_timeout.py b/examples/pinot_quickstart_timeout.py index 5a687be..677d112 100644 --- a/examples/pinot_quickstart_timeout.py +++ b/examples/pinot_quickstart_timeout.py @@ -11,7 +11,8 @@ def run_pinot_quickstart_timeout_example() -> None: #Test 1 : Try without timeout. The request should succeed. - conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http") + conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http", + extra_request_headers="Database=default") curs = conn.cursor() sql = "SELECT * FROM airlineStats LIMIT 5" print(f"Sending SQL to Pinot: {sql}") @@ -20,7 +21,8 @@ def run_pinot_quickstart_timeout_example() -> None: #Test 2 : Try with timeout=None. The request should succeed. - conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http", timeout=None) + conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http", timeout=None, + extra_request_headers="Database=default") curs = conn.cursor() sql = "SELECT count(*) FROM airlineStats LIMIT 5" print(f"Sending SQL to Pinot: {sql}") @@ -29,7 +31,8 @@ def run_pinot_quickstart_timeout_example() -> None: #Test 3 : Try with a really small timeout. The query should raise an exception. - conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http", timeout=0.001) + conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http", timeout=0.001, + extra_request_headers="Database=default") curs = conn.cursor() sql = "SELECT AirlineID, sum(Cancelled) FROM airlineStats WHERE Year > 2010 GROUP BY AirlineID LIMIT 5" print(f"Sending SQL to Pinot: {sql}") diff --git a/pinotdb/sqlalchemy.py b/pinotdb/sqlalchemy.py index 8ea44ea..3779e12 100644 --- a/pinotdb/sqlalchemy.py +++ b/pinotdb/sqlalchemy.py @@ -132,6 +132,7 @@ class PinotDialect(default.DefaultDialect): preparer = PinotIdentifierPareparer statement_compiler = PinotCompiler type_compiler = PinotTypeCompiler + supports_schemas = True supports_statement_cache = False supports_alter = False supports_pk_autoincrement = False @@ -204,9 +205,9 @@ def create_connect_args(self, url): kwargs = self.update_from_kwargs(kwargs) return ([], kwargs) - def get_metadata_from_controller(self, path): + def get_metadata_from_controller(self, path, database=None): url = parse.urljoin(self._controller, path) - r = requests.get(url, headers={"Accept": "application/json"}, verify=self._verify_ssl, auth= HTTPBasicAuth(self._username, self._password)) + r = requests.get(url, headers={"Accept": "application/json", "Database": database}, verify=self._verify_ssl, auth= HTTPBasicAuth(self._username, self._password)) try: result = r.json() except ValueError as e: @@ -221,13 +222,13 @@ def get_metadata_from_controller(self, path): return result def get_schema_names(self, connection, **kwargs): - return ["default"] + return self.get_metadata_from_controller("/databases") def has_table(self, connection, table_name, schema=None): return table_name in self.get_table_names(connection, schema) def get_table_names(self, connection, schema=None, **kwargs): - return self.get_metadata_from_controller("/tables")["tables"] + return self.get_metadata_from_controller("/tables", schema)["tables"] def get_view_names(self, connection, schema=None, **kwargs): return [] @@ -236,7 +237,7 @@ def get_table_options(self, connection, table_name, schema=None, **kwargs): return {} def get_columns(self, connection, table_name, schema=None, **kwargs): - payload = self.get_metadata_from_controller(f"/tables/{table_name}/schema") + payload = self.get_metadata_from_controller(f"/tables/{table_name}/schema", schema) logger.info( "Getting columns for %s from %s: %s", table_name, self._controller, payload From 9757db3b19b34ea5986053bc3af78f6191d27204 Mon Sep 17 00:00:00 2001 From: Shounak kulkarni Date: Mon, 4 Mar 2024 11:27:39 +0530 Subject: [PATCH 2/5] test fix --- tests/unit/test_sqlalchemy.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_sqlalchemy.py b/tests/unit/test_sqlalchemy.py index 0005008..6314b06 100644 --- a/tests/unit/test_sqlalchemy.py +++ b/tests/unit/test_sqlalchemy.py @@ -103,10 +103,13 @@ def test_cannot_get_metadata_if_broken_json(self): with self.assertRaises(exceptions.DatabaseError): self.dialect.get_metadata_from_controller('some-path') + @responses.activate def test_gets_schema_names(self): + url = f'{self.dialect._controller}/databases' + responses.get(url, json=['default', 'foo', 'bar']) names = self.dialect.get_schema_names('some connection') - self.assertEqual(names, ['default']) + self.assertEqual(names, ['default', 'foo', 'bar']) @responses.activate def test_gets_table_names_from_controller(self): From a7d62198ca9ec6971e331d7612c7ba023cb5e6d7 Mon Sep 17 00:00:00 2001 From: Shounak kulkarni Date: Fri, 15 Mar 2024 18:23:51 +0530 Subject: [PATCH 3/5] response data handling --- pinotdb/sqlalchemy.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/pinotdb/sqlalchemy.py b/pinotdb/sqlalchemy.py index 3779e12..62fc790 100644 --- a/pinotdb/sqlalchemy.py +++ b/pinotdb/sqlalchemy.py @@ -123,6 +123,11 @@ def __init__( ) +def extract_table_name(fqn): + split = fqn.split(".", 2) + return fqn if len(split) == 1 else split[1] + + class PinotDialect(default.DefaultDialect): name = "pinot" @@ -222,13 +227,17 @@ def get_metadata_from_controller(self, path, database=None): return result def get_schema_names(self, connection, **kwargs): - return self.get_metadata_from_controller("/databases") + schema_names = self.get_metadata_from_controller("/databases") + if isinstance(schema_names, (list, tuple)): + return schema_names + else: + return ['default'] def has_table(self, connection, table_name, schema=None): return table_name in self.get_table_names(connection, schema) def get_table_names(self, connection, schema=None, **kwargs): - return self.get_metadata_from_controller("/tables", schema)["tables"] + return list(map(extract_table_name, self.get_metadata_from_controller("/tables", schema)["tables"])) def get_view_names(self, connection, schema=None, **kwargs): return [] From f22162a84fa122b98c58b03222282199632e3535 Mon Sep 17 00:00:00 2001 From: Shounak kulkarni Date: Thu, 21 Mar 2024 20:32:14 +0530 Subject: [PATCH 4/5] pass database using connection string --- pinotdb/db.py | 7 +++++++ pinotdb/sqlalchemy.py | 21 ++++++++++----------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/pinotdb/db.py b/pinotdb/db.py index 9b39d56..b53166d 100644 --- a/pinotdb/db.py +++ b/pinotdb/db.py @@ -335,6 +335,7 @@ def __init__( k, v = header.split("=", 1) extra_headers[k] = v + extra_headers['database'] = kwargs['database'] self.session.headers.update(extra_headers) @check_closed @@ -382,6 +383,12 @@ def finalize_query_payload( queryOptions += ";useMultistageEngine=true" else: queryOptions = "useMultistageEngine=true" + database = self.session.headers['database'] + if database: + if queryOptions: + queryOptions += f";database={database}" + else: + queryOptions = f"database={database}" if queryOptions: return {"sql": query, "queryOptions": queryOptions} else: diff --git a/pinotdb/sqlalchemy.py b/pinotdb/sqlalchemy.py index 62fc790..41490ad 100644 --- a/pinotdb/sqlalchemy.py +++ b/pinotdb/sqlalchemy.py @@ -137,7 +137,7 @@ class PinotDialect(default.DefaultDialect): preparer = PinotIdentifierPareparer statement_compiler = PinotCompiler type_compiler = PinotTypeCompiler - supports_schemas = True + supports_schemas = False supports_statement_cache = False supports_alter = False supports_pk_autoincrement = False @@ -160,6 +160,7 @@ def __init__(self, *args, **kwargs): self._password = None self._debug = False self._verify_ssl = True + self._database = None self.update_from_kwargs(kwargs) def update_from_kwargs(self, givenkw): @@ -173,6 +174,8 @@ def update_from_kwargs(self, givenkw): kwargs["username"] = self._username = kwargs.pop("username") if "password" in kwargs: kwargs["password"] = self._password = kwargs.pop("password") + if "database" in kwargs: + kwargs["database"] = self._database = kwargs.pop("database") kwargs["debug"] = self._debug = bool(kwargs.get("debug", False)) kwargs["verify_ssl"] = self._verify_ssl = (str(kwargs.get("verify_ssl", "true")).lower() in ['true']) logger.info( @@ -210,9 +213,9 @@ def create_connect_args(self, url): kwargs = self.update_from_kwargs(kwargs) return ([], kwargs) - def get_metadata_from_controller(self, path, database=None): + def get_metadata_from_controller(self, path): url = parse.urljoin(self._controller, path) - r = requests.get(url, headers={"Accept": "application/json", "Database": database}, verify=self._verify_ssl, auth= HTTPBasicAuth(self._username, self._password)) + r = requests.get(url, headers={"Accept": "application/json", "Database": self._database}, verify=self._verify_ssl, auth= HTTPBasicAuth(self._username, self._password)) try: result = r.json() except ValueError as e: @@ -227,17 +230,13 @@ def get_metadata_from_controller(self, path, database=None): return result def get_schema_names(self, connection, **kwargs): - schema_names = self.get_metadata_from_controller("/databases") - if isinstance(schema_names, (list, tuple)): - return schema_names - else: - return ['default'] + return [self._database] def has_table(self, connection, table_name, schema=None): return table_name in self.get_table_names(connection, schema) def get_table_names(self, connection, schema=None, **kwargs): - return list(map(extract_table_name, self.get_metadata_from_controller("/tables", schema)["tables"])) + return list(map(extract_table_name, self.get_metadata_from_controller("/tables")["tables"])) def get_view_names(self, connection, schema=None, **kwargs): return [] @@ -246,7 +245,7 @@ def get_table_options(self, connection, table_name, schema=None, **kwargs): return {} def get_columns(self, connection, table_name, schema=None, **kwargs): - payload = self.get_metadata_from_controller(f"/tables/{table_name}/schema", schema) + payload = self.get_metadata_from_controller(f"/tables/{table_name}/schema") logger.info( "Getting columns for %s from %s: %s", table_name, self._controller, payload @@ -306,7 +305,7 @@ def _check_unicode_returns(self, connection, additional_tests=None): def _check_unicode_description(self, connection): return True - + # Fix for SQL Alchemy error def _json_deserializer(self, content: any): """ From 6d8a0266884361575366832e4dacd412f19bea6f Mon Sep 17 00:00:00 2001 From: Shounak kulkarni Date: Thu, 21 Mar 2024 22:42:22 +0530 Subject: [PATCH 5/5] omit passing database by query params --- pinotdb/db.py | 12 +++--------- pinotdb/sqlalchemy.py | 11 +++++++++-- tests/unit/test_sqlalchemy.py | 9 ++++----- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/pinotdb/db.py b/pinotdb/db.py index b53166d..25ce761 100644 --- a/pinotdb/db.py +++ b/pinotdb/db.py @@ -162,7 +162,7 @@ def close(self): except exceptions.Error: pass # already closed # if we're managing the httpx session, attempt to close it - if not self.is_session_external: + if not self.is_session_external and self.session: self.session.close() @check_closed @@ -334,8 +334,8 @@ def __init__( for header in extra_request_headers.split(","): k, v = header.split("=", 1) extra_headers[k] = v - - extra_headers['database'] = kwargs['database'] + if 'database' in kwargs: + extra_headers['database'] = kwargs['database'] self.session.headers.update(extra_headers) @check_closed @@ -383,12 +383,6 @@ def finalize_query_payload( queryOptions += ";useMultistageEngine=true" else: queryOptions = "useMultistageEngine=true" - database = self.session.headers['database'] - if database: - if queryOptions: - queryOptions += f";database={database}" - else: - queryOptions = f"database={database}" if queryOptions: return {"sql": query, "queryOptions": queryOptions} else: diff --git a/pinotdb/sqlalchemy.py b/pinotdb/sqlalchemy.py index 41490ad..bfa8af5 100644 --- a/pinotdb/sqlalchemy.py +++ b/pinotdb/sqlalchemy.py @@ -230,13 +230,20 @@ def get_metadata_from_controller(self, path): return result def get_schema_names(self, connection, **kwargs): - return [self._database] + if self._database: + return [self._database] + else: + return ['default'] def has_table(self, connection, table_name, schema=None): return table_name in self.get_table_names(connection, schema) def get_table_names(self, connection, schema=None, **kwargs): - return list(map(extract_table_name, self.get_metadata_from_controller("/tables")["tables"])) + resp = self.get_metadata_from_controller("/tables") + if 'tables' in resp: + return list(map(extract_table_name, resp["tables"])) + else: + return [] def get_view_names(self, connection, schema=None, **kwargs): return [] diff --git a/tests/unit/test_sqlalchemy.py b/tests/unit/test_sqlalchemy.py index 6314b06..31b7c4f 100644 --- a/tests/unit/test_sqlalchemy.py +++ b/tests/unit/test_sqlalchemy.py @@ -103,13 +103,12 @@ def test_cannot_get_metadata_if_broken_json(self): with self.assertRaises(exceptions.DatabaseError): self.dialect.get_metadata_from_controller('some-path') - @responses.activate def test_gets_schema_names(self): - url = f'{self.dialect._controller}/databases' - responses.get(url, json=['default', 'foo', 'bar']) names = self.dialect.get_schema_names('some connection') - - self.assertEqual(names, ['default', 'foo', 'bar']) + self.assertEqual(names, ['default']) + self.dialect._database = 'foo' + names = self.dialect.get_schema_names('some connection') + self.assertEqual(names, ['foo']) @responses.activate def test_gets_table_names_from_controller(self):