Skip to content

Commit 833041d

Browse files
DirkToeweannxingyuan
authored andcommitted
tf.linalg.bandPart (#2226)
FEATURE
1 parent 0f249b3 commit 833041d

File tree

2 files changed

+301
-2
lines changed

2 files changed

+301
-2
lines changed

tfjs-core/src/ops/linalg_ops.ts

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,116 @@
2222
import {ENGINE} from '../engine';
2323
import {dispose} from '../globals';
2424
import {Tensor, Tensor1D, Tensor2D} from '../tensor';
25+
import {convertToTensor} from '../tensor_util_env';
26+
import {TensorLike} from '../types';
2527
import {assert} from '../util';
2628
import {eye, squeeze, stack, unstack} from './array_ops';
29+
import {sub} from './binary_ops';
2730
import {split} from './concat_split';
31+
import {logicalAnd, where} from './logical_ops';
2832
import {norm} from './norm';
2933
import {op} from './operation';
3034
import {sum} from './reduction_ops';
31-
import {tensor2d} from './tensor_ops';
35+
import {range, scalar, tensor2d, zeros} from './tensor_ops';
36+
37+
/**
38+
* Copy a tensor setting everything outside a central band in each innermost
39+
* matrix to zero.
40+
*
41+
* The band part is computed as follows: Assume input has `k` dimensions
42+
* `[I, J, K, ..., M, N]`, then the output is a tensor with the same shape where
43+
* `band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]`.
44+
* The indicator function
45+
* `in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower))`
46+
* `&& (num_upper < 0 || (n-m) <= num_upper)`
47+
*
48+
* ```js
49+
* const x = tf.tensor2d([[ 0, 1, 2, 3],
50+
* [-1, 0, 1, 2],
51+
* [-2, -1, 0, 1],
52+
* [-3, -2, -1, 0]]);
53+
* let y = tf.linalg.bandPart(x, 1, -1);
54+
* y.print(); // [[ 0, 1, 2, 3],
55+
* // [-1, 0, 1, 2],
56+
* // [ 0, -1, 0, 1],
57+
* // [ 0, 0 , -1, 0]]
58+
* let z = tf.linalg.bandPart(x, 2, 1);
59+
* z.print(); // [[ 0, 1, 0, 0],
60+
* // [-1, 0, 1, 0],
61+
* // [-2, -1, 0, 1],
62+
* // [ 0, -2, -1, 0]]
63+
* ```
64+
*
65+
* @param x Rank `k` tensor
66+
* @param numLower Number of subdiagonals to keep.
67+
* If negative, keep entire lower triangle.
68+
* @param numUpper Number of subdiagonals to keep.
69+
* If negative, keep entire upper triangle.
70+
* @returns Rank `k` tensor of the same shape as input.
71+
* The extracted banded tensor.
72+
*/
73+
/**
74+
* @doc {heading:'Operations',
75+
* subheading:'Linear Algebra',
76+
* namespace:'linalg'}
77+
*/
78+
function bandPart_<T extends Tensor>(
79+
a: T|TensorLike, numLower: number, numUpper: number
80+
): T
81+
{
82+
if( numLower%1 !== 0 ){
83+
throw new Error(
84+
`bandPart(): numLower must be an integer, got ${numLower}.`
85+
);
86+
}
87+
if( numUpper%1 !== 0 ){
88+
throw new Error(
89+
`bandPart(): numUpper must be an integer, got ${numUpper}.`
90+
);
91+
}
92+
93+
const $a = convertToTensor(a,'a','bandPart');
94+
95+
if( $a.rank < 2 ) {
96+
throw new Error(`bandPart(): Rank must be at least 2, got ${$a.rank}.`);
97+
}
98+
99+
const shape = $a.shape,
100+
[M,N] = $a.shape.slice(-2);
101+
102+
if( !(numLower <= M) ) {
103+
throw new Error(
104+
`bandPart(): numLower (${numLower})` +
105+
` must not be greater than the number of rows (${M}).`
106+
);
107+
}
108+
if( !(numUpper <= N) ) {
109+
throw new Error(
110+
`bandPart(): numUpper (${numUpper})` +
111+
` must not be greater than the number of columns (${N}).`
112+
);
113+
}
114+
115+
if( numLower < 0 ) { numLower = M; }
116+
if( numUpper < 0 ) { numUpper = N; }
117+
118+
const i = range(0,M, 1, 'int32').reshape([-1,1]),
119+
j = range(0,N, 1, 'int32'),
120+
ij = sub(i,j);
121+
122+
const inBand = logicalAnd(
123+
ij. lessEqual( scalar(+numLower,'int32') ),
124+
ij.greaterEqual( scalar(-numUpper,'int32') )
125+
);
126+
127+
const zero = zeros([M,N], $a.dtype);
128+
129+
return stack(
130+
unstack( $a.reshape([-1,M,N]) ).map(
131+
mat => where(inBand, mat, zero)
132+
)
133+
).reshape(shape) as T;
134+
}
32135

33136
/**
34137
* Gram-Schmidt orthogonalization.
@@ -263,5 +366,6 @@ function qr2d(x: Tensor2D, fullMatrices = false): [Tensor2D, Tensor2D] {
263366
}) as [Tensor2D, Tensor2D];
264367
}
265368

369+
export const bandPart = op({bandPart_});
266370
export const gramSchmidt = op({gramSchmidt_});
267371
export const qr = op({qr_});

tfjs-core/src/ops/linalg_ops_test.ts

Lines changed: 196 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,206 @@
1717

1818
import * as tf from '../index';
1919
import {ALL_ENVS, describeWithFlags} from '../jasmine_util';
20-
import {Tensor1D, Tensor2D} from '../tensor';
20+
import {Tensor1D, Tensor2D, Tensor3D} from '../tensor';
2121
import {expectArraysClose} from '../test_util';
2222

2323
import {scalar, tensor1d, tensor2d, tensor3d, tensor4d} from './ops';
2424

25+
describeWithFlags('bandPart', ALL_ENVS, () => {
26+
it('keeps tensor unchanged', async () => {
27+
const x: Tensor2D = tensor2d([1, 1, 1, 1, 1, 1, 1, 1, 1], [3, 3]);
28+
expectArraysClose(
29+
await tf.linalg.bandPart(x, -1, -1).array(),
30+
[[1, 1, 1], [1, 1, 1], [1, 1, 1]]);
31+
});
32+
33+
it('upper triangular matrix', async () => {
34+
const x: Tensor2D = tensor2d([1, 1, 1, 1, 1, 1, 1, 1, 1], [3, 3]);
35+
expectArraysClose(
36+
await tf.linalg.bandPart(x, 0, -1).array(),
37+
[[1, 1, 1], [0, 1, 1], [0, 0, 1]]);
38+
});
39+
40+
it('lower triangular matrix', async () => {
41+
const x: Tensor2D = tensor2d([1, 1, 1, 1, 1, 1, 1, 1, 1], [3, 3]);
42+
expectArraysClose(
43+
await tf.linalg.bandPart(x, -1, 0).array(),
44+
[[1, 0, 0], [1, 1, 0], [1, 1, 1]]);
45+
});
46+
47+
it('diagonal elements', async () => {
48+
const x: Tensor2D = tensor2d([1, 1, 1, 1, 1, 1, 1, 1, 1], [3, 3]);
49+
expectArraysClose(
50+
await tf.linalg.bandPart(x, 0, 0).array(),
51+
[[1, 0, 0], [0, 1, 0], [0, 0, 1]]);
52+
});
53+
54+
it('lower triangular elements', async () => {
55+
const x: Tensor2D = tensor2d([1, 1, 1, 1, 1, 1, 1, 1, 1], [3, 3]);
56+
expectArraysClose(
57+
await tf.linalg.bandPart(x, 1, 0).array(),
58+
[[1, 0, 0], [1, 1, 0], [0, 1, 1]]);
59+
});
60+
61+
it('upper triangular elements', async () => {
62+
const x: Tensor2D = tensor2d([1, 1, 1, 1, 1, 1, 1, 1, 1], [3, 3]);
63+
expectArraysClose(
64+
await tf.linalg.bandPart(x, 0, 1).array(),
65+
[[1, 1, 0], [0, 1, 1], [0, 0, 1]]);
66+
});
67+
68+
it('4X4 matrix - tensorflow python examples', async () => {
69+
const x: Tensor2D = tensor2d(
70+
[[0, 1, 2, 3], [-1, 0, 1, 2], [-2, -1, 0, 1], [-3, -2, -1, 0]]);
71+
expectArraysClose(
72+
await tf.linalg.bandPart(x, 1, -1).array(),
73+
[[0, 1, 2, 3], [-1, 0, 1, 2], [0, -1, 0, 1], [0, 0, -1, 0]]);
74+
expectArraysClose(
75+
await tf.linalg.bandPart(x, 2, 1).array(),
76+
[[0, 1, 0, 0], [-1, 0, 1, 0], [-2, -1, 0, 1], [0, -2, -1, 0]]);
77+
});
78+
79+
it('3 dimensional matrix', async () => {
80+
const x: Tensor3D = tensor3d([[[1, 1], [1, 1]], [[1, 1], [1, 1]]]);
81+
expectArraysClose(
82+
await tf.linalg.bandPart(x, 0, 0).array(),
83+
[[[1, 0], [0, 1]], [[1, 0], [0, 1]]]);
84+
});
85+
86+
it('2X3X3 tensor', async () => {
87+
const x: Tensor3D = tensor3d(
88+
[[[1, 1, 1], [1, 1, 1], [1, 1, 1]], [[1, 1, 1], [1, 1, 1], [1, 1, 1]]]);
89+
expectArraysClose(
90+
await tf.linalg.bandPart(x, 1, 2).array(),
91+
[[[1, 1, 1], [1, 1, 1], [0, 1, 1]], [[1, 1, 1], [1, 1, 1], [0, 1, 1]]]);
92+
});
93+
94+
const la = tf.linalg;
95+
96+
it('fails for scalar', async () => {
97+
const x = scalar(1);
98+
expect( () => la.bandPart(x, 1, 2) ).toThrowError(/bandPart.*rank/i);
99+
});
100+
101+
it('fails for 1D tensor', async () => {
102+
const x = tensor1d([1, 2, 3, 4, 5]);
103+
expect( () => la.bandPart(x, 1, 2) ).toThrowError(/bandPart.*rank/i);
104+
});
105+
106+
it('fails if numLower or numUpper too large', async () => {
107+
const a = tf.tensor2d([[1, 2, 3],
108+
[4, 5, 6]]);
109+
110+
for( const numLower of [ 3,5,8,13] ) {
111+
for( const numUpper of [-1,0,1, 2] ) {
112+
expect( () => tf.linalg.bandPart(a, numLower, numUpper) )
113+
.toThrowError(/bandPart.*numLower/i);
114+
}}
115+
116+
for( const numLower of [-1,0,1] ) {
117+
for( const numUpper of [ 4,5,9] ) {
118+
expect( () => tf.linalg.bandPart(a, numLower, numUpper) )
119+
.toThrowError(/bandPart.*numUpper/i);
120+
}}
121+
122+
for( const numLower of [ 3,5,8,13] ) {
123+
for( const numUpper of [ 4,5, 9] ) {
124+
expect( () => tf.linalg.bandPart(a, numLower, numUpper) )
125+
.toThrowError(/bandPart.*(numLower|numUpper)/i);
126+
}}
127+
});
128+
129+
it('works for 3x4 example', async () => {
130+
const a = tf.tensor2d([[1, 2, 3, 4],
131+
[5, 6, 7, 8],
132+
[9,10,11,12]]);
133+
134+
expectArraysClose(
135+
await la.bandPart(a,0,0).array(),
136+
[[1, 0, 0, 0],
137+
[0, 6, 0, 0],
138+
[0, 0,11, 0]]
139+
);
140+
expectArraysClose(
141+
await la.bandPart(a,0,1).array(),
142+
[[1, 2, 0, 0],
143+
[0, 6, 7, 0],
144+
[0, 0,11,12]]
145+
);
146+
expectArraysClose(
147+
await la.bandPart(a,0,2).array(),
148+
[[1, 2, 3, 0],
149+
[0, 6, 7, 8],
150+
[0, 0,11,12]]
151+
);
152+
for( const numUpper of [3,4,-1,-2] ) {
153+
expectArraysClose(
154+
await la.bandPart(a,0,numUpper).array(),
155+
[[1, 2, 3, 4],
156+
[0, 6, 7, 8],
157+
[0, 0,11,12]]
158+
);
159+
}
160+
161+
expectArraysClose(
162+
await la.bandPart(a,1,0).array(),
163+
[[1, 0, 0, 0],
164+
[5, 6, 0, 0],
165+
[0,10,11, 0]]
166+
);
167+
expectArraysClose(
168+
await la.bandPart(a,1,1).array(),
169+
[[1, 2, 0, 0],
170+
[5, 6, 7, 0],
171+
[0,10,11,12]]
172+
);
173+
expectArraysClose(
174+
await la.bandPart(a,1,2).array(),
175+
[[1, 2, 3, 0],
176+
[5, 6, 7, 8],
177+
[0,10,11,12]]
178+
);
179+
for( const numUpper of [3,4,-1,-2] ) {
180+
expectArraysClose(
181+
await la.bandPart(a,1,numUpper).array(),
182+
[[1, 2, 3, 4],
183+
[5, 6, 7, 8],
184+
[0,10,11,12]]
185+
);
186+
}
187+
188+
for( const numLower of [2,3,-1,-2])
189+
{
190+
expectArraysClose(
191+
await la.bandPart(a,numLower,0).array(),
192+
[[1, 0, 0, 0],
193+
[5, 6, 0, 0],
194+
[9,10,11, 0]]
195+
);
196+
expectArraysClose(
197+
await la.bandPart(a,numLower,1).array(),
198+
[[1, 2, 0, 0],
199+
[5, 6, 7, 0],
200+
[9,10,11,12]]
201+
);
202+
expectArraysClose(
203+
await la.bandPart(a,numLower,2).array(),
204+
[[1, 2, 3, 0],
205+
[5, 6, 7, 8],
206+
[9,10,11,12]]
207+
);
208+
for( const numUpper of [3,4,-1,-2] ) {
209+
expectArraysClose(
210+
await la.bandPart(a,numLower,numUpper).array(),
211+
[[1, 2, 3, 4],
212+
[5, 6, 7, 8],
213+
[9,10,11,12]]
214+
);
215+
}
216+
}
217+
});
218+
}); // end bandPart
219+
25220
describeWithFlags('gramSchmidt-tiny', ALL_ENVS, () => {
26221
it('2x2, Array of Tensor1D', async () => {
27222
const xs: Tensor1D[] = [

0 commit comments

Comments
 (0)