11
11
from ..access_keys import encode_access_key
12
12
from .._vendor .pydantic import BaseModel
13
13
from ..data_explorer import COMPARE_OPS , DataExplorerService
14
- from ..data_explorer_comm import ColumnSchema , ColumnSortKey , FilterResult
14
+ from ..data_explorer_comm import (
15
+ ColumnFilter ,
16
+ ColumnSchema ,
17
+ ColumnSortKey ,
18
+ FilterResult ,
19
+ )
15
20
16
21
from .conftest import DummyComm , PositronShell
17
22
from .utils import json_rpc_notification , json_rpc_request , json_rpc_response
@@ -58,6 +63,17 @@ def get_last_message(de_service: DataExplorerService, comm_id: str):
58
63
# Test basic service functionality
59
64
60
65
66
+ class MyData :
67
+ def __init__ (self , value ):
68
+ self .value = value
69
+
70
+ def __str__ (self ):
71
+ return str (self .value )
72
+
73
+ def __repr__ (self ):
74
+ return repr (self .value )
75
+
76
+
61
77
SIMPLE_PANDAS_DF = pd .DataFrame (
62
78
{
63
79
"a" : [1 , 2 , 3 , 4 , 5 ],
@@ -73,6 +89,7 @@ def get_last_message(de_service: DataExplorerService, comm_id: str):
73
89
"2024-01-05 00:00:00" ,
74
90
]
75
91
),
92
+ "f" : [None , MyData (5 ), MyData (- 1 ), None , None ],
76
93
}
77
94
)
78
95
@@ -216,6 +233,7 @@ def _check_update_variable(name, update_type="schema", discard_state=True):
216
233
217
234
# Do a simple update and make sure that sort keys are preserved
218
235
x_comm_id = list (de_service .path_to_comm_ids [path_x ])[0 ]
236
+ x_sort_keys = [{"column_index" : 0 , "ascending" : True }]
219
237
msg = json_rpc_request (
220
238
"set_sort_columns" ,
221
239
params = {"sort_keys" : [{"column_index" : 0 , "ascending" : True }]},
@@ -227,9 +245,15 @@ def _check_update_variable(name, update_type="schema", discard_state=True):
227
245
_check_update_variable ("x" , update_type = "data" )
228
246
229
247
tv = de_service .table_views [x_comm_id ]
230
- assert tv .sort_keys == [ColumnSortKey (column_index = 0 , ascending = True ) ]
248
+ assert tv .sort_keys == [ColumnSortKey (** k ) for k in x_sort_keys ]
231
249
assert tv ._need_recompute
232
250
251
+ pf = PandasFixture (de_service )
252
+ new_state = pf .get_state ("x" )
253
+ assert new_state ["table_shape" ]["num_rows" ] == 5
254
+ assert new_state ["table_shape" ]["num_columns" ] == 1
255
+ assert new_state ["sort_keys" ] == [ColumnSortKey (** k ) for k in x_sort_keys ]
256
+
233
257
# Execute code that triggers an update event for big_x because it's large
234
258
shell .run_cell ("print('hello world')" )
235
259
_check_update_variable ("big_x" , update_type = "data" )
@@ -281,17 +305,30 @@ def test_shutdown(de_service: DataExplorerService):
281
305
class PandasFixture :
282
306
def __init__ (self , de_service : DataExplorerService ):
283
307
self .de_service = de_service
284
- self ._table_ids = {}
285
308
286
309
self .register_table ("simple" , SIMPLE_PANDAS_DF )
287
310
288
311
def register_table (self , table_name : str , table ):
289
312
comm_id = guid ()
290
- self .de_service .register_table (table , table_name , comm_id = comm_id )
291
- self ._table_ids [table_name ] = comm_id
313
+
314
+ paths = self .de_service .get_paths_for_variable (table_name )
315
+ for path in paths :
316
+ for old_comm_id in list (self .de_service .path_to_comm_ids [path ]):
317
+ self .de_service ._close_explorer (old_comm_id )
318
+
319
+ self .de_service .register_table (
320
+ table ,
321
+ table_name ,
322
+ comm_id = comm_id ,
323
+ variable_path = [encode_access_key (table_name )],
324
+ )
292
325
293
326
def do_json_rpc (self , table_name , method , ** params ):
294
- comm_id = self ._table_ids [table_name ]
327
+ paths = self .de_service .get_paths_for_variable (table_name )
328
+ assert len (paths ) == 1
329
+
330
+ comm_id = list (self .de_service .path_to_comm_ids [paths [0 ]])[0 ]
331
+
295
332
request = json_rpc_request (
296
333
method ,
297
334
params = params ,
@@ -313,6 +350,9 @@ def get_schema(self, table_name, start_index, num_columns):
313
350
num_columns = num_columns ,
314
351
)
315
352
353
+ def get_state (self , table_name ):
354
+ return self .do_json_rpc (table_name , "get_state" )
355
+
316
356
def get_data_values (self , table_name , ** params ):
317
357
return self .do_json_rpc (table_name , "get_data_values" , ** params )
318
358
@@ -372,10 +412,26 @@ def _wrap_json(model: Type[BaseModel], data: JsonRecords):
372
412
return [model (** d ).dict () for d in data ]
373
413
374
414
415
+ def test_pandas_get_state (pandas_fixture : PandasFixture ):
416
+ result = pandas_fixture .get_state ("simple" )
417
+ assert result ["table_shape" ]["num_rows" ] == 5
418
+ assert result ["table_shape" ]["num_columns" ] == 6
419
+
420
+ sort_keys = [
421
+ {"column_index" : 0 , "ascending" : True },
422
+ {"column_index" : 1 , "ascending" : False },
423
+ ]
424
+ filters = [_compare_filter (0 , ">" , 0 ), _compare_filter (0 , "<" , 5 )]
425
+ pandas_fixture .set_sort_columns ("simple" , sort_keys = sort_keys )
426
+ pandas_fixture .set_column_filters ("simple" , filters = filters )
427
+
428
+ result = pandas_fixture .get_state ("simple" )
429
+ assert result ["sort_keys" ] == sort_keys
430
+ assert result ["filters" ] == [ColumnFilter (** f ) for f in filters ]
431
+
432
+
375
433
def test_pandas_get_schema (pandas_fixture : PandasFixture ):
376
434
result = pandas_fixture .get_schema ("simple" , 0 , 100 )
377
- assert result ["num_rows" ] == 5
378
- assert result ["total_num_columns" ] == 5
379
435
380
436
full_schema = [
381
437
{
@@ -403,20 +459,15 @@ def test_pandas_get_schema(pandas_fixture: PandasFixture):
403
459
"type_name" : "datetime64[ns]" ,
404
460
"type_display" : "datetime" ,
405
461
},
462
+ {"column_name" : "f" , "type_name" : "mixed" , "type_display" : "unknown" },
406
463
]
407
464
408
465
assert result ["columns" ] == _wrap_json (ColumnSchema , full_schema )
409
466
410
467
result = pandas_fixture .get_schema ("simple" , 2 , 100 )
411
- assert result ["num_rows" ] == 5
412
- assert result ["total_num_columns" ] == 5
413
-
414
468
assert result ["columns" ] == _wrap_json (ColumnSchema , full_schema [2 :])
415
469
416
- result = pandas_fixture .get_schema ("simple" , 5 , 100 )
417
- assert result ["num_rows" ] == 5
418
- assert result ["total_num_columns" ] == 5
419
-
470
+ result = pandas_fixture .get_schema ("simple" , 6 , 100 )
420
471
assert result ["columns" ] == []
421
472
422
473
# Make a really big schema
@@ -426,13 +477,9 @@ def test_pandas_get_schema(pandas_fixture: PandasFixture):
426
477
pandas_fixture .register_table (bigger_name , bigger_df )
427
478
428
479
result = pandas_fixture .get_schema (bigger_name , 0 , 100 )
429
- assert result ["num_rows" ] == 5
430
- assert result ["total_num_columns" ] == 500
431
480
assert result ["columns" ] == _wrap_json (ColumnSchema , bigger_schema [:100 ])
432
481
433
482
result = pandas_fixture .get_schema (bigger_name , 10 , 10 )
434
- assert result ["num_rows" ] == 5
435
- assert result ["total_num_columns" ] == 500
436
483
assert result ["columns" ] == _wrap_json (ColumnSchema , bigger_schema [10 :20 ])
437
484
438
485
@@ -466,7 +513,7 @@ def test_pandas_get_data_values(pandas_fixture: PandasFixture):
466
513
"simple" ,
467
514
row_start_index = 0 ,
468
515
num_rows = 20 ,
469
- column_indices = list (range (5 )),
516
+ column_indices = list (range (6 )),
470
517
)
471
518
472
519
# TODO: pandas pads all values to fixed width, do we want to do
@@ -483,6 +530,7 @@ def test_pandas_get_data_values(pandas_fixture: PandasFixture):
483
530
"2024-01-04 00:00:00" ,
484
531
"2024-01-05 00:00:00" ,
485
532
],
533
+ ["None" , "5" , "-1" , "None" , "None" ],
486
534
]
487
535
488
536
assert _trim_whitespace (result ["columns" ]) == expected_columns
0 commit comments