16
16
)
17
17
from zarr .core .common import ChunkCoords , concurrent_map
18
18
from zarr .core .config import config
19
- from zarr .core .indexing import SelectorTuple , is_scalar
19
+ from zarr .core .indexing import SelectorTuple , is_scalar , is_total_slice
20
20
from zarr .core .metadata .v2 import _default_fill_value
21
21
from zarr .registry import register_pipeline
22
22
@@ -243,18 +243,18 @@ async def encode_partial_batch(
243
243
244
244
async def read_batch (
245
245
self ,
246
- batch_info : Iterable [tuple [ByteGetter , ArraySpec , SelectorTuple , SelectorTuple , bool ]],
246
+ batch_info : Iterable [tuple [ByteGetter , ArraySpec , SelectorTuple , SelectorTuple ]],
247
247
out : NDBuffer ,
248
248
drop_axes : tuple [int , ...] = (),
249
249
) -> None :
250
250
if self .supports_partial_decode :
251
251
chunk_array_batch = await self .decode_partial_batch (
252
252
[
253
253
(byte_getter , chunk_selection , chunk_spec )
254
- for byte_getter , chunk_spec , chunk_selection , * _ in batch_info
254
+ for byte_getter , chunk_spec , chunk_selection , _ in batch_info
255
255
]
256
256
)
257
- for chunk_array , (_ , chunk_spec , _ , out_selection , _ ) in zip (
257
+ for chunk_array , (_ , chunk_spec , _ , out_selection ) in zip (
258
258
chunk_array_batch , batch_info , strict = False
259
259
):
260
260
if chunk_array is not None :
@@ -263,19 +263,22 @@ async def read_batch(
263
263
out [out_selection ] = fill_value_or_default (chunk_spec )
264
264
else :
265
265
chunk_bytes_batch = await concurrent_map (
266
- [(byte_getter , array_spec .prototype ) for byte_getter , array_spec , * _ in batch_info ],
266
+ [
267
+ (byte_getter , array_spec .prototype )
268
+ for byte_getter , array_spec , _ , _ in batch_info
269
+ ],
267
270
lambda byte_getter , prototype : byte_getter .get (prototype ),
268
271
config .get ("async.concurrency" ),
269
272
)
270
273
chunk_array_batch = await self .decode_batch (
271
274
[
272
275
(chunk_bytes , chunk_spec )
273
- for chunk_bytes , (_ , chunk_spec , * _ ) in zip (
276
+ for chunk_bytes , (_ , chunk_spec , _ , _ ) in zip (
274
277
chunk_bytes_batch , batch_info , strict = False
275
278
)
276
279
],
277
280
)
278
- for chunk_array , (_ , chunk_spec , chunk_selection , out_selection , _ ) in zip (
281
+ for chunk_array , (_ , chunk_spec , chunk_selection , out_selection ) in zip (
279
282
chunk_array_batch , batch_info , strict = False
280
283
):
281
284
if chunk_array is not None :
@@ -293,10 +296,9 @@ def _merge_chunk_array(
293
296
out_selection : SelectorTuple ,
294
297
chunk_spec : ArraySpec ,
295
298
chunk_selection : SelectorTuple ,
296
- is_complete_chunk : bool ,
297
299
drop_axes : tuple [int , ...],
298
300
) -> NDBuffer :
299
- if is_complete_chunk and value .shape == chunk_spec .shape :
301
+ if is_total_slice ( chunk_selection , chunk_spec . shape ) and value .shape == chunk_spec .shape :
300
302
return value
301
303
if existing_chunk_array is None :
302
304
chunk_array = chunk_spec .prototype .nd_buffer .create (
@@ -325,7 +327,7 @@ def _merge_chunk_array(
325
327
326
328
async def write_batch (
327
329
self ,
328
- batch_info : Iterable [tuple [ByteSetter , ArraySpec , SelectorTuple , SelectorTuple , bool ]],
330
+ batch_info : Iterable [tuple [ByteSetter , ArraySpec , SelectorTuple , SelectorTuple ]],
329
331
value : NDBuffer ,
330
332
drop_axes : tuple [int , ...] = (),
331
333
) -> None :
@@ -335,14 +337,14 @@ async def write_batch(
335
337
await self .encode_partial_batch (
336
338
[
337
339
(byte_setter , value , chunk_selection , chunk_spec )
338
- for byte_setter , chunk_spec , chunk_selection , out_selection , _ in batch_info
340
+ for byte_setter , chunk_spec , chunk_selection , out_selection in batch_info
339
341
],
340
342
)
341
343
else :
342
344
await self .encode_partial_batch (
343
345
[
344
346
(byte_setter , value [out_selection ], chunk_selection , chunk_spec )
345
- for byte_setter , chunk_spec , chunk_selection , out_selection , _ in batch_info
347
+ for byte_setter , chunk_spec , chunk_selection , out_selection in batch_info
346
348
],
347
349
)
348
350
@@ -359,43 +361,33 @@ async def _read_key(
359
361
chunk_bytes_batch = await concurrent_map (
360
362
[
361
363
(
362
- None if is_complete_chunk else byte_setter ,
364
+ None if is_total_slice ( chunk_selection , chunk_spec . shape ) else byte_setter ,
363
365
chunk_spec .prototype ,
364
366
)
365
- for byte_setter , chunk_spec , chunk_selection , _ , is_complete_chunk in batch_info
367
+ for byte_setter , chunk_spec , chunk_selection , _ in batch_info
366
368
],
367
369
_read_key ,
368
370
config .get ("async.concurrency" ),
369
371
)
370
372
chunk_array_decoded = await self .decode_batch (
371
373
[
372
374
(chunk_bytes , chunk_spec )
373
- for chunk_bytes , (_ , chunk_spec , * _ ) in zip (
375
+ for chunk_bytes , (_ , chunk_spec , _ , _ ) in zip (
374
376
chunk_bytes_batch , batch_info , strict = False
375
377
)
376
378
],
377
379
)
378
380
379
381
chunk_array_merged = [
380
382
self ._merge_chunk_array (
381
- chunk_array ,
382
- value ,
383
- out_selection ,
384
- chunk_spec ,
385
- chunk_selection ,
386
- is_complete_chunk ,
387
- drop_axes ,
383
+ chunk_array , value , out_selection , chunk_spec , chunk_selection , drop_axes
384
+ )
385
+ for chunk_array , (_ , chunk_spec , chunk_selection , out_selection ) in zip (
386
+ chunk_array_decoded , batch_info , strict = False
388
387
)
389
- for chunk_array , (
390
- _ ,
391
- chunk_spec ,
392
- chunk_selection ,
393
- out_selection ,
394
- is_complete_chunk ,
395
- ) in zip (chunk_array_decoded , batch_info , strict = False )
396
388
]
397
389
chunk_array_batch : list [NDBuffer | None ] = []
398
- for chunk_array , (_ , chunk_spec , * _ ) in zip (
390
+ for chunk_array , (_ , chunk_spec , _ , _ ) in zip (
399
391
chunk_array_merged , batch_info , strict = False
400
392
):
401
393
if chunk_array is None :
@@ -411,7 +403,7 @@ async def _read_key(
411
403
chunk_bytes_batch = await self .encode_batch (
412
404
[
413
405
(chunk_array , chunk_spec )
414
- for chunk_array , (_ , chunk_spec , * _ ) in zip (
406
+ for chunk_array , (_ , chunk_spec , _ , _ ) in zip (
415
407
chunk_array_batch , batch_info , strict = False
416
408
)
417
409
],
@@ -426,7 +418,7 @@ async def _write_key(byte_setter: ByteSetter, chunk_bytes: Buffer | None) -> Non
426
418
await concurrent_map (
427
419
[
428
420
(byte_setter , chunk_bytes )
429
- for chunk_bytes , (byte_setter , * _ ) in zip (
421
+ for chunk_bytes , (byte_setter , _ , _ , _ ) in zip (
430
422
chunk_bytes_batch , batch_info , strict = False
431
423
)
432
424
],
@@ -454,7 +446,7 @@ async def encode(
454
446
455
447
async def read (
456
448
self ,
457
- batch_info : Iterable [tuple [ByteGetter , ArraySpec , SelectorTuple , SelectorTuple , bool ]],
449
+ batch_info : Iterable [tuple [ByteGetter , ArraySpec , SelectorTuple , SelectorTuple ]],
458
450
out : NDBuffer ,
459
451
drop_axes : tuple [int , ...] = (),
460
452
) -> None :
@@ -469,7 +461,7 @@ async def read(
469
461
470
462
async def write (
471
463
self ,
472
- batch_info : Iterable [tuple [ByteSetter , ArraySpec , SelectorTuple , SelectorTuple , bool ]],
464
+ batch_info : Iterable [tuple [ByteSetter , ArraySpec , SelectorTuple , SelectorTuple ]],
473
465
value : NDBuffer ,
474
466
drop_axes : tuple [int , ...] = (),
475
467
) -> None :
0 commit comments