From 466cbe98d2b49bbe2beaa14b1f8adce92824cf63 Mon Sep 17 00:00:00 2001 From: cjench Date: Sat, 17 Apr 2021 17:11:27 -0400 Subject: [PATCH 1/5] BUG: Fix Dataframe constructor called with sql query missing column names --- pandas/compat/_optional.py | 1 + pandas/core/frame.py | 13 ++++++++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/pandas/compat/_optional.py b/pandas/compat/_optional.py index a26da75d921ef..b5a18a1fd2674 100644 --- a/pandas/compat/_optional.py +++ b/pandas/compat/_optional.py @@ -25,6 +25,7 @@ "s3fs": "0.4.0", "scipy": "1.2.0", "sqlalchemy": "1.2.8", + "sql_metadata": None, "tables": "3.5.1", "tabulate": "0.8.7", "xarray": "0.12.3", diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 2736560def2cb..45a555cecee4d 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -437,6 +437,7 @@ 3 bar 8 """ +# DataFrame helper functions # ----------------------------------------------------------------------- # DataFrame class @@ -566,7 +567,6 @@ def __init__( dtype: Dtype | None = None, copy: bool | None = None, ): - if copy is None: if isinstance(data, dict) or data is None: # retain pre-GH#38939 default behavior @@ -669,6 +669,17 @@ def __init__( # For data is list-like, or Iterable (will consume into list) elif is_list_like(data): if not isinstance(data, (abc.Sequence, ExtensionArray)): + # For data is a sqlalchemy query, extract column names + if str(type(data)) == "": + query = str(data) + sql_mt = import_optional_dependency("sql_metadata") + # Extract column names using sql_metadata + columns = sql_mt.get_query_columns(str(data)) + # Sanitize column names + for i in range(len(columns)): + if columns[i].find('.') != -1: + columns[i] = columns[i][columns[i].find('.')+1:] + data = list(data) if len(data) > 0: if is_dataclass(data[0]): From eff6aa15b17c2228f958e70627e3f6e60f03f2cc Mon Sep 17 00:00:00 2001 From: cjench Date: Sat, 17 Apr 2021 17:24:19 -0400 Subject: [PATCH 2/5] linting issue --- pandas/core/frame.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 45a555cecee4d..eb606bf7ba2d3 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -678,7 +678,7 @@ def __init__( # Sanitize column names for i in range(len(columns)): if columns[i].find('.') != -1: - columns[i] = columns[i][columns[i].find('.')+1:] + columns[i] = columns[i][columns[i].find('.') + 1:] data = list(data) if len(data) > 0: From 33f001376ac4393c138fc8f7bc71461e803d916f Mon Sep 17 00:00:00 2001 From: jnchngc Date: Fri, 23 Apr 2021 15:23:02 -0400 Subject: [PATCH 3/5] fixing CI issue --- pandas/core/frame.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/frame.py b/pandas/core/frame.py index eb606bf7ba2d3..31f0b58d658b7 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -674,7 +674,7 @@ def __init__( query = str(data) sql_mt = import_optional_dependency("sql_metadata") # Extract column names using sql_metadata - columns = sql_mt.get_query_columns(str(data)) + columns = list(sql_mt.get_query_columns(str(data))) # Sanitize column names for i in range(len(columns)): if columns[i].find('.') != -1: From f88a3030b85d2ee2fd872dd0f08bc0b4ad5af899 Mon Sep 17 00:00:00 2001 From: jnchngc Date: Fri, 23 Apr 2021 16:45:56 -0400 Subject: [PATCH 4/5] unused variable --- pandas/core/frame.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 31f0b58d658b7..6f4c4c26095b3 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -671,14 +671,13 @@ def __init__( if not isinstance(data, (abc.Sequence, ExtensionArray)): # For data is a sqlalchemy query, extract column names if str(type(data)) == "": - query = str(data) sql_mt = import_optional_dependency("sql_metadata") # Extract column names using sql_metadata columns = list(sql_mt.get_query_columns(str(data))) # Sanitize column names for i in range(len(columns)): - if columns[i].find('.') != -1: - columns[i] = columns[i][columns[i].find('.') + 1:] + if columns[i].find(".") != -1: + columns[i] = columns[i][columns[i].find(".") + 1 :] data = list(data) if len(data) > 0: From c6fc88b8e6500cd128004f8cb722c452bfed5691 Mon Sep 17 00:00:00 2001 From: jnchngc Date: Fri, 23 Apr 2021 17:46:08 -0400 Subject: [PATCH 5/5] style issue --- pandas/core/frame.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 6f4c4c26095b3..b2f6893eb0752 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -674,7 +674,7 @@ def __init__( sql_mt = import_optional_dependency("sql_metadata") # Extract column names using sql_metadata columns = list(sql_mt.get_query_columns(str(data))) - # Sanitize column names + # Sanitize column names and remove everything before . character for i in range(len(columns)): if columns[i].find(".") != -1: columns[i] = columns[i][columns[i].find(".") + 1 :]