Skip to content

Commit 3637f58

Browse files
committed
Adding functions to create and manipulate sparse matrices
1 parent 85465c2 commit 3637f58

File tree

4 files changed

+305
-0
lines changed

4 files changed

+305
-0
lines changed

arrayfire/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
from .interop import *
7474
from .timer import *
7575
from .random import *
76+
from .sparse import *
7677

7778
# do not export default modules as part of arrayfire
7879
del ct

arrayfire/sparse.py

Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
#######################################################
2+
# Copyright (c) 2015, ArrayFire
3+
# All rights reserved.
4+
#
5+
# This file is distributed under 3-clause BSD license.
6+
# The complete license agreement can be obtained at:
7+
# http://arrayfire.com/licenses/BSD-3-Clause
8+
########################################################
9+
10+
"""
11+
Functions to create and manipulate sparse matrices.
12+
"""
13+
14+
from .library import *
15+
from .array import *
16+
import numbers
17+
from .interop import to_array
18+
19+
__to_sparse_enum = [STORAGE.DENSE,
20+
STORAGE.CSR,
21+
STORAGE.CSC,
22+
STORAGE.COO]
23+
24+
25+
def sparse(values, row_idx, col_idx, nrows, ncols, storage = STORAGE.CSR):
26+
"""
27+
Create a sparse matrix from it's constituent parts.
28+
29+
Parameters
30+
----------
31+
32+
values : af.Array.
33+
- Contains the non zero elements of the sparse array.
34+
35+
row_idx : af.Array.
36+
- Contains row indices of the sparse array.
37+
38+
col_idx : af.Array.
39+
- Contains column indices of the sparse array.
40+
41+
nrows : int.
42+
- specifies the number of rows in sparse matrix.
43+
44+
ncols : int.
45+
- specifies the number of columns in sparse matrix.
46+
47+
storage : optional: arrayfire.STORAGE. default: arrayfire.STORAGE.CSR.
48+
- Can be one of arrayfire.STORAGE.CSR, arrayfire.STORAGE.COO.
49+
50+
Returns
51+
-------
52+
53+
A sparse matrix.
54+
"""
55+
assert(isinstance(values, Array))
56+
assert(isinstance(row_idx, Array))
57+
assert(isinstance(col_idx, Array))
58+
out = Array()
59+
safe_call(backend.get().af_create_sparse_array(ct.pointer(out.arr), c_dim_t(nrows), c_dim_t(ncols),
60+
values.arr, row_idx.arr, col_idx.arr, storage.value))
61+
return out
62+
63+
def sparse_from_host(values, row_idx, col_idx, nrows, ncols, storage = STORAGE.CSR):
64+
"""
65+
Create a sparse matrix from it's constituent parts.
66+
67+
Parameters
68+
----------
69+
70+
values : Any datatype that can be converted to array.
71+
- Contains the non zero elements of the sparse array.
72+
73+
row_idx : Any datatype that can be converted to array.
74+
- Contains row indices of the sparse array.
75+
76+
col_idx : Any datatype that can be converted to array.
77+
- Contains column indices of the sparse array.
78+
79+
nrows : int.
80+
- specifies the number of rows in sparse matrix.
81+
82+
ncols : int.
83+
- specifies the number of columns in sparse matrix.
84+
85+
storage : optional: arrayfire.STORAGE. default: arrayfire.STORAGE.CSR.
86+
- Can be one of arrayfire.STORAGE.CSR, arrayfire.STORAGE.COO.
87+
88+
Returns
89+
-------
90+
91+
A sparse matrix.
92+
"""
93+
return sparse(to_array(values), to_array(row_idx), to_array(col_idx), nrows, ncols, storage)
94+
95+
def sparse_from_dense(dense, storage = STORAGE.CSR):
96+
"""
97+
Create a sparse matrix from a dense matrix.
98+
99+
Parameters
100+
----------
101+
102+
dense : af.Array.
103+
- A dense matrix.
104+
105+
storage : optional: arrayfire.STORAGE. default: arrayfire.STORAGE.CSR.
106+
- Can be one of arrayfire.STORAGE.CSR, arrayfire.STORAGE.COO.
107+
108+
Returns
109+
-------
110+
111+
A sparse matrix.
112+
"""
113+
assert(isinstance(dense, Array))
114+
out = Array()
115+
safe_call(backend.get().af_create_sparse_array_from_dense(ct.pointer(out.arr), dense.arr, storage.value))
116+
return out
117+
118+
def sparse_to_dense(sparse):
119+
"""
120+
Create a dense matrix from a sparse matrix.
121+
122+
Parameters
123+
----------
124+
125+
sparse : af.Array.
126+
- A sparse matrix.
127+
128+
Returns
129+
-------
130+
131+
A dense matrix.
132+
"""
133+
out = Array()
134+
safe_call(backend.get().af_sparse_to_dense(ct.pointer(out.arr), sparse.arr))
135+
return out
136+
137+
def sparse_get_info(sparse):
138+
"""
139+
Get the constituent arrays and storage info from a sparse matrix.
140+
141+
Parameters
142+
----------
143+
144+
sparse : af.Array.
145+
- A sparse matrix.
146+
147+
Returns
148+
--------
149+
(values, row_idx, col_idx, storage) where
150+
values : arrayfire.Array containing non zero elements from sparse matrix
151+
row_idx : arrayfire.Array containing the row indices
152+
col_idx : arrayfire.Array containing the column indices
153+
storage : sparse storage
154+
"""
155+
values = Array()
156+
row_idx = Array()
157+
col_idx = Array()
158+
stype = ct.c_int(0)
159+
safe_call(backend.get().af_sparse_get_info(ct.pointer(values.arr), ct.pointer(row_idx.arr),
160+
ct.pointer(col_idx.arr), ct.pointer(stype),
161+
sparse.arr))
162+
return (values, row_idx, col_idx, __to_sparse_enum[stype.value])
163+
164+
def sparse_get_values(sparse):
165+
"""
166+
Get the non zero values from sparse matrix.
167+
168+
Parameters
169+
----------
170+
171+
sparse : af.Array.
172+
- A sparse matrix.
173+
174+
Returns
175+
--------
176+
arrayfire array containing the non zero elements.
177+
178+
"""
179+
values = Array()
180+
safe_call(backend.get().af_sparse_get_values(ct.pointer(values.arr), sparse.arr))
181+
return values
182+
183+
def sparse_get_row_idx(sparse):
184+
"""
185+
Get the row indices from sparse matrix.
186+
187+
Parameters
188+
----------
189+
190+
sparse : af.Array.
191+
- A sparse matrix.
192+
193+
Returns
194+
--------
195+
arrayfire array containing the non zero elements.
196+
197+
"""
198+
row_idx = Array()
199+
safe_call(backend.get().af_sparse_get_row_idx(ct.pointer(row_idx.arr), sparse.arr))
200+
return row_idx
201+
202+
def sparse_get_col_idx(sparse):
203+
"""
204+
Get the column indices from sparse matrix.
205+
206+
Parameters
207+
----------
208+
209+
sparse : af.Array.
210+
- A sparse matrix.
211+
212+
Returns
213+
--------
214+
arrayfire array containing the non zero elements.
215+
216+
"""
217+
col_idx = Array()
218+
safe_call(backend.get().af_sparse_get_col_idx(ct.pointer(col_idx.arr), sparse.arr))
219+
return col_idx
220+
221+
def sparse_get_nnz(sparse):
222+
"""
223+
Get the column indices from sparse matrix.
224+
225+
Parameters
226+
----------
227+
228+
sparse : af.Array.
229+
- A sparse matrix.
230+
231+
Returns
232+
--------
233+
Number of non zero elements in the sparse matrix.
234+
235+
"""
236+
nnz = c_dim_t(0)
237+
safe_call(backend.get().af_sparse_get_nnz(ct.pointer(nnz), sparse.arr))
238+
return nnz.value
239+
240+
def sparse_get_storage(sparse):
241+
"""
242+
Get the column indices from sparse matrix.
243+
244+
Parameters
245+
----------
246+
247+
sparse : af.Array.
248+
- A sparse matrix.
249+
250+
Returns
251+
--------
252+
Number of non zero elements in the sparse matrix.
253+
254+
"""
255+
storage = ct.c_int(0)
256+
safe_call(backend.get().af_sparse_get_storage(ct.pointer(storage), sparse.arr))
257+
return __to_sparse_enum[storage.value]
258+
259+
def sparse_convert_to(sparse, storage):
260+
"""
261+
Convert sparse matrix from one format to another.
262+
263+
Parameters
264+
----------
265+
266+
storage : arrayfire.STORAGE.
267+
268+
Returns
269+
-------
270+
271+
Sparse matrix converted to the appropriate type.
272+
"""
273+
out = Array()
274+
safe_call(backend.get().af_sparse_convert_to(ct.pointer(out.arr), sparse.arr, storage.value))
275+
return out

arrayfire/tests/simple/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,5 @@
1919
from .signal import *
2020
from .statistics import *
2121
from .random import *
22+
from .sparse import *
2223
from ._util import tests

arrayfire/tests/simple/sparse.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#!/usr/bin/python
2+
#######################################################
3+
# Copyright (c) 2015, ArrayFire
4+
# All rights reserved.
5+
#
6+
# This file is distributed under 3-clause BSD license.
7+
# The complete license agreement can be obtained at:
8+
# http://arrayfire.com/licenses/BSD-3-Clause
9+
########################################################
10+
11+
import arrayfire as af
12+
from . import _util
13+
14+
def simple_sparse(verbose=False):
15+
display_func = _util.display_func(verbose)
16+
print_func = _util.print_func(verbose)
17+
18+
dd = af.randu(5, 5)
19+
ds = dd * (dd > 0.5)
20+
sp = af.sparse_from_dense(ds)
21+
display_func(af.sparse_get_info(sp))
22+
display_func(af.sparse_get_values(sp))
23+
display_func(af.sparse_get_row_idx(sp))
24+
display_func(af.sparse_get_col_idx(sp))
25+
print_func(af.sparse_get_nnz(sp))
26+
print_func(af.sparse_get_storage(sp))
27+
28+
_util.tests['sparse'] = simple_sparse

0 commit comments

Comments
 (0)