@@ -333,6 +333,11 @@ def __init__(
333
333
object .__setattr__ (self , "index_codecs" , index_codecs_parsed )
334
334
object .__setattr__ (self , "index_location" , index_location_parsed )
335
335
336
+ # Use instance-local lru_cache to avoid memory leaks
337
+ object .__setattr__ (self , "_get_chunk_spec" , lru_cache ()(self ._get_chunk_spec ))
338
+ object .__setattr__ (self , "_get_index_chunk_spec" , lru_cache ()(self ._get_index_chunk_spec ))
339
+ object .__setattr__ (self , "_get_chunks_per_shard" , lru_cache ()(self ._get_chunks_per_shard ))
340
+
336
341
@classmethod
337
342
def from_dict (cls , data : Dict [str , JSON ]) -> Self :
338
343
_ , configuration_parsed = parse_named_configuration (data , "sharding_indexed" )
@@ -609,7 +614,6 @@ def _shard_index_size(self, chunks_per_shard: ChunkCoords) -> int:
609
614
16 * product (chunks_per_shard ), self ._get_index_chunk_spec (chunks_per_shard )
610
615
)
611
616
612
- @lru_cache
613
617
def _get_index_chunk_spec (self , chunks_per_shard : ChunkCoords ) -> ArraySpec :
614
618
return ArraySpec (
615
619
shape = chunks_per_shard + (2 ,),
@@ -618,7 +622,6 @@ def _get_index_chunk_spec(self, chunks_per_shard: ChunkCoords) -> ArraySpec:
618
622
order = "C" , # Note: this is hard-coded for simplicity -- it is not surfaced into user code
619
623
)
620
624
621
- @lru_cache
622
625
def _get_chunk_spec (self , shard_spec : ArraySpec ) -> ArraySpec :
623
626
return ArraySpec (
624
627
shape = self .chunk_shape ,
@@ -627,7 +630,6 @@ def _get_chunk_spec(self, shard_spec: ArraySpec) -> ArraySpec:
627
630
order = shard_spec .order ,
628
631
)
629
632
630
- @lru_cache
631
633
def _get_chunks_per_shard (self , shard_spec : ArraySpec ) -> ChunkCoords :
632
634
return tuple (
633
635
s // c
0 commit comments