Skip to content

Commit 34e707f

Browse files
tmrn411nedtwigg
authored andcommitted
Fix multi-dimensional indexing bug, add tests, and some convenience routines
1 parent ba51f01 commit 34e707f

File tree

7 files changed

+321
-0
lines changed

7 files changed

+321
-0
lines changed

src/main/java/com/jmatio/types/MLArray.java

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ public class MLArray {
4141
protected int attributes;
4242
protected int type;
4343
private int dimsFactors[]; // Used for calculating the index in arrays with higher dimensions than 2.
44+
private int dimStrides[]; // Used to convert multidimensional index to linear index
4445

4546
public static final String DEFAULT_NAME = "@";
4647

@@ -62,6 +63,15 @@ public MLArray(String name, int[] dims, int type, int attributes) {
6263
dimsFactors[dimIx] = f;
6364
f *= dims[dimIx];
6465
}
66+
67+
// This essentially does the same thing as the previous block with dimsFactors, except
68+
// the indices are processed in order from 0 to max in order to match Matlab's
69+
// column major ordering of storage in the .mat file
70+
dimStrides = new int[dims.length];
71+
for (int dimIx = 0, f = 1; dimIx < dims.length; dimIx++ ) {
72+
dimStrides[dimIx] = f;
73+
f *= dims[dimIx];
74+
}
6575
}
6676

6777
/**
@@ -80,6 +90,26 @@ public int getIndex(int... indexes) {
8090
}
8191
return ix;
8292
}
93+
94+
/**
95+
* Returns the one-dim index for the multi-dimensional indexes. Compatible with matlab multi-dimensional indexing.
96+
*
97+
* Note: this performs the same logical function as getIndex, but the indices are computed in column major order
98+
* for compatibility with .mat files generated by Matlab.
99+
*
100+
* @param indexes Length must be same as number of dimensions. Element value must be >= 0 and < dimension size for the corresponding dimension.
101+
* @return The linear index
102+
*/
103+
public int getIndexCM(int ... indexes) {
104+
if (indexes.length != dims.length) {
105+
throw new IllegalArgumentException("Cannot use " + indexes.length + " indexes for " + dims.length + " dimensions.");
106+
}
107+
int ix = 0;
108+
for (int dimIx = 0; dimIx < indexes.length; dimIx++) {
109+
ix += dimStrides[dimIx] * validateDimSize(dimIx, indexes[dimIx]);
110+
}
111+
return ix;
112+
}
83113

84114
private int validateDimSize(int dimIx, int ixInDim) {
85115
if (dims[dimIx] > ixInDim) {

src/main/java/com/jmatio/types/MLNumericArray.java

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,16 @@ public T getReal(int m, int n) {
7878
return getReal(getIndex(m, n));
7979
}
8080

81+
/**
82+
* Get real component of an element from a multidimensional array using indices. This is compatible with .mat files created or used by Matlab.
83+
*
84+
* @param indices Length must be same as number of dimensions. Element value must be >= 0 and < dimension size for the corresponding dimension.
85+
* @return The value.
86+
*/
87+
public T getRealCM(int... indices) {
88+
return getReal(getIndexCM(indices));
89+
}
90+
8191
/**
8292
* @param index
8393
* @return
@@ -153,6 +163,16 @@ public T getImaginary(int m, int n) {
153163
return getImaginary(getIndex(m, n));
154164
}
155165

166+
/**
167+
* Get an imaginary component of element from a multidimensional array using indices. This is compatible with .mat files created or used by Matlab.
168+
*
169+
* @param indices Length must be same as number of dimensions. Element value must be >= 0 and < dimension size for the corresponding dimension.
170+
* @return The value.
171+
*/
172+
public T getImaginaryCM(int... indices) {
173+
return getImaginary(getIndexCM(indices));
174+
}
175+
156176
/**
157177
* @param index
158178
* @return
@@ -248,6 +268,16 @@ public T get(int... indices) {
248268
return get(getIndex(indices));
249269
}
250270

271+
/**
272+
* Get an element from a multidimensional array using indices. This is compatible with .mat files created or used by Matlab.
273+
*
274+
* @param indices Length must be same as number of dimensions. Element value must be >= 0 and < dimension size for the corresponding dimension.
275+
* @return The value.
276+
*/
277+
public T getCM(int... indices) {
278+
return get(getIndexCM(indices));
279+
}
280+
251281
/**
252282
* @param vector
253283
*/

src/test/java/com/jmatio/io/MatIOTest.java

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,174 @@ public void testBenchmarkUInt8() throws Exception {
126126
// test if MLArray objects are equal
127127
assertEquals("Test if value red from file equals value stored", mluint8, mlArrayRetrived);
128128
}
129+
130+
@Test
131+
public void testMultipleDimArrayRealFromMatlabCreatedFile() throws IOException {
132+
int ndims = 5;
133+
int[] dims = new int[]{2, 3, 4, 5, 6};
134+
File file = getTestFile("multiDimMatrix.mat");
135+
MatFileReader reader = new MatFileReader(file);
136+
MLDouble mlArray = (MLDouble) reader.getMLArray("in");
137+
138+
int testNDims = mlArray.getNDimensions();
139+
Assert.assertEquals(ndims, testNDims);
140+
141+
int[] testDims = mlArray.getDimensions();
142+
for (int i = 0; i < ndims; i++){
143+
Assert.assertEquals(dims[i], testDims[i]);
144+
}
145+
146+
Double expectedVal = 0.0;
147+
for (int i = 0; i < dims[4]; i++) {
148+
for (int j = 0; j < dims[3]; j++) {
149+
for (int k = 0; k < dims[2]; k++) {
150+
for (int l = 0; l < dims[1]; l++) {
151+
for (int m = 0; m < dims[0]; m++, expectedVal += 1.0) {
152+
Double actual = mlArray.getReal( mlArray.getIndexCM(m, l, k, j, i) );
153+
Assert.assertEquals( expectedVal, actual );
154+
}
155+
}
156+
}
157+
}
158+
}
159+
}
160+
161+
@Test
162+
public void testMultipleDimArrayRealWIndicesFromMatlabCreatedFile() throws IOException {
163+
int ndims = 5;
164+
int[] dims = new int[]{2, 3, 4, 5, 6};
165+
File file = getTestFile("multiDimMatrix.mat");
166+
MatFileReader reader = new MatFileReader(file);
167+
MLDouble mlArray = (MLDouble) reader.getMLArray("in");
168+
169+
int testNDims = mlArray.getNDimensions();
170+
Assert.assertEquals(ndims, testNDims);
171+
172+
int[] testDims = mlArray.getDimensions();
173+
for (int i = 0; i < ndims; i++){
174+
Assert.assertEquals(dims[i], testDims[i]);
175+
}
176+
177+
Double expectedVal = 0.0;
178+
for (int i = 0; i < dims[4]; i++) {
179+
for (int j = 0; j < dims[3]; j++) {
180+
for (int k = 0; k < dims[2]; k++) {
181+
for (int l = 0; l < dims[1]; l++) {
182+
for (int m = 0; m < dims[0]; m++, expectedVal += 1.0) {
183+
Double actual = mlArray.getRealCM( m, l, k, j, i );
184+
Assert.assertEquals( expectedVal, actual );
185+
}
186+
}
187+
}
188+
}
189+
}
190+
}
191+
192+
@Test
193+
public void testMultipleDimArrayGetFromMatlabCreatedFile() throws IOException {
194+
int ndims = 5;
195+
int[] dims = new int[]{2, 3, 4, 5, 6};
196+
File file = getTestFile("multiDimMatrix.mat");
197+
MatFileReader reader = new MatFileReader(file);
198+
MLDouble mlArray = (MLDouble) reader.getMLArray("in");
199+
200+
int testNDims = mlArray.getNDimensions();
201+
Assert.assertEquals(ndims, testNDims);
202+
203+
int[] testDims = mlArray.getDimensions();
204+
for (int i = 0; i < ndims; i++){
205+
Assert.assertEquals(dims[i], testDims[i]);
206+
}
207+
208+
Double expectedVal = 0.0;
209+
for (int i = 0; i < dims[4]; i++) {
210+
for (int j = 0; j < dims[3]; j++) {
211+
for (int k = 0; k < dims[2]; k++) {
212+
for (int l = 0; l < dims[1]; l++) {
213+
for (int m = 0; m < dims[0]; m++, expectedVal += 1.0) {
214+
Double actual = mlArray.getCM( m, l, k, j, i );
215+
Assert.assertEquals( expectedVal, actual );
216+
}
217+
}
218+
}
219+
}
220+
}
221+
}
222+
223+
@Test
224+
public void testMultipleDimArrayComplexFromMatlabCreatedFile() throws IOException {
225+
int ndims = 5;
226+
int[] dims = new int[]{2, 3, 4, 5, 6};
227+
File file = getTestFile("multiDimComplexMatrix.mat");
228+
MatFileReader reader = new MatFileReader(file);
229+
MLDouble mlArray = (MLDouble) reader.getMLArray("in");
230+
231+
int testNDims = mlArray.getNDimensions();
232+
Assert.assertEquals(ndims, testNDims);
233+
234+
int[] testDims = mlArray.getDimensions();
235+
for (int i = 0; i < ndims; i++){
236+
Assert.assertEquals(dims[i], testDims[i]);
237+
}
238+
239+
Double expectedValRe = 0.0;
240+
for (int i = 0; i < dims[4]; i++) {
241+
for (int j = 0; j < dims[3]; j++) {
242+
for (int k = 0; k < dims[2]; k++) {
243+
for (int l = 0; l < dims[1]; l++) {
244+
for (int m = 0; m < dims[0]; m++, expectedValRe += 1.0) {
245+
Double actualRe = mlArray.getReal( mlArray.getIndexCM(m, l, k, j, i) );
246+
Assert.assertEquals( expectedValRe, actualRe );
247+
Double actualIm = mlArray.getImaginary( mlArray.getIndexCM(m, l, k, j, i) );
248+
Double expectedValIm = 0.0;
249+
if (expectedValRe != 0.0) {
250+
expectedValIm = expectedValRe * -1.0;
251+
}
252+
Assert.assertEquals( expectedValIm, actualIm );
253+
}
254+
}
255+
}
256+
}
257+
}
258+
}
259+
260+
@Test
261+
public void testMultipleDimArrayComplexWIndicesFromMatlabCreatedFile() throws IOException {
262+
int ndims = 5;
263+
int[] dims = new int[]{2, 3, 4, 5, 6};
264+
File file = getTestFile("multiDimComplexMatrix.mat");
265+
MatFileReader reader = new MatFileReader(file);
266+
MLDouble mlArray = (MLDouble) reader.getMLArray("in");
267+
268+
int testNDims = mlArray.getNDimensions();
269+
Assert.assertEquals(ndims, testNDims);
270+
271+
int[] testDims = mlArray.getDimensions();
272+
for (int i = 0; i < ndims; i++){
273+
Assert.assertEquals(dims[i], testDims[i]);
274+
}
275+
276+
Double expectedValRe = 0.0;
277+
for (int i = 0; i < dims[4]; i++) {
278+
for (int j = 0; j < dims[3]; j++) {
279+
for (int k = 0; k < dims[2]; k++) {
280+
for (int l = 0; l < dims[1]; l++) {
281+
for (int m = 0; m < dims[0]; m++, expectedValRe += 1.0) {
282+
Double actualRe = mlArray.getRealCM( m, l, k, j, i );
283+
Assert.assertEquals( expectedValRe, actualRe );
284+
Double actualIm = mlArray.getImaginaryCM( m, l, k, j, i );
285+
Double expectedValIm = 0.0;
286+
if (expectedValRe != 0.0) {
287+
expectedValIm = expectedValRe * -1.0;
288+
}
289+
Assert.assertEquals( expectedValIm, actualIm );
290+
}
291+
}
292+
}
293+
}
294+
}
295+
}
296+
129297

130298
@Test
131299
public void testCellFromMatlabCreatedFile() throws IOException {
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
2+
3+
% determine if script is being run from maven style project directory
4+
% if so, set root dir approriately, otherwise just use current directory
5+
% in maven directory structure this file should be located in directory
6+
% project_root/src/test/matlab/com/jmatio/TestMultiDimArray.m
7+
proj_root = '../../../../../'
8+
test_rsrc_dir = 'src/test/resources/'
9+
10+
rsrc_dir = strcat( proj_root, test_rsrc_dir )
11+
12+
fnMultiDimMatrix = 'multiDimMatrix.mat'
13+
14+
% by default, use current directory as root
15+
test_dir = './';
16+
if exist( rsrc_dir, 'dir' )
17+
test_dir = rsrc_dir;
18+
end
19+
20+
genMultiDimMatrix(strcat( test_dir, fnMultiDimMatrix ))
21+
22+
23+
24+
25+
function op = genMultiDimMatrix( filePath )
26+
in = zeros(2, 3, 4, 5, 6);
27+
e = 0;
28+
for i =1:6
29+
for j = 1:5
30+
for k = 1:4
31+
for l = 1:3
32+
for m = 1:2
33+
in(m, l, k, j, i) = e;
34+
e = e + 1;
35+
end
36+
end
37+
end
38+
end
39+
end
40+
41+
42+
43+
44+
save( filePath, 'in' )
45+
end
46+
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
2+
3+
% determine if script is being run from maven style project directory
4+
% if so, set root dir approriately, otherwise just use current directory
5+
% in maven directory structure this file should be located in directory
6+
% project_root/src/test/matlab/com/jmatio/TestMultiDimComplexArray.m
7+
proj_root = '../../../../../'
8+
test_rsrc_dir = 'src/test/resources/'
9+
10+
rsrc_dir = strcat( proj_root, test_rsrc_dir )
11+
12+
fnMultiDimComplexMatrix = 'multiDimComplexMatrix.mat'
13+
14+
% by default, use current directory as root
15+
test_dir = './';
16+
if exist( rsrc_dir, 'dir' )
17+
test_dir = rsrc_dir;
18+
end
19+
20+
genMultiDimComplexMatrix(strcat( test_dir, fnMultiDimComplexMatrix ))
21+
22+
23+
24+
25+
function op = genMultiDimComplexMatrix( filePath )
26+
a = zeros(2, 3, 4, 5, 6);
27+
in = complex(a, 0);
28+
e = 0;
29+
for i =1:6
30+
for j = 1:5
31+
for k = 1:4
32+
for l = 1:3
33+
for m = 1:2
34+
in(m, l, k, j, i) = complex(e, -e);
35+
e = e + 1;
36+
end
37+
end
38+
end
39+
end
40+
end
41+
42+
43+
44+
45+
save( filePath, 'in' );
46+
end
47+
2.42 KB
Binary file not shown.
1.25 KB
Binary file not shown.

0 commit comments

Comments
 (0)