Skip to content

tf.linalg.bandPart #2226

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Nov 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 105 additions & 1 deletion tfjs-core/src/ops/linalg_ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,116 @@
import {ENGINE} from '../engine';
import {dispose} from '../globals';
import {Tensor, Tensor1D, Tensor2D} from '../tensor';
import {convertToTensor} from '../tensor_util_env';
import {TensorLike} from '../types';
import {assert} from '../util';
import {eye, squeeze, stack, unstack} from './array_ops';
import {sub} from './binary_ops';
import {split} from './concat_split';
import {logicalAnd, where} from './logical_ops';
import {norm} from './norm';
import {op} from './operation';
import {sum} from './reduction_ops';
import {tensor2d} from './tensor_ops';
import {range, scalar, tensor2d, zeros} from './tensor_ops';

/**
* Copy a tensor setting everything outside a central band in each innermost
* matrix to zero.
*
* The band part is computed as follows: Assume input has `k` dimensions
* `[I, J, K, ..., M, N]`, then the output is a tensor with the same shape where
* `band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]`.
* The indicator function
* `in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower))`
* `&& (num_upper < 0 || (n-m) <= num_upper)`
*
* ```js
* const x = tf.tensor2d([[ 0, 1, 2, 3],
* [-1, 0, 1, 2],
* [-2, -1, 0, 1],
* [-3, -2, -1, 0]]);
* let y = tf.linalg.bandPart(x, 1, -1);
* y.print(); // [[ 0, 1, 2, 3],
* // [-1, 0, 1, 2],
* // [ 0, -1, 0, 1],
* // [ 0, 0 , -1, 0]]
* let z = tf.linalg.bandPart(x, 2, 1);
* z.print(); // [[ 0, 1, 0, 0],
* // [-1, 0, 1, 0],
* // [-2, -1, 0, 1],
* // [ 0, -2, -1, 0]]
* ```
*
* @param x Rank `k` tensor
* @param numLower Number of subdiagonals to keep.
* If negative, keep entire lower triangle.
* @param numUpper Number of subdiagonals to keep.
* If negative, keep entire upper triangle.
* @returns Rank `k` tensor of the same shape as input.
* The extracted banded tensor.
*/
/**
* @doc {heading:'Operations',
* subheading:'Linear Algebra',
* namespace:'linalg'}
*/
function bandPart_<T extends Tensor>(
a: T|TensorLike, numLower: number, numUpper: number
): T
{
if( numLower%1 !== 0 ){
throw new Error(
`bandPart(): numLower must be an integer, got ${numLower}.`
);
}
if( numUpper%1 !== 0 ){
throw new Error(
`bandPart(): numUpper must be an integer, got ${numUpper}.`
);
}

const $a = convertToTensor(a,'a','bandPart');

if( $a.rank < 2 ) {
throw new Error(`bandPart(): Rank must be at least 2, got ${$a.rank}.`);
}

const shape = $a.shape,
[M,N] = $a.shape.slice(-2);

if( !(numLower <= M) ) {
throw new Error(
`bandPart(): numLower (${numLower})` +
` must not be greater than the number of rows (${M}).`
);
}
if( !(numUpper <= N) ) {
throw new Error(
`bandPart(): numUpper (${numUpper})` +
` must not be greater than the number of columns (${N}).`
);
}

if( numLower < 0 ) { numLower = M; }
if( numUpper < 0 ) { numUpper = N; }

const i = range(0,M, 1, 'int32').reshape([-1,1]),
j = range(0,N, 1, 'int32'),
ij = sub(i,j);

const inBand = logicalAnd(
ij. lessEqual( scalar(+numLower,'int32') ),
ij.greaterEqual( scalar(-numUpper,'int32') )
);

const zero = zeros([M,N], $a.dtype);

return stack(
unstack( $a.reshape([-1,M,N]) ).map(
mat => where(inBand, mat, zero)
)
).reshape(shape) as T;
}

/**
* Gram-Schmidt orthogonalization.
Expand Down Expand Up @@ -263,5 +366,6 @@ function qr2d(x: Tensor2D, fullMatrices = false): [Tensor2D, Tensor2D] {
}) as [Tensor2D, Tensor2D];
}

export const bandPart = op({bandPart_});
export const gramSchmidt = op({gramSchmidt_});
export const qr = op({qr_});
197 changes: 196 additions & 1 deletion tfjs-core/src/ops/linalg_ops_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,206 @@

import * as tf from '../index';
import {ALL_ENVS, describeWithFlags} from '../jasmine_util';
import {Tensor1D, Tensor2D} from '../tensor';
import {Tensor1D, Tensor2D, Tensor3D} from '../tensor';
import {expectArraysClose} from '../test_util';

import {scalar, tensor1d, tensor2d, tensor3d, tensor4d} from './ops';

describeWithFlags('bandPart', ALL_ENVS, () => {
it('keeps tensor unchanged', async () => {
const x: Tensor2D = tensor2d([1, 1, 1, 1, 1, 1, 1, 1, 1], [3, 3]);
expectArraysClose(
await tf.linalg.bandPart(x, -1, -1).array(),
[[1, 1, 1], [1, 1, 1], [1, 1, 1]]);
});

it('upper triangular matrix', async () => {
const x: Tensor2D = tensor2d([1, 1, 1, 1, 1, 1, 1, 1, 1], [3, 3]);
expectArraysClose(
await tf.linalg.bandPart(x, 0, -1).array(),
[[1, 1, 1], [0, 1, 1], [0, 0, 1]]);
});

it('lower triangular matrix', async () => {
const x: Tensor2D = tensor2d([1, 1, 1, 1, 1, 1, 1, 1, 1], [3, 3]);
expectArraysClose(
await tf.linalg.bandPart(x, -1, 0).array(),
[[1, 0, 0], [1, 1, 0], [1, 1, 1]]);
});

it('diagonal elements', async () => {
const x: Tensor2D = tensor2d([1, 1, 1, 1, 1, 1, 1, 1, 1], [3, 3]);
expectArraysClose(
await tf.linalg.bandPart(x, 0, 0).array(),
[[1, 0, 0], [0, 1, 0], [0, 0, 1]]);
});

it('lower triangular elements', async () => {
const x: Tensor2D = tensor2d([1, 1, 1, 1, 1, 1, 1, 1, 1], [3, 3]);
expectArraysClose(
await tf.linalg.bandPart(x, 1, 0).array(),
[[1, 0, 0], [1, 1, 0], [0, 1, 1]]);
});

it('upper triangular elements', async () => {
const x: Tensor2D = tensor2d([1, 1, 1, 1, 1, 1, 1, 1, 1], [3, 3]);
expectArraysClose(
await tf.linalg.bandPart(x, 0, 1).array(),
[[1, 1, 0], [0, 1, 1], [0, 0, 1]]);
});

it('4X4 matrix - tensorflow python examples', async () => {
const x: Tensor2D = tensor2d(
[[0, 1, 2, 3], [-1, 0, 1, 2], [-2, -1, 0, 1], [-3, -2, -1, 0]]);
expectArraysClose(
await tf.linalg.bandPart(x, 1, -1).array(),
[[0, 1, 2, 3], [-1, 0, 1, 2], [0, -1, 0, 1], [0, 0, -1, 0]]);
expectArraysClose(
await tf.linalg.bandPart(x, 2, 1).array(),
[[0, 1, 0, 0], [-1, 0, 1, 0], [-2, -1, 0, 1], [0, -2, -1, 0]]);
});

it('3 dimensional matrix', async () => {
const x: Tensor3D = tensor3d([[[1, 1], [1, 1]], [[1, 1], [1, 1]]]);
expectArraysClose(
await tf.linalg.bandPart(x, 0, 0).array(),
[[[1, 0], [0, 1]], [[1, 0], [0, 1]]]);
});

it('2X3X3 tensor', async () => {
const x: Tensor3D = tensor3d(
[[[1, 1, 1], [1, 1, 1], [1, 1, 1]], [[1, 1, 1], [1, 1, 1], [1, 1, 1]]]);
expectArraysClose(
await tf.linalg.bandPart(x, 1, 2).array(),
[[[1, 1, 1], [1, 1, 1], [0, 1, 1]], [[1, 1, 1], [1, 1, 1], [0, 1, 1]]]);
});

const la = tf.linalg;

it('fails for scalar', async () => {
const x = scalar(1);
expect( () => la.bandPart(x, 1, 2) ).toThrowError(/bandPart.*rank/i);
});

it('fails for 1D tensor', async () => {
const x = tensor1d([1, 2, 3, 4, 5]);
expect( () => la.bandPart(x, 1, 2) ).toThrowError(/bandPart.*rank/i);
});

it('fails if numLower or numUpper too large', async () => {
const a = tf.tensor2d([[1, 2, 3],
[4, 5, 6]]);

for( const numLower of [ 3,5,8,13] ) {
for( const numUpper of [-1,0,1, 2] ) {
expect( () => tf.linalg.bandPart(a, numLower, numUpper) )
.toThrowError(/bandPart.*numLower/i);
}}

for( const numLower of [-1,0,1] ) {
for( const numUpper of [ 4,5,9] ) {
expect( () => tf.linalg.bandPart(a, numLower, numUpper) )
.toThrowError(/bandPart.*numUpper/i);
}}

for( const numLower of [ 3,5,8,13] ) {
for( const numUpper of [ 4,5, 9] ) {
expect( () => tf.linalg.bandPart(a, numLower, numUpper) )
.toThrowError(/bandPart.*(numLower|numUpper)/i);
}}
});

it('works for 3x4 example', async () => {
const a = tf.tensor2d([[1, 2, 3, 4],
[5, 6, 7, 8],
[9,10,11,12]]);

expectArraysClose(
await la.bandPart(a,0,0).array(),
[[1, 0, 0, 0],
[0, 6, 0, 0],
[0, 0,11, 0]]
);
expectArraysClose(
await la.bandPart(a,0,1).array(),
[[1, 2, 0, 0],
[0, 6, 7, 0],
[0, 0,11,12]]
);
expectArraysClose(
await la.bandPart(a,0,2).array(),
[[1, 2, 3, 0],
[0, 6, 7, 8],
[0, 0,11,12]]
);
for( const numUpper of [3,4,-1,-2] ) {
expectArraysClose(
await la.bandPart(a,0,numUpper).array(),
[[1, 2, 3, 4],
[0, 6, 7, 8],
[0, 0,11,12]]
);
}

expectArraysClose(
await la.bandPart(a,1,0).array(),
[[1, 0, 0, 0],
[5, 6, 0, 0],
[0,10,11, 0]]
);
expectArraysClose(
await la.bandPart(a,1,1).array(),
[[1, 2, 0, 0],
[5, 6, 7, 0],
[0,10,11,12]]
);
expectArraysClose(
await la.bandPart(a,1,2).array(),
[[1, 2, 3, 0],
[5, 6, 7, 8],
[0,10,11,12]]
);
for( const numUpper of [3,4,-1,-2] ) {
expectArraysClose(
await la.bandPart(a,1,numUpper).array(),
[[1, 2, 3, 4],
[5, 6, 7, 8],
[0,10,11,12]]
);
}

for( const numLower of [2,3,-1,-2])
{
expectArraysClose(
await la.bandPart(a,numLower,0).array(),
[[1, 0, 0, 0],
[5, 6, 0, 0],
[9,10,11, 0]]
);
expectArraysClose(
await la.bandPart(a,numLower,1).array(),
[[1, 2, 0, 0],
[5, 6, 7, 0],
[9,10,11,12]]
);
expectArraysClose(
await la.bandPart(a,numLower,2).array(),
[[1, 2, 3, 0],
[5, 6, 7, 8],
[9,10,11,12]]
);
for( const numUpper of [3,4,-1,-2] ) {
expectArraysClose(
await la.bandPart(a,numLower,numUpper).array(),
[[1, 2, 3, 4],
[5, 6, 7, 8],
[9,10,11,12]]
);
}
}
});
}); // end bandPart

describeWithFlags('gramSchmidt-tiny', ALL_ENVS, () => {
it('2x2, Array of Tensor1D', async () => {
const xs: Tensor1D[] = [
Expand Down