@@ -312,32 +312,56 @@ def __getitem__(self, indexer: Any):
312
312
return self ._replace (self .index [indexer ])
313
313
314
314
315
- def _create_variables_from_multiindex (index , dim , level_meta = None ):
316
- from .variable import IndexVariable
315
+ def _check_dim_compat (variables : Mapping [Any , "Variable" ]) -> Hashable :
316
+ """Check that all multi-index variable candidates share the same (single) dimension
317
+ and return the name of that dimension.
318
+
319
+ """
320
+ if any ([var .ndim != 1 for var in variables .values ()]):
321
+ raise ValueError ("PandasMultiIndex only accepts 1-dimensional variables" )
317
322
318
- if level_meta is None :
319
- level_meta = {}
323
+ dims = set ([var .dims for var in variables .values ()])
320
324
321
- variables = {}
325
+ if len (dims ) > 1 :
326
+ raise ValueError (
327
+ "unmatched dimensions for variables "
328
+ + ", " .join ([f"{ k !r} { v .dims } " for k , v in variables .items ()])
329
+ )
322
330
323
- dim_coord_adapter = PandasMultiIndexingAdapter (index )
324
- variables [dim ] = IndexVariable (dim , dim_coord_adapter , fastpath = True )
331
+ return next (iter (dims ))[0 ]
325
332
326
- for level in index .names :
327
- meta = level_meta .get (level , {})
333
+
334
+ def _create_variables_from_multiindex (index , dim , var_meta = None ):
335
+ from .variable import IndexVariable
336
+
337
+ if var_meta is None :
338
+ var_meta = {}
339
+
340
+ def create_variable (name ):
341
+ if name == dim :
342
+ level = None
343
+ else :
344
+ level = name
345
+ meta = var_meta .get (name , {})
328
346
data = PandasMultiIndexingAdapter (index , dtype = meta .get ("dtype" ), level = level )
329
- variables [ level ] = IndexVariable (
347
+ return IndexVariable (
330
348
dim ,
331
349
data ,
332
350
attrs = meta .get ("attrs" ),
333
351
encoding = meta .get ("encoding" ),
334
352
fastpath = True ,
335
353
)
336
354
355
+ variables = {}
356
+ variables [dim ] = create_variable (dim )
357
+ for level in index .names :
358
+ variables [level ] = create_variable (level )
359
+
337
360
return variables
338
361
339
362
340
363
class PandasMultiIndex (PandasIndex ):
364
+ """Wrap a pandas.MultiIndex as an xarray compatible index."""
341
365
342
366
level_coords_dtype : Dict [str , Any ]
343
367
@@ -358,51 +382,101 @@ def _replace(self, index, dim=None, level_coords_dtype=None) -> "PandasMultiInde
358
382
return type (self )(index , dim , level_coords_dtype )
359
383
360
384
@classmethod
361
- def from_variables (cls , variables : Mapping [Any , "Variable" ]):
362
- if any ([var .ndim != 1 for var in variables .values ()]):
363
- raise ValueError ("PandasMultiIndex only accepts 1-dimensional variables" )
364
-
365
- dims = set ([var .dims for var in variables .values ()])
366
- if len (dims ) != 1 :
367
- raise ValueError (
368
- "unmatched dimensions for variables "
369
- + "," .join ([str (k ) for k in variables ])
370
- )
385
+ def from_variables (
386
+ cls , variables : Mapping [Any , "Variable" ]
387
+ ) -> Tuple ["PandasMultiIndex" , IndexVars ]:
388
+ dim = _check_dim_compat (variables )
371
389
372
- dim = next (iter (dims ))[0 ]
373
390
index = pd .MultiIndex .from_arrays (
374
391
[var .values for var in variables .values ()], names = variables .keys ()
375
392
)
376
393
level_coords_dtype = {name : var .dtype for name , var in variables .items ()}
377
394
obj = cls (index , dim , level_coords_dtype = level_coords_dtype )
378
395
379
- level_meta = {
396
+ var_meta = {
380
397
name : {"dtype" : var .dtype , "attrs" : var .attrs , "encoding" : var .encoding }
381
398
for name , var in variables .items ()
382
399
}
383
- index_vars = _create_variables_from_multiindex (
384
- index , dim , level_meta = level_meta
385
- )
400
+ index_vars = _create_variables_from_multiindex (index , dim , var_meta = var_meta )
401
+
402
+ return obj , index_vars
403
+
404
+ @classmethod
405
+ def from_variables_maybe_expand (
406
+ cls ,
407
+ dim : Hashable ,
408
+ current_variables : Mapping [Any , "Variable" ],
409
+ variables : Mapping [Any , "Variable" ],
410
+ ) -> Tuple ["PandasMultiIndex" , IndexVars ]:
411
+ """Create a new multi-index maybe by expanding an existing one with
412
+ new variables as index levels.
413
+
414
+ the index might be created along a new dimension.
415
+ """
416
+ names : List [Hashable ] = []
417
+ codes : List [List [int ]] = []
418
+ levels : List [List [int ]] = []
419
+ var_meta : Dict [str , Dict ] = {}
420
+ level_coords_dtype : Dict [Hashable , Any ] = {}
421
+
422
+ _check_dim_compat ({** current_variables , ** variables })
423
+
424
+ def add_level_var (name , var ):
425
+ var_meta [name ] = {
426
+ "dtype" : var .dtype ,
427
+ "attrs" : var .attrs ,
428
+ "encoding" : var .encoding ,
429
+ }
430
+ level_coords_dtype [name ] = var .dtype
431
+
432
+ if len (current_variables ) > 1 :
433
+ current_index : pd .MultiIndex = next (
434
+ iter (current_variables .values ())
435
+ )._data .array
436
+ names .extend (current_index .names )
437
+ codes .extend (current_index .codes )
438
+ levels .extend (current_index .levels )
439
+ for name in current_index .names :
440
+ add_level_var (name , current_variables [name ])
441
+
442
+ elif len (current_variables ) == 1 :
443
+ # one 1D variable (no multi-index): convert it to an index level
444
+ var = next (iter (current_variables .values ()))
445
+ new_var_name = f"{ dim } _level_0"
446
+ names .append (new_var_name )
447
+ cat = pd .Categorical (var .values , ordered = True )
448
+ codes .append (cat .codes )
449
+ levels .append (cat .categories )
450
+ add_level_var (new_var_name , var )
451
+
452
+ for name , var in variables .items ():
453
+ names .append (name )
454
+ cat = pd .Categorical (var .values , ordered = True )
455
+ codes .append (cat .codes )
456
+ levels .append (cat .categories )
457
+ add_level_var (name , var )
458
+
459
+ index = pd .MultiIndex (levels , codes , names = names )
460
+ obj = cls (index , dim , level_coords_dtype = level_coords_dtype )
461
+ index_vars = _create_variables_from_multiindex (index , dim , var_meta = var_meta )
386
462
387
463
return obj , index_vars
388
464
389
465
@classmethod
390
466
def from_pandas_index (
391
467
cls , index : pd .MultiIndex , dim : Hashable
392
468
) -> Tuple ["PandasMultiIndex" , IndexVars ]:
393
- level_meta = {}
469
+ var_meta = {}
394
470
for i , idx in enumerate (index .levels ):
395
471
name = idx .name or f"{ dim } _level_{ i } "
396
472
if name == dim :
397
473
raise ValueError (
398
474
f"conflicting multi-index level name { name !r} with dimension { dim !r} "
399
475
)
400
- level_meta [name ] = {"dtype" : idx .dtype }
476
+ var_meta [name ] = {"dtype" : idx .dtype }
401
477
402
- index = index .rename (level_meta .keys ())
403
- index_vars = _create_variables_from_multiindex (
404
- index , dim , level_meta = level_meta
405
- )
478
+ index = index .rename (var_meta .keys ())
479
+ index_vars = _create_variables_from_multiindex (index , dim , var_meta = var_meta )
406
480
return cls (index , dim ), index_vars
407
481
408
482
def query (self , labels , method = None , tolerance = None ) -> QueryResult :
0 commit comments