Skip to content

Commit f5a9c1e

Browse files
committed
Opt in to caching
1 parent 214ff74 commit f5a9c1e

File tree

2 files changed

+77
-38
lines changed

2 files changed

+77
-38
lines changed

sparse/core.py

Lines changed: 73 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,10 @@ class COO(object):
9696
__array_priority__ = 12
9797

9898
def __init__(self, coords, data=None, shape=None, has_duplicates=True,
99-
sorted=False):
100-
self._cache = defaultdict(lambda: deque(maxlen=3))
99+
sorted=False, cache=False):
100+
self._cache = None
101+
if cache:
102+
self.enable_caching()
101103
if data is None:
102104
# {(i, j, k): x, (i, j, k): y, ...}
103105
if isinstance(coords, dict):
@@ -154,7 +156,30 @@ def __init__(self, coords, data=None, shape=None, has_duplicates=True,
154156
assert not self.shape or len(data) == self.coords.shape[1]
155157
self.has_duplicates = has_duplicates
156158
self.sorted = sorted
159+
160+
def enable_caching(self):
161+
""" Enable caching of reshape, transpose, and tocsr/csc operations
162+
163+
This enables efficient iterative workflows that make heavy use of
164+
csr/csc operations, such as tensordot. This maintains a cache of
165+
recent results of reshape and transpose so that operations like
166+
tensordot (which uses both internally) store efficiently stored
167+
representations for repeated use. This can significantly cut down on
168+
computational costs in common numeric algorithms.
169+
170+
However, this also assumes that neither this object, nor the downstream
171+
objects will have their data mutated.
172+
173+
Examples
174+
--------
175+
>>> x.enable_caching() # doctest: +SKIP
176+
>>> csr1 = x.transpose((2, 0, 1)).reshape((100, 120)).tocsr() # doctest: +SKIP
177+
>>> csr2 = x.transpose((2, 0, 1)).reshape((100, 120)).tocsr() # doctest: +SKIP
178+
>>> csr1 is csr2 # doctest: +SKIP
179+
True
180+
"""
157181
self._cache = defaultdict(lambda: deque(maxlen=3))
182+
return self
158183

159184
@classmethod
160185
def from_numpy(cls, x):
@@ -329,15 +354,18 @@ def transpose(self, axes=None):
329354
if axes == tuple(range(self.ndim)):
330355
return self
331356

332-
for ax, value in self._cache['transpose']:
333-
if ax == axes:
334-
return value
357+
if self._cache is not None:
358+
for ax, value in self._cache['transpose']:
359+
if ax == axes:
360+
return value
335361

336362
shape = tuple(self.shape[ax] for ax in axes)
337363
result = COO(self.coords[axes, :], self.data, shape,
338-
has_duplicates=self.has_duplicates)
364+
has_duplicates=self.has_duplicates,
365+
cache=self._cache is not None)
339366

340-
self._cache['transpose'].append((axes, result))
367+
if self._cache is not None:
368+
self._cache['transpose'].append((axes, result))
341369
return result
342370

343371
@property
@@ -378,9 +406,10 @@ def reshape(self, shape):
378406
if self.shape == shape:
379407
return self
380408

381-
for sh, value in self._cache['reshape']:
382-
if sh == shape:
383-
return value
409+
if self._cache is not None:
410+
for sh, value in self._cache['reshape']:
411+
if sh == shape:
412+
return value
384413

385414
# TODO: this np.prod(self.shape) enforces a 2**64 limit to array size
386415
linear_loc = self.linear_loc()
@@ -393,9 +422,10 @@ def reshape(self, shape):
393422

394423
result = COO(coords, self.data, shape,
395424
has_duplicates=self.has_duplicates,
396-
sorted=self.sorted)
425+
sorted=self.sorted, cache=self._cache is not None)
397426

398-
self._cache['reshape'].append((shape, result))
427+
if self._cache is not None:
428+
self._cache['reshape'].append((shape, result))
399429
return result
400430

401431
def to_scipy_sparse(self):
@@ -424,32 +454,39 @@ def _tocsr(self):
424454
return scipy.sparse.csr_matrix((self.data, col, indptr), shape=self.shape)
425455

426456
def tocsr(self):
427-
try:
428-
return self._csr
429-
except AttributeError:
430-
pass
431-
try:
432-
self._csr = self._csc.tocsr()
433-
return self._csr
434-
except AttributeError:
435-
pass
436-
437-
self._csr = self._tocsr()
438-
return self._csr
457+
if self._cache is not None:
458+
try:
459+
return self._csr
460+
except AttributeError:
461+
pass
462+
try:
463+
self._csr = self._csc.tocsr()
464+
return self._csr
465+
except AttributeError:
466+
pass
467+
468+
self._csr = csr = self._tocsr()
469+
else:
470+
csr = self._tocsr()
471+
return csr
439472

440473
def tocsc(self):
441-
try:
442-
return self._csc
443-
except AttributeError:
444-
pass
445-
try:
446-
self._csc = self._csr.tocsc()
447-
return self._csc
448-
except AttributeError:
449-
pass
450-
451-
self._csc = self.tocsr().tocsc()
452-
return self._csc
474+
if self._cache is not None:
475+
try:
476+
return self._csc
477+
except AttributeError:
478+
pass
479+
try:
480+
self._csc = self._csr.tocsc()
481+
return self._csc
482+
except AttributeError:
483+
pass
484+
485+
self._csc = csc = self.tocsr().tocsc()
486+
else:
487+
csc = self.tocsr().tocsc()
488+
489+
return csc
453490

454491
def sort_indices(self):
455492
if self.sorted:

sparse/tests/test_core.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ def test_scipy_sparse_interface():
411411

412412
def test_cache_csr():
413413
x = random_x((10, 5))
414-
s = COO.from_numpy(x)
414+
s = COO(x, cache=True)
415415

416416
assert isinstance(s.tocsr(), scipy.sparse.csr_matrix)
417417
assert isinstance(s.tocsc(), scipy.sparse.csc_matrix)
@@ -465,10 +465,12 @@ def test_add_many_sparse_arrays():
465465

466466
def test_caching():
467467
x = COO({(10, 10, 10): 1})
468+
assert x[:].reshape((100, 10)).transpose().tocsr() is not x[:].reshape((100, 10)).transpose().tocsr()
468469

470+
x = COO({(10, 10, 10): 1}, cache=True)
469471
assert x[:].reshape((100, 10)).transpose().tocsr() is x[:].reshape((100, 10)).transpose().tocsr()
470472

471-
x = COO({(1, 1, 1, 1, 1, 1, 1, 2): 1})
473+
x = COO({(1, 1, 1, 1, 1, 1, 1, 2): 1}, cache=True)
472474

473475
for i in range(x.ndim):
474476
x.reshape((1,) * i + (2,) + (1,) * (x.ndim - i - 1))

0 commit comments

Comments
 (0)