diff --git a/benchmarks/benchmark_coo.py b/benchmarks/benchmark_coo.py index 39088fbc..ca99da9e 100644 --- a/benchmarks/benchmark_coo.py +++ b/benchmarks/benchmark_coo.py @@ -3,6 +3,18 @@ import sparse +class MatrixMultiplySuite: + def setup(self): + np.random.seed(0) + self.x = sparse.random((100, 100), density=0.01) + self.y = sparse.random((100, 100), density=0.01) + + self.x @ self.y # Numba compilation + + def time_matmul(self): + self.x @ self.y + + class ElemwiseSuite: def setup(self): np.random.seed(0) diff --git a/sparse/_coo/common.py b/sparse/_coo/common.py index 98e00fa5..5477dbda 100644 --- a/sparse/_coo/common.py +++ b/sparse/_coo/common.py @@ -1166,44 +1166,45 @@ def _dot_coo_coo(coords1, data1, coords2, data2): # pragma: no cover coords_out = [] data_out = [] didx1 = 0 + data1_end = len(data1) + data2_end = len(data2) - while didx1 < len(data1): + while didx1 < data1_end: oidx1 = coords1[0, didx1] didx2 = 0 didx1_curr = didx1 while ( - didx2 < len(data2) and didx1 < len(data1) and coords1[0, didx1] == oidx1 + didx2 < data2_end and didx1 < data1_end and coords1[0, didx1] == oidx1 ): oidx2 = coords2[0, didx2] data_curr = 0 while ( - didx2 < len(data2) - and didx1 < len(data1) + didx2 < data2_end + and didx1 < data1_end and coords2[0, didx2] == oidx2 and coords1[0, didx1] == oidx1 ): - if coords1[1, didx1] < coords2[1, didx2]: - didx1 += 1 - elif coords1[1, didx1] > coords2[1, didx2]: - didx2 += 1 - else: + c1 = coords1[1, didx1] + c2 = coords2[1, didx2] + k = min(c1, c2) + if c1 == k and c2 == k: data_curr += data1[didx1] * data2[didx2] - didx1 += 1 - didx2 += 1 + didx1 += c1 == k + didx2 += c2 == k - while didx2 < len(data2) and coords2[0, didx2] == oidx2: + while didx2 < data2_end and coords2[0, didx2] == oidx2: didx2 += 1 - if didx2 < len(data2): + if didx2 < data2_end: didx1 = didx1_curr if data_curr != 0: coords_out.append((oidx1, oidx2)) data_out.append(data_curr) - while didx1 < len(data1) and coords1[0, didx1] == oidx1: + while didx1 < data1_end and coords1[0, didx1] == oidx1: didx1 += 1 if len(data_out) == 0: