Skip to content

Commit 9752734

Browse files
authored
Resizing Layer (#6879)
Implement the resizing layer. The branch and commit histories have been cleaned up for the PR. IMPORT NOTE: The lower-level op implementation of image resizing-nearest neighbor in TensorFlow.js differs from the implementation of the comparable op in Keras-Python. While the Python version of the op function always selects the bottom right cell of the sub-matrix to be used as the representative value of that region in the downscaled matrix, the JavaScript implementation defaults to the top left cell of the sub-matrix, and then preferentially shifts to the right side of the sub-matrix in all sub-matrices past the lateral halfway point ( calculated by floor((length-1)/2) ), and the bottom side of the sub-matrix in all sub-matrices past the vertical halfway point, when considering the top-left side of the parent matrix as the origin. This causes a slight variation in the output values from nearest neighbor downscaling between the Python and JavaScript versions of the code as it currently stands, and the unit tests for the resizing layer has been implemented to reflect this difference in op-function behavior. Co-authored-by: Adam Lang (@AdamLang96) [email protected] Co-authored-by: Brian Zheng (@Brianzheng123) [email protected]
1 parent 18be40c commit 9752734

File tree

4 files changed

+267
-2
lines changed

4 files changed

+267
-2
lines changed

tfjs-layers/src/exports_layers.ts

100755100644
Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@ import {ZeroPadding2D, ZeroPadding2DLayerArgs} from './layers/padding';
2424
import {AveragePooling1D, AveragePooling2D, AveragePooling3D, GlobalAveragePooling1D, GlobalAveragePooling2D, GlobalMaxPooling1D, GlobalMaxPooling2D, GlobalPooling2DLayerArgs, MaxPooling1D, MaxPooling2D, MaxPooling3D, Pooling1DLayerArgs, Pooling2DLayerArgs, Pooling3DLayerArgs} from './layers/pooling';
2525
import {GRU, GRUCell, GRUCellLayerArgs, GRULayerArgs, LSTM, LSTMCell, LSTMCellLayerArgs, LSTMLayerArgs, RNN, RNNCell, RNNLayerArgs, SimpleRNN, SimpleRNNCell, SimpleRNNCellLayerArgs, SimpleRNNLayerArgs, StackedRNNCells, StackedRNNCellsArgs} from './layers/recurrent';
2626
import {Bidirectional, BidirectionalLayerArgs, TimeDistributed, WrapperLayerArgs} from './layers/wrappers';
27-
import { Rescaling, RescalingArgs } from './layers/preprocessing/image_preprocessing';
28-
import { CategoryEncoding, CategoryEncodingArgs } from './layers/preprocessing/category_encoding';
27+
import {Rescaling, RescalingArgs} from './layers/preprocessing/image_preprocessing';
28+
import {Resizing, ResizingArgs} from './layers/preprocessing/image_resizing';
29+
import {CategoryEncoding, CategoryEncodingArgs} from './layers/preprocessing/category_encoding';
30+
2931
// TODO(cais): Add doc string to all the public static functions in this
3032
// class; include exectuable JavaScript code snippets where applicable
3133
// (b/74074458).
@@ -1730,6 +1732,32 @@ export function rescaling(args?: RescalingArgs) {
17301732
return new Rescaling(args);
17311733
}
17321734

1735+
/**
1736+
* A preprocessing layer which resizes images.
1737+
* This layer resizes an image input to a target height and width. The input
1738+
* should be a 4D (batched) or 3D (unbatched) tensor in `"channels_last"`
1739+
* format. Input pixel values can be of any range (e.g. `[0., 1.)` or `[0,
1740+
* 255]`) and of interger or floating point dtype. By default, the layer will
1741+
* output floats.
1742+
*
1743+
* Arguments:
1744+
* - `height`: number, the height for the output tensor.
1745+
* - `width`: number, the width for the output tensor.
1746+
* - `interpolation`: string, the method for image resizing interpolation.
1747+
* - `cropToAspectRatio`: boolean, whether to keep image aspect ratio.
1748+
*
1749+
* Input shape:
1750+
* Arbitrary.
1751+
*
1752+
* Output shape:
1753+
* height, width, num channels.
1754+
*
1755+
* @doc {heading: 'Layers', subheading: 'Resizing', namespace: 'layers'}
1756+
*/
1757+
export function resizing(args?: ResizingArgs) {
1758+
return new Resizing(args);
1759+
}
1760+
17331761
/**
17341762
* A preprocessing layer which encodes integer features.
17351763
*

tfjs-layers/src/layers/preprocessing/image_preprocessing_test.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
1+
/**
2+
* @license
3+
* Copyright 2022 CodeSmith LLC
4+
*
5+
* Use of this source code is governed by an MIT-style
6+
* license that can be found in the LICENSE file or at
7+
* https://opensource.org/licenses/MIT.
8+
* =============================================================================
9+
*/
10+
111
import { Tensor, randomNormal, mul, add} from '@tensorflow/tfjs-core';
212
import { Rescaling } from './image_preprocessing';
313
import { describeMathCPUAndGPU, expectTensorsClose } from '../../utils/test_utils';
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/**
2+
* @license
3+
* Copyright 2022 CodeSmith LLC
4+
*
5+
* Use of this source code is governed by an MIT-style
6+
* license that can be found in the LICENSE file or at
7+
* https://opensource.org/licenses/MIT.
8+
* =============================================================================
9+
*/
10+
11+
import {image, Rank, serialization, Tensor, tidy} from '@tensorflow/tfjs-core'; // mul, add
12+
13+
import {Layer, LayerArgs} from '../../engine/topology';
14+
import {ValueError} from '../../errors';
15+
import {Shape} from '../../keras_format/common';
16+
import {Kwargs} from '../../types';
17+
import {getExactlyOneShape} from '../../utils/types_utils'; //, getExactlyOneTensor
18+
19+
// tf methods unimplemented in tfjs: 'bicubic', 'area', 'lanczos3', 'lanczos5',
20+
// 'gaussian', 'mitchellcubic'
21+
const INTERPOLATION_KEYS = ['bilinear', 'nearest'] as const;
22+
const INTERPOLATION_METHODS = new Set(INTERPOLATION_KEYS);
23+
type InterpolationType = typeof INTERPOLATION_KEYS[number];
24+
25+
export declare interface ResizingArgs extends LayerArgs {
26+
height: number;
27+
width: number;
28+
interpolation?: InterpolationType; // default = 'bilinear';
29+
cropToAspectRatio?: boolean; // default = false;
30+
}
31+
32+
/**
33+
* Preprocessing Resizing Layer
34+
*
35+
* This resizes images by a scaling and offset factor
36+
*/
37+
38+
export class Resizing extends Layer {
39+
/** @nocollapse */
40+
static className = 'Resizing';
41+
private readonly height: number;
42+
private readonly width: number;
43+
// method of interpolation to be used; default = "bilinear";
44+
private readonly interpolation: InterpolationType;
45+
// toggle whether the aspect ratio should be preserved; default = false;
46+
private readonly cropToAspectRatio: boolean;
47+
48+
constructor(args: ResizingArgs) {
49+
super(args);
50+
51+
this.height = args.height;
52+
this.width = args.width;
53+
54+
if (args.interpolation) {
55+
if (INTERPOLATION_METHODS.has(args.interpolation)) {
56+
this.interpolation = args.interpolation;
57+
} else {
58+
throw new ValueError(`Invalid interpolation parameter: ${
59+
args.interpolation} is not implemented`);
60+
}
61+
} else {
62+
this.interpolation = 'bilinear';
63+
}
64+
this.cropToAspectRatio = Boolean(args.cropToAspectRatio);
65+
}
66+
67+
computeOutputShape(inputShape: Shape|Shape[]): Shape|Shape[] {
68+
inputShape = getExactlyOneShape(inputShape);
69+
const numChannels = inputShape[2];
70+
return [this.height, this.width, numChannels];
71+
}
72+
73+
getConfig(): serialization.ConfigDict {
74+
const config: serialization.ConfigDict = {
75+
'height': this.height,
76+
'width': this.width,
77+
'interpolation': this.interpolation,
78+
'cropToAspectRatio': this.cropToAspectRatio
79+
};
80+
81+
const baseConfig = super.getConfig();
82+
Object.assign(config, baseConfig);
83+
return config;
84+
}
85+
86+
call(inputs: Tensor<Rank.R3>|Tensor<Rank.R4>, kwargs: Kwargs):
87+
Tensor[]|Tensor {
88+
return tidy(() => {
89+
const size: [number, number] = [this.height, this.width];
90+
if (this.interpolation === 'bilinear') {
91+
return image.resizeBilinear(inputs, size, !this.cropToAspectRatio);
92+
} else if (this.interpolation === 'nearest') {
93+
return image.resizeNearestNeighbor(
94+
inputs, size, !this.cropToAspectRatio);
95+
} else {
96+
throw new Error(`Interpolation is ${this.interpolation} but only ${[...INTERPOLATION_METHODS]} are supported`);
97+
}
98+
});
99+
}
100+
}
101+
102+
serialization.registerClass(Resizing);
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
/**
2+
* @license
3+
* Copyright 2022 CodeSmith LLC
4+
*
5+
* Use of this source code is governed by an MIT-style
6+
* license that can be found in the LICENSE file or at
7+
* https://opensource.org/licenses/MIT.
8+
* =============================================================================
9+
*/
10+
11+
/**
12+
* Unit Tests for image resizing layer.
13+
*/
14+
15+
import {image, Rank, Tensor, tensor, zeros, range, reshape} from '@tensorflow/tfjs-core';
16+
17+
// import {Shape} from '../../keras_format/common';
18+
import {describeMathCPUAndGPU, expectTensorsClose} from '../../utils/test_utils';
19+
20+
import {Resizing, ResizingArgs} from './image_resizing';
21+
22+
describeMathCPUAndGPU('Resizing Layer', () => {
23+
it('Check if output shape matches specifications', () => {
24+
// resize and check output shape
25+
const maxHeight = 40;
26+
const height = Math.floor(Math.random() * maxHeight);
27+
const maxWidth = 60;
28+
const width = Math.floor(Math.random() * maxWidth);
29+
const numChannels = 3;
30+
const inputTensor = zeros([height * 2, width * 2, numChannels]);
31+
const expectedOutputShape = [height, width, numChannels];
32+
const resizingLayer = new Resizing({height, width});
33+
const layerOutputTensor = resizingLayer.apply(inputTensor) as Tensor;
34+
expect(layerOutputTensor.shape).toEqual(expectedOutputShape);
35+
});
36+
37+
it('Returns correctly downscaled tensor', () => {
38+
// resize and check output content (not batched)
39+
const rangeTensor = range(0, 16);
40+
const inputTensor = reshape(rangeTensor, [4,4,1]);
41+
const height = 2;
42+
const width = 2;
43+
const interpolation = 'nearest';
44+
const resizingLayer = new Resizing({height, width, interpolation});
45+
const layerOutputTensor = resizingLayer.apply(inputTensor) as Tensor;
46+
const expectedArr = [[0, 3], [12, 15]];
47+
const expectedOutput = tensor(expectedArr, [2,2,1]);
48+
expectTensorsClose(layerOutputTensor, expectedOutput);
49+
});
50+
51+
it('Returns correctly downscaled tensor', () => {
52+
// resize and check output content (batched)
53+
const rangeTensor = range(0, 36);
54+
const inputTensor = reshape(rangeTensor, [1,6,6,1]);
55+
const height = 3;
56+
const width = 3;
57+
const interpolation = 'nearest';
58+
const resizingLayer = new Resizing({height, width, interpolation});
59+
const layerOutputTensor = resizingLayer.apply(inputTensor) as Tensor;
60+
const expectedArr = [[0,3,5], [18,21,23], [30,33,35]];
61+
const expectedOutput = tensor([expectedArr], [1,3,3,1]);
62+
expectTensorsClose(layerOutputTensor, expectedOutput);
63+
});
64+
65+
it('Returns correctly upscaled tensor', () => {
66+
const rangeTensor = range(0, 4);
67+
const inputTensor = reshape(rangeTensor, [1, 2, 2, 1]);
68+
const height = 4;
69+
const width = 4;
70+
const interpolation = 'nearest';
71+
const resizingLayer = new Resizing({height, width, interpolation});
72+
const layerOutputTensor = resizingLayer.apply(inputTensor) as Tensor;
73+
const expectedArr = [[0,0,1,1], [0,0,1,1], [2,2,3,3], [2,2,3,3]];
74+
const expectedOutput = tensor([expectedArr], [1,4,4,1]);
75+
expectTensorsClose(layerOutputTensor, expectedOutput);
76+
});
77+
78+
it('Returns the same tensor when given same shape as input', () => {
79+
// create a resizing layer with same shape as input
80+
const height = 64;
81+
const width = 32;
82+
const numChannels = 1;
83+
const rangeTensor = range(0, height * width);
84+
const inputTensor = reshape(rangeTensor, [height, width, numChannels]);
85+
const resizingLayer = new Resizing({height, width});
86+
const layerOutputTensor = resizingLayer.apply(inputTensor) as Tensor;
87+
expectTensorsClose(layerOutputTensor, inputTensor);
88+
});
89+
90+
it('Returns a tensor of the correct dtype', () => {
91+
// do a same resizing operation, cheeck tensors dtypes and content
92+
const height = 40;
93+
const width = 60;
94+
const numChannels = 3;
95+
const inputTensor: Tensor<Rank.R3> =
96+
zeros([height, width, numChannels]);
97+
const size: [number, number] = [height, width];
98+
const expectedOutputTensor = image.resizeBilinear(inputTensor, size);
99+
const resizingLayer = new Resizing({height, width});
100+
const layerOutputTensor = resizingLayer.apply(inputTensor) as Tensor;
101+
expect(layerOutputTensor.dtype).toBe(inputTensor.dtype);
102+
expectTensorsClose(layerOutputTensor, expectedOutputTensor);
103+
});
104+
105+
it('Throws an error given incorrect parameters', () => {
106+
// pass incorrect interpolation method string to layer init
107+
const height = 16;
108+
const width = 16;
109+
const interpolation = 'unimplemented';
110+
const incorrectArgs = {height, width, interpolation};
111+
const expectedError =
112+
`Invalid interpolation parameter: ${interpolation} is not implemented`;
113+
expect(() => new Resizing(incorrectArgs as ResizingArgs))
114+
.toThrowError(expectedError);
115+
});
116+
117+
it('Config holds correct name', () => {
118+
// layer name property set properly
119+
const height = 40;
120+
const width = 60;
121+
const resizingLayer = new Resizing({height, width, name:'Resizing'});
122+
const config = resizingLayer.getConfig();
123+
expect(config.name).toEqual('Resizing');
124+
});
125+
});

0 commit comments

Comments
 (0)