Skip to content

Commit aa360f9

Browse files
authored
modified TableDescription formatting (#616)
1 parent 4b938ed commit aa360f9

File tree

4 files changed

+409
-70
lines changed

4 files changed

+409
-70
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# CHANGELOG
22

33
## 0.7.10dev
4+
* [Feature] Modified `TableDescription` to add styling, generate messages and format the calculated outputs (#459)
45
* [Feature] Support flexible spacing `myvar=<<` operator ([#525](https://github.com/ploomber/jupysql/issues/525))
5-
66
* [Doc] Modified integrations content to ensure they're all consistent (#523)
77
* [Doc] Document --persist-replace in API section (#539)
88
* [Fix] Fixed CI issue by updating `invalid_connection_string_duckdb` in `test_magic.py` (#631)

src/sql/inspect.py

Lines changed: 207 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import math
99
from sql import util
1010
from IPython.core.display import HTML
11+
import uuid
1112

1213

1314
def _get_inspector(conn):
@@ -77,6 +78,94 @@ def _get_row_with_most_keys(rows):
7778
return list(rows[max_idx])
7879

7980

81+
def _is_numeric(value):
82+
"""Check if a column has numeric and not categorical datatype"""
83+
try:
84+
if isinstance(value, bool):
85+
return False
86+
float(value) # Try to convert the value to float
87+
return True
88+
except (TypeError, ValueError):
89+
return False
90+
91+
92+
def _is_numeric_as_str(column, value):
93+
"""Check if a column contains numerical data stored as `str`"""
94+
try:
95+
if isinstance(value, str) and _is_numeric(value):
96+
return True
97+
return False
98+
except ValueError:
99+
pass
100+
101+
102+
def _generate_column_styles(
103+
column_indices, unique_id, background_color="#FFFFCC", text_color="black"
104+
):
105+
"""
106+
Generate CSS styles to change the background-color of all columns
107+
with data-type mismatch.
108+
109+
Parameters
110+
----------
111+
column_indices (list): List of column indices with data-type mismatch.
112+
unique_id (str): Unique ID for the current table.
113+
background_color (str, optional): Background color for the mismatched columns.
114+
text_color (str, optional): Text color for the mismatched columns.
115+
116+
Returns:
117+
str: HTML style tags containing the CSS styles for the mismatched columns.
118+
"""
119+
120+
styles = ""
121+
for index in column_indices:
122+
styles = f"""{styles}
123+
#profile-table-{unique_id} td:nth-child({index + 1}) {{
124+
background-color: {background_color};
125+
color: {text_color};
126+
}}
127+
"""
128+
return f"<style>{styles}</style>"
129+
130+
131+
def _generate_message(column_indices, columns):
132+
"""Generate a message indicating all columns with a datatype mismatch"""
133+
message = "Columns "
134+
for c in column_indices:
135+
col = columns[c - 1]
136+
message = f"{message}<code>{col}</code>"
137+
message = (
138+
f"{message} have a datatype mismatch -> numeric values stored as a string."
139+
)
140+
message = f"{message} <br> Cannot calculate mean/min/max/std/percentiles"
141+
return message
142+
143+
144+
def _assign_column_specific_stats(col_stats, is_numeric):
145+
"""
146+
Assign NaN values to categorical/numerical specific statistic.
147+
148+
Parameters
149+
----------
150+
col_stats (dict): Dictionary containing column statistics.
151+
is_numeric (bool): Flag indicating whether the column is numeric or not.
152+
153+
Returns:
154+
dict: Updated col_stats dictionary.
155+
"""
156+
categorical_stats = ["top", "freq"]
157+
numerical_stats = ["mean", "min", "max", "std", "25%", "50%", "75%"]
158+
159+
if is_numeric:
160+
for stat in categorical_stats:
161+
col_stats[stat] = math.nan
162+
else:
163+
for stat in numerical_stats:
164+
col_stats[stat] = math.nan
165+
166+
return col_stats
167+
168+
80169
@modify_exceptions
81170
class Columns(DatabaseInspection):
82171
"""
@@ -108,27 +197,36 @@ def __init__(self, name, schema, conn=None) -> None:
108197
@modify_exceptions
109198
class TableDescription(DatabaseInspection):
110199
"""
111-
Generates descriptive statistics.
200+
Generates descriptive statistics.
201+
202+
--------------------------------------
203+
Descriptive statistics are:
204+
205+
Count - Number of all not None values
112206
113-
Descriptive statistics are:
207+
Mean - Mean of the values
114208
115-
Count - Number of all not None values
209+
Max - Maximum of the values in the object.
116210
117-
Mean - Mean of the values
211+
Min - Minimum of the values in the object.
118212
119-
Max - Maximum of the values in the object.
213+
STD - Standard deviation of the observations
120214
121-
Min - Minimum of the values in the object.
215+
25h, 50h and 75h percentiles
122216
123-
STD - Standard deviation of the observations
217+
Unique - Number of not None unique values
124218
125-
25h, 50h and 75h percentiles
219+
Top - The most frequent value
126220
127-
Unique - Number of not None unique values
221+
Freq - Frequency of the top value
128222
129-
Top - The most frequent value
223+
------------------------------------------
224+
Following statistics will be calculated for :-
130225
131-
Freq - Frequency of the top value
226+
Categorical columns - [Count, Unique, Top, Freq]
227+
228+
Numerical columns - [Count, Unique, Mean, Max, Min,
229+
STD, 25h, 50h and 75h percentiles]
132230
133231
"""
134232

@@ -141,18 +239,35 @@ def __init__(self, table_name, schema=None) -> None:
141239
columns_query_result = sql.run.raw_run(
142240
Connection.current, f"SELECT * FROM {table_name} WHERE 1=0"
143241
)
144-
145242
if Connection.is_custom_connection():
146243
columns = [i[0] for i in columns_query_result.description]
147244
else:
148245
columns = columns_query_result.keys()
149246

150247
table_stats = dict({})
151248
columns_to_include_in_report = set()
249+
columns_with_styles = []
250+
message_check = False
152251

153-
for column in columns:
252+
for i, column in enumerate(columns):
154253
table_stats[column] = dict()
155254

255+
# check the datatype of a column
256+
try:
257+
result = sql.run.raw_run(
258+
Connection.current, f"""SELECT {column} FROM {table_name} LIMIT 1"""
259+
).fetchone()
260+
261+
value = result[0]
262+
is_numeric = isinstance(value, (int, float)) or (
263+
isinstance(value, str) and _is_numeric(value)
264+
)
265+
except ValueError:
266+
is_numeric = True
267+
268+
if _is_numeric_as_str(column, value):
269+
columns_with_styles.append(i + 1)
270+
message_check = True
156271
# Note: index is reserved word in sqlite
157272
try:
158273
result_col_freq_values = sql.run.raw_run(
@@ -183,10 +298,12 @@ def __init__(self, table_name, schema=None) -> None:
183298
""",
184299
).fetchall()
185300

186-
table_stats[column]["min"] = result_value_values[0][0]
187-
table_stats[column]["max"] = result_value_values[0][1]
301+
columns_to_include_in_report.update(["count", "min", "max"])
188302
table_stats[column]["count"] = result_value_values[0][2]
189303

304+
table_stats[column]["min"] = round(result_value_values[0][0], 4)
305+
table_stats[column]["max"] = round(result_value_values[0][1], 4)
306+
190307
columns_to_include_in_report.update(["count", "min", "max"])
191308

192309
except Exception:
@@ -204,9 +321,7 @@ def __init__(self, table_name, schema=None) -> None:
204321
""",
205322
).fetchall()
206323
table_stats[column]["unique"] = result_value_values[0][0]
207-
208324
columns_to_include_in_report.update(["unique"])
209-
210325
except Exception:
211326
pass
212327

@@ -220,8 +335,8 @@ def __init__(self, table_name, schema=None) -> None:
220335
""",
221336
).fetchall()
222337

223-
table_stats[column]["mean"] = float(results_avg[0][0])
224338
columns_to_include_in_report.update(["mean"])
339+
table_stats[column]["mean"] = format(float(results_avg[0][0]), ".4f")
225340

226341
except Exception:
227342
table_stats[column]["mean"] = math.nan
@@ -246,11 +361,10 @@ def __init__(self, table_name, schema=None) -> None:
246361
""",
247362
).fetchall()
248363

364+
columns_to_include_in_report.update(special_numeric_keys)
249365
for i, key in enumerate(special_numeric_keys):
250366
# r_key = f'key_{key.replace("%", "")}'
251-
table_stats[column][key] = float(result[0][i])
252-
253-
columns_to_include_in_report.update(special_numeric_keys)
367+
table_stats[column][key] = format(float(result[0][i]), ".4f")
254368

255369
except TypeError:
256370
# for non numeric values
@@ -268,39 +382,97 @@ def __init__(self, table_name, schema=None) -> None:
268382
# We ignore the cell stats for such case.
269383
pass
270384

385+
table_stats[column] = _assign_column_specific_stats(
386+
table_stats[column], is_numeric
387+
)
388+
271389
self._table = PrettyTable()
272390
self._table.field_names = [" "] + list(table_stats.keys())
273391

274-
rows = list(columns_to_include_in_report)
275-
rows.sort(reverse=True)
276-
for row in rows:
277-
values = [row]
278-
for column in table_stats:
279-
if row in table_stats[column]:
280-
value = table_stats[column][row]
281-
else:
282-
value = ""
283-
value = util.convert_to_scientific(value)
284-
values.append(value)
392+
custom_order = [
393+
"count",
394+
"unique",
395+
"top",
396+
"freq",
397+
"mean",
398+
"std",
399+
"min",
400+
"25%",
401+
"50%",
402+
"75%",
403+
"max",
404+
]
405+
406+
for row in custom_order:
407+
if row.lower() in [r.lower() for r in columns_to_include_in_report]:
408+
values = [row]
409+
for column in table_stats:
410+
if row in table_stats[column]:
411+
value = table_stats[column][row]
412+
else:
413+
value = ""
414+
# value = util.convert_to_scientific(value)
415+
values.append(value)
416+
417+
self._table.add_row(values)
418+
419+
unique_id = str(uuid.uuid4()).replace("-", "")
420+
column_styles = _generate_column_styles(columns_with_styles, unique_id)
421+
422+
if message_check:
423+
message_content = _generate_message(columns_with_styles, list(columns))
424+
warning_background = "#FFFFCC"
425+
warning_title = "Warning: "
426+
else:
427+
message_content = ""
428+
warning_background = "white"
429+
warning_title = ""
430+
431+
database = Connection.current.url
432+
db_driver = Connection.current._get_curr_sqlalchemy_connection_info()["driver"]
433+
if "duckdb" in database:
434+
db_message = ""
435+
else:
436+
db_message = f"""Following statistics are not available in
437+
{db_driver}: STD, 25%, 50%, 75%"""
438+
439+
db_html = (
440+
f"<div style='position: sticky; left: 0; padding: 10px; "
441+
f"font-size: 12px; color: #FFA500'>"
442+
f"<strong></strong> {db_message}"
443+
"</div>"
444+
)
285445

286-
self._table.add_row(values)
446+
message_html = (
447+
f"<div style='position: sticky; left: 0; padding: 10px; "
448+
f"font-size: 12px; color: black; background-color: {warning_background};'>"
449+
f"<strong>{warning_title}</strong> {message_content}"
450+
"</div>"
451+
)
287452

288453
# Inject css to html to make first column sticky
289454
sticky_column_css = """<style>
290455
#profile-table td:first-child {
291456
position: sticky;
292457
left: 0;
293458
background-color: var(--jp-cell-editor-background);
459+
font-weight: bold;
294460
}
295461
#profile-table thead tr th:first-child {
296462
position: sticky;
297463
left: 0;
298464
background-color: var(--jp-cell-editor-background);
465+
font-weight: bold; /* Adding bold text */
299466
}
300467
</style>"""
301468
self._table_html = HTML(
302-
sticky_column_css
303-
+ self._table.get_html_string(attributes={"id": "profile-table"})
469+
db_html
470+
+ sticky_column_css
471+
+ column_styles
472+
+ self._table.get_html_string(
473+
attributes={"id": f"profile-table-{unique_id}"}
474+
)
475+
+ message_html
304476
).__html__()
305477

306478
self._table_txt = self._table.get_string()

0 commit comments

Comments
 (0)