Skip to content

Commit a64b0e1

Browse files
committed
Move table shape into data explorer get_state request and test pandas state requests (posit-dev/positron-python#393)
* Move table shape into get_state request and test pandas state requests * Handle parameter change, cleaning * fix pyright * Rename shape again * Dictify state result
1 parent 374bcc3 commit a64b0e1

File tree

3 files changed

+114
-43
lines changed

3 files changed

+114
-43
lines changed

extensions/positron-python/pythonFiles/positron/positron_ipykernel/data_explorer.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323

2424
from .access_keys import decode_access_key
2525
from .data_explorer_comm import (
26-
BackendState,
2726
ColumnFilter,
2827
ColumnFilterCompareOp,
2928
ColumnSchema,
@@ -42,6 +41,8 @@
4241
SetSortColumnsRequest,
4342
TableData,
4443
TableSchema,
44+
TableShape,
45+
TableState,
4546
)
4647
from .positron_comm import CommMessage, PositronComm
4748
from .third_party import pd_
@@ -128,7 +129,7 @@ def get_column_profile(self, request: GetColumnProfileRequest):
128129
return self._get_column_profile(request.params.profile_type, request.params.column_index)
129130

130131
def get_state(self, request: GetStateRequest):
131-
return self._get_state()
132+
return self._get_state().dict()
132133

133134
def _get_schema(self, column_start: int, num_columns: int) -> TableSchema:
134135
raise NotImplementedError
@@ -154,7 +155,7 @@ def _get_column_profile(
154155
) -> None:
155156
raise NotImplementedError
156157

157-
def _get_state(self) -> BackendState:
158+
def _get_state(self) -> TableState:
158159
raise NotImplementedError
159160

160161

@@ -185,6 +186,7 @@ class PandasView(DataExplorerTableView):
185186
"float64": "number",
186187
"mixed-integer": "number",
187188
"mixed-integer-float": "number",
189+
"mixed": "unknown",
188190
"decimal": "number",
189191
"complex": "number",
190192
"categorical": "categorical",
@@ -291,11 +293,7 @@ def _get_schema(self, column_start: int, num_columns: int) -> TableSchema:
291293
)
292294
column_schemas.append(col_schema)
293295

294-
return TableSchema(
295-
columns=column_schemas,
296-
num_rows=self.table.shape[0],
297-
total_num_columns=self.table.shape[1],
298-
)
296+
return TableSchema(columns=column_schemas)
299297

300298
def _get_data_values(
301299
self, row_start: int, num_rows: int, column_indices: Sequence[int]
@@ -420,8 +418,12 @@ def _get_column_profile(
420418
) -> None:
421419
pass
422420

423-
def _get_state(self) -> BackendState:
424-
return BackendState(filters=self.filters, sort_keys=self.sort_keys)
421+
def _get_state(self) -> TableState:
422+
return TableState(
423+
table_shape=TableShape(num_rows=self.table.shape[0], num_columns=self.table.shape[1]),
424+
filters=self.filters,
425+
sort_keys=self.sort_keys,
426+
)
425427

426428

427429
COMPARE_OPS = {
@@ -503,7 +505,13 @@ def shutdown(self) -> None:
503505
for comm_id in list(self.comms.keys()):
504506
self._close_explorer(comm_id)
505507

506-
def register_table(self, table, title, variable_path=None, comm_id=None):
508+
def register_table(
509+
self,
510+
table,
511+
title,
512+
variable_path: Optional[List[str]] = None,
513+
comm_id=None,
514+
):
507515
"""
508516
Set up a new comm and data explorer table query wrapper to
509517
handle requests and manage state.
@@ -552,6 +560,9 @@ def close_callback(msg):
552560
base_comm.on_close(close_callback)
553561

554562
if variable_path is not None:
563+
if not isinstance(variable_path, list):
564+
raise ValueError(variable_path)
565+
555566
key = tuple(variable_path)
556567
self.comm_id_to_path[comm_id] = key
557568

extensions/positron-python/pythonFiles/positron/positron_ipykernel/data_explorer_comm.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -113,14 +113,6 @@ class TableSchema(BaseModel):
113113
description="Schema for each column in the table",
114114
)
115115

116-
num_rows: int = Field(
117-
description="Numbers of rows in the unfiltered dataset",
118-
)
119-
120-
total_num_columns: int = Field(
121-
description="Total number of columns in the unfiltered dataset",
122-
)
123-
124116

125117
class TableData(BaseModel):
126118
"""
@@ -211,11 +203,15 @@ class FreqtableCounts(BaseModel):
211203
)
212204

213205

214-
class BackendState(BaseModel):
206+
class TableState(BaseModel):
215207
"""
216-
The current backend state
208+
The current backend table state
217209
"""
218210

211+
table_shape: TableShape = Field(
212+
description="Provides number of rows and columns in table",
213+
)
214+
219215
filters: List[ColumnFilter] = Field(
220216
description="The set of currently applied filters",
221217
)
@@ -225,6 +221,20 @@ class BackendState(BaseModel):
225221
)
226222

227223

224+
class TableShape(BaseModel):
225+
"""
226+
Provides number of rows and columns in table
227+
"""
228+
229+
num_rows: int = Field(
230+
description="Numbers of rows in the unfiltered dataset",
231+
)
232+
233+
num_columns: int = Field(
234+
description="Number of columns in the unfiltered dataset",
235+
)
236+
237+
228238
class ColumnSchema(BaseModel):
229239
"""
230240
Schema for a column in a table
@@ -548,7 +558,7 @@ class GetColumnProfileRequest(BaseModel):
548558

549559
class GetStateRequest(BaseModel):
550560
"""
551-
Request the current backend state (applied filters and sort columns)
561+
Request the current table state (applied filters and sort columns)
552562
"""
553563

554564
method: Literal[DataExplorerBackendRequest.GetState] = Field(
@@ -606,7 +616,9 @@ class SchemaUpdateParams(BaseModel):
606616

607617
FreqtableCounts.update_forward_refs()
608618

609-
BackendState.update_forward_refs()
619+
TableState.update_forward_refs()
620+
621+
TableShape.update_forward_refs()
610622

611623
ColumnSchema.update_forward_refs()
612624

extensions/positron-python/pythonFiles/positron/positron_ipykernel/tests/test_data_explorer.py

Lines changed: 68 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,12 @@
1111
from ..access_keys import encode_access_key
1212
from .._vendor.pydantic import BaseModel
1313
from ..data_explorer import COMPARE_OPS, DataExplorerService
14-
from ..data_explorer_comm import ColumnSchema, ColumnSortKey, FilterResult
14+
from ..data_explorer_comm import (
15+
ColumnFilter,
16+
ColumnSchema,
17+
ColumnSortKey,
18+
FilterResult,
19+
)
1520

1621
from .conftest import DummyComm, PositronShell
1722
from .utils import json_rpc_notification, json_rpc_request, json_rpc_response
@@ -58,6 +63,17 @@ def get_last_message(de_service: DataExplorerService, comm_id: str):
5863
# Test basic service functionality
5964

6065

66+
class MyData:
67+
def __init__(self, value):
68+
self.value = value
69+
70+
def __str__(self):
71+
return str(self.value)
72+
73+
def __repr__(self):
74+
return repr(self.value)
75+
76+
6177
SIMPLE_PANDAS_DF = pd.DataFrame(
6278
{
6379
"a": [1, 2, 3, 4, 5],
@@ -73,6 +89,7 @@ def get_last_message(de_service: DataExplorerService, comm_id: str):
7389
"2024-01-05 00:00:00",
7490
]
7591
),
92+
"f": [None, MyData(5), MyData(-1), None, None],
7693
}
7794
)
7895

@@ -216,6 +233,7 @@ def _check_update_variable(name, update_type="schema", discard_state=True):
216233

217234
# Do a simple update and make sure that sort keys are preserved
218235
x_comm_id = list(de_service.path_to_comm_ids[path_x])[0]
236+
x_sort_keys = [{"column_index": 0, "ascending": True}]
219237
msg = json_rpc_request(
220238
"set_sort_columns",
221239
params={"sort_keys": [{"column_index": 0, "ascending": True}]},
@@ -227,9 +245,15 @@ def _check_update_variable(name, update_type="schema", discard_state=True):
227245
_check_update_variable("x", update_type="data")
228246

229247
tv = de_service.table_views[x_comm_id]
230-
assert tv.sort_keys == [ColumnSortKey(column_index=0, ascending=True)]
248+
assert tv.sort_keys == [ColumnSortKey(**k) for k in x_sort_keys]
231249
assert tv._need_recompute
232250

251+
pf = PandasFixture(de_service)
252+
new_state = pf.get_state("x")
253+
assert new_state["table_shape"]["num_rows"] == 5
254+
assert new_state["table_shape"]["num_columns"] == 1
255+
assert new_state["sort_keys"] == [ColumnSortKey(**k) for k in x_sort_keys]
256+
233257
# Execute code that triggers an update event for big_x because it's large
234258
shell.run_cell("print('hello world')")
235259
_check_update_variable("big_x", update_type="data")
@@ -281,17 +305,30 @@ def test_shutdown(de_service: DataExplorerService):
281305
class PandasFixture:
282306
def __init__(self, de_service: DataExplorerService):
283307
self.de_service = de_service
284-
self._table_ids = {}
285308

286309
self.register_table("simple", SIMPLE_PANDAS_DF)
287310

288311
def register_table(self, table_name: str, table):
289312
comm_id = guid()
290-
self.de_service.register_table(table, table_name, comm_id=comm_id)
291-
self._table_ids[table_name] = comm_id
313+
314+
paths = self.de_service.get_paths_for_variable(table_name)
315+
for path in paths:
316+
for old_comm_id in list(self.de_service.path_to_comm_ids[path]):
317+
self.de_service._close_explorer(old_comm_id)
318+
319+
self.de_service.register_table(
320+
table,
321+
table_name,
322+
comm_id=comm_id,
323+
variable_path=[encode_access_key(table_name)],
324+
)
292325

293326
def do_json_rpc(self, table_name, method, **params):
294-
comm_id = self._table_ids[table_name]
327+
paths = self.de_service.get_paths_for_variable(table_name)
328+
assert len(paths) == 1
329+
330+
comm_id = list(self.de_service.path_to_comm_ids[paths[0]])[0]
331+
295332
request = json_rpc_request(
296333
method,
297334
params=params,
@@ -313,6 +350,9 @@ def get_schema(self, table_name, start_index, num_columns):
313350
num_columns=num_columns,
314351
)
315352

353+
def get_state(self, table_name):
354+
return self.do_json_rpc(table_name, "get_state")
355+
316356
def get_data_values(self, table_name, **params):
317357
return self.do_json_rpc(table_name, "get_data_values", **params)
318358

@@ -372,10 +412,26 @@ def _wrap_json(model: Type[BaseModel], data: JsonRecords):
372412
return [model(**d).dict() for d in data]
373413

374414

415+
def test_pandas_get_state(pandas_fixture: PandasFixture):
416+
result = pandas_fixture.get_state("simple")
417+
assert result["table_shape"]["num_rows"] == 5
418+
assert result["table_shape"]["num_columns"] == 6
419+
420+
sort_keys = [
421+
{"column_index": 0, "ascending": True},
422+
{"column_index": 1, "ascending": False},
423+
]
424+
filters = [_compare_filter(0, ">", 0), _compare_filter(0, "<", 5)]
425+
pandas_fixture.set_sort_columns("simple", sort_keys=sort_keys)
426+
pandas_fixture.set_column_filters("simple", filters=filters)
427+
428+
result = pandas_fixture.get_state("simple")
429+
assert result["sort_keys"] == sort_keys
430+
assert result["filters"] == [ColumnFilter(**f) for f in filters]
431+
432+
375433
def test_pandas_get_schema(pandas_fixture: PandasFixture):
376434
result = pandas_fixture.get_schema("simple", 0, 100)
377-
assert result["num_rows"] == 5
378-
assert result["total_num_columns"] == 5
379435

380436
full_schema = [
381437
{
@@ -403,20 +459,15 @@ def test_pandas_get_schema(pandas_fixture: PandasFixture):
403459
"type_name": "datetime64[ns]",
404460
"type_display": "datetime",
405461
},
462+
{"column_name": "f", "type_name": "mixed", "type_display": "unknown"},
406463
]
407464

408465
assert result["columns"] == _wrap_json(ColumnSchema, full_schema)
409466

410467
result = pandas_fixture.get_schema("simple", 2, 100)
411-
assert result["num_rows"] == 5
412-
assert result["total_num_columns"] == 5
413-
414468
assert result["columns"] == _wrap_json(ColumnSchema, full_schema[2:])
415469

416-
result = pandas_fixture.get_schema("simple", 5, 100)
417-
assert result["num_rows"] == 5
418-
assert result["total_num_columns"] == 5
419-
470+
result = pandas_fixture.get_schema("simple", 6, 100)
420471
assert result["columns"] == []
421472

422473
# Make a really big schema
@@ -426,13 +477,9 @@ def test_pandas_get_schema(pandas_fixture: PandasFixture):
426477
pandas_fixture.register_table(bigger_name, bigger_df)
427478

428479
result = pandas_fixture.get_schema(bigger_name, 0, 100)
429-
assert result["num_rows"] == 5
430-
assert result["total_num_columns"] == 500
431480
assert result["columns"] == _wrap_json(ColumnSchema, bigger_schema[:100])
432481

433482
result = pandas_fixture.get_schema(bigger_name, 10, 10)
434-
assert result["num_rows"] == 5
435-
assert result["total_num_columns"] == 500
436483
assert result["columns"] == _wrap_json(ColumnSchema, bigger_schema[10:20])
437484

438485

@@ -466,7 +513,7 @@ def test_pandas_get_data_values(pandas_fixture: PandasFixture):
466513
"simple",
467514
row_start_index=0,
468515
num_rows=20,
469-
column_indices=list(range(5)),
516+
column_indices=list(range(6)),
470517
)
471518

472519
# TODO: pandas pads all values to fixed width, do we want to do
@@ -483,6 +530,7 @@ def test_pandas_get_data_values(pandas_fixture: PandasFixture):
483530
"2024-01-04 00:00:00",
484531
"2024-01-05 00:00:00",
485532
],
533+
["None", "5", "-1", "None", "None"],
486534
]
487535

488536
assert _trim_whitespace(result["columns"]) == expected_columns

0 commit comments

Comments
 (0)