@@ -96,8 +96,10 @@ class COO(object):
96
96
__array_priority__ = 12
97
97
98
98
def __init__ (self , coords , data = None , shape = None , has_duplicates = True ,
99
- sorted = False ):
100
- self ._cache = defaultdict (lambda : deque (maxlen = 3 ))
99
+ sorted = False , cache = False ):
100
+ self ._cache = None
101
+ if cache :
102
+ self .enable_caching ()
101
103
if data is None :
102
104
# {(i, j, k): x, (i, j, k): y, ...}
103
105
if isinstance (coords , dict ):
@@ -154,7 +156,30 @@ def __init__(self, coords, data=None, shape=None, has_duplicates=True,
154
156
assert not self .shape or len (data ) == self .coords .shape [1 ]
155
157
self .has_duplicates = has_duplicates
156
158
self .sorted = sorted
159
+
160
+ def enable_caching (self ):
161
+ """ Enable caching of reshape, transpose, and tocsr/csc operations
162
+
163
+ This enables efficient iterative workflows that make heavy use of
164
+ csr/csc operations, such as tensordot. This maintains a cache of
165
+ recent results of reshape and transpose so that operations like
166
+ tensordot (which uses both internally) store efficiently stored
167
+ representations for repeated use. This can significantly cut down on
168
+ computational costs in common numeric algorithms.
169
+
170
+ However, this also assumes that neither this object, nor the downstream
171
+ objects will have their data mutated.
172
+
173
+ Examples
174
+ --------
175
+ >>> x.enable_caching() # doctest: +SKIP
176
+ >>> csr1 = x.transpose((2, 0, 1)).reshape((100, 120)).tocsr() # doctest: +SKIP
177
+ >>> csr2 = x.transpose((2, 0, 1)).reshape((100, 120)).tocsr() # doctest: +SKIP
178
+ >>> csr1 is csr2 # doctest: +SKIP
179
+ True
180
+ """
157
181
self ._cache = defaultdict (lambda : deque (maxlen = 3 ))
182
+ return self
158
183
159
184
@classmethod
160
185
def from_numpy (cls , x ):
@@ -329,15 +354,18 @@ def transpose(self, axes=None):
329
354
if axes == tuple (range (self .ndim )):
330
355
return self
331
356
332
- for ax , value in self ._cache ['transpose' ]:
333
- if ax == axes :
334
- return value
357
+ if self ._cache is not None :
358
+ for ax , value in self ._cache ['transpose' ]:
359
+ if ax == axes :
360
+ return value
335
361
336
362
shape = tuple (self .shape [ax ] for ax in axes )
337
363
result = COO (self .coords [axes , :], self .data , shape ,
338
- has_duplicates = self .has_duplicates )
364
+ has_duplicates = self .has_duplicates ,
365
+ cache = self ._cache is not None )
339
366
340
- self ._cache ['transpose' ].append ((axes , result ))
367
+ if self ._cache is not None :
368
+ self ._cache ['transpose' ].append ((axes , result ))
341
369
return result
342
370
343
371
@property
@@ -378,9 +406,10 @@ def reshape(self, shape):
378
406
if self .shape == shape :
379
407
return self
380
408
381
- for sh , value in self ._cache ['reshape' ]:
382
- if sh == shape :
383
- return value
409
+ if self ._cache is not None :
410
+ for sh , value in self ._cache ['reshape' ]:
411
+ if sh == shape :
412
+ return value
384
413
385
414
# TODO: this np.prod(self.shape) enforces a 2**64 limit to array size
386
415
linear_loc = self .linear_loc ()
@@ -393,9 +422,10 @@ def reshape(self, shape):
393
422
394
423
result = COO (coords , self .data , shape ,
395
424
has_duplicates = self .has_duplicates ,
396
- sorted = self .sorted )
425
+ sorted = self .sorted , cache = self . _cache is not None )
397
426
398
- self ._cache ['reshape' ].append ((shape , result ))
427
+ if self ._cache is not None :
428
+ self ._cache ['reshape' ].append ((shape , result ))
399
429
return result
400
430
401
431
def to_scipy_sparse (self ):
@@ -424,32 +454,39 @@ def _tocsr(self):
424
454
return scipy .sparse .csr_matrix ((self .data , col , indptr ), shape = self .shape )
425
455
426
456
def tocsr (self ):
427
- try :
428
- return self ._csr
429
- except AttributeError :
430
- pass
431
- try :
432
- self ._csr = self ._csc .tocsr ()
433
- return self ._csr
434
- except AttributeError :
435
- pass
436
-
437
- self ._csr = self ._tocsr ()
438
- return self ._csr
457
+ if self ._cache is not None :
458
+ try :
459
+ return self ._csr
460
+ except AttributeError :
461
+ pass
462
+ try :
463
+ self ._csr = self ._csc .tocsr ()
464
+ return self ._csr
465
+ except AttributeError :
466
+ pass
467
+
468
+ self ._csr = csr = self ._tocsr ()
469
+ else :
470
+ csr = self ._tocsr ()
471
+ return csr
439
472
440
473
def tocsc (self ):
441
- try :
442
- return self ._csc
443
- except AttributeError :
444
- pass
445
- try :
446
- self ._csc = self ._csr .tocsc ()
447
- return self ._csc
448
- except AttributeError :
449
- pass
450
-
451
- self ._csc = self .tocsr ().tocsc ()
452
- return self ._csc
474
+ if self ._cache is not None :
475
+ try :
476
+ return self ._csc
477
+ except AttributeError :
478
+ pass
479
+ try :
480
+ self ._csc = self ._csr .tocsc ()
481
+ return self ._csc
482
+ except AttributeError :
483
+ pass
484
+
485
+ self ._csc = csc = self .tocsr ().tocsc ()
486
+ else :
487
+ csc = self .tocsr ().tocsc ()
488
+
489
+ return csc
453
490
454
491
def sort_indices (self ):
455
492
if self .sorted :
0 commit comments