Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Add tf.image.transform function #1637

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/kernels/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,17 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer {
extrapolationValue: number): Tensor4D {
throw new Error('Not yet implemented');
}


transform(
images: Tensor4D,
transforms: Tensor2D,
method: 'bilinear'|'nearest',
outputSize: [number, number],
fillValue: number
): Tensor4D {
throw new Error('Not yet implemented');
}

depthToSpace(x: Tensor4D, blockSize: number, dataFormat: string): Tensor4D {
throw new Error('Not yet implemented');
}
Expand Down
70 changes: 70 additions & 0 deletions src/kernels/backend_cpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3272,7 +3272,77 @@ export class MathBackendCPU implements KernelBackend {
}
return output.toTensor() as Tensor4D;
}

transform(
images: Tensor4D,
transforms: Tensor2D,
method: string,
outputSize: [number, number],
fillValue: number
) {
const [batch, inHeight, inWidth, numChannels] = images.shape;
const numTransforms = transforms.shape[0];

const outHeight = outputSize[0];
const outWidth = outputSize[1];
const imageBuffers = images.bufferSync();

const output = ops.buffer([batch, outHeight, outWidth, numChannels], images.dtype);

const outStride = output.strides;

const transformVals = transforms.dataSync();

// Reference implementation
// tslint:disable-next-line:max-line-length
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/image/kernels/image_ops.h
if((numTransforms !== batch && numTransforms !== 1 )|| transforms.shape[1] !== 8){
throw (new Error("Input transform should be num_images x 8 or 1 x 8"));
}
const readWithFillValue = function(batch : number, y: number, x: number, channel: number, fillValue: number){
return (0 <= y && y < inHeight && 0 <= x && x < inWidth) ? imageBuffers.get(batch, y, x, channel) : fillValue;
}
const bilinearInterpolation = function(batch : number, y: number, x: number, channel: number, fillValue: number){
const xFloor = Math.floor(x);
const yFloor = Math.floor(y);
const xCeil = xFloor+1;
const yCeil = yFloor+1;

const valueYFloor = (xCeil - x) * readWithFillValue(batch, yFloor, xFloor, channel, fillValue)
+ (x - xFloor) * readWithFillValue(batch, yFloor, xCeil, channel, fillValue);

const valueYCeil = (xCeil - x) * readWithFillValue(batch, yCeil, xFloor, channel, fillValue)
+ (x - xFloor) * readWithFillValue(batch, yCeil, xCeil, channel, fillValue);
const res = (yCeil - y)*valueYFloor + (y - yFloor)*valueYCeil;

return res;
}
const nearestInterpolation = function(batch : number, y: number, x: number, channel: number, fillValue: number){
return readWithFillValue(batch, Math.round(y), Math.round(x), channel, fillValue);
}
for (let bInd = 0; bInd < batch; bInd++){
const transform = numTransforms === 1 ? transformVals: transformVals.slice(bInd*8, (bInd+1)*8);

for (let topInd = 0; topInd < outHeight; topInd++){
for (let leftInd = 0; leftInd < outWidth; leftInd++){
const projection = transform[6]*leftInd + transform[7]*topInd + 1;
const floatInputLeft = (transform[0] * leftInd + transform[1] * topInd + transform[2]) / projection;
const floatInputTop = (transform[3] * leftInd + transform[4] * topInd + transform[5]) / projection;

for (let c = 0; c < numChannels; c ++){
const outInd = c + leftInd * outStride[2] + topInd * outStride[1] + bInd * outStride[0];
if (method === 'bilinear') {
output.values[outInd] = bilinearInterpolation(bInd, floatInputTop, floatInputLeft, c, fillValue);
} else { // method == "nearest"
output.values[outInd] = nearestInterpolation(bInd, floatInputTop, floatInputLeft, c, fillValue);
}
}
}
}
}
return output.toTensor() as Tensor4D;
}

sparseToDense<R extends Rank>(
sparseIndices: Tensor, sparseValues: Tensor, outputShape: ShapeMap[R],
defaultValue: Scalar): Tensor<R> {
Expand Down
15 changes: 14 additions & 1 deletion src/kernels/backend_webgl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ import {Conv2DProgram, Conv3DProgram} from './webgl/conv_gpu';
import {DepthwiseConv2DProgram} from './webgl/conv_gpu_depthwise';
import {DepthwiseConvPacked2DProgram} from './webgl/conv_packed_gpu_depthwise';
import {CropAndResizeProgram} from './webgl/crop_and_resize_gpu';
import {TransformProgram} from './webgl/transform_gpu';
import {CumSumProgram} from './webgl/cumsum_gpu';
import {DepthToSpaceProgram} from './webgl/depth_to_space_gpu';
import {EncodeFloatProgram} from './webgl/encode_float_gpu';
Expand Down Expand Up @@ -2048,7 +2049,19 @@ export class MathBackendWebGL implements KernelBackend {
image.shape, boxes.shape, cropSize, method, extrapolationValue);
return this.compileAndRun(program, [image, boxes, boxIndex]);
}


transform(
image: Tensor4D,
transforms: Tensor2D,
method: 'bilinear'|'nearest',
outputSize: [number, number],
fillValue: number
): Tensor4D {
const program = new TransformProgram(image.shape, transforms.shape, method, outputSize, fillValue);

return this.compileAndRun(program, [image, transforms]);
}

depthToSpace(x: Tensor4D, blockSize: number, dataFormat: 'NHWC'|'NCHW'):
Tensor4D {
util.assert(
Expand Down
110 changes: 110 additions & 0 deletions src/kernels/webgl/transform_gpu.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/**
* @license
* Copyright 2017 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import { GPGPUProgram } from './gpgpu_math';

export class TransformProgram implements GPGPUProgram {
variableNames = ['Image', 'Transform'];
outputShape: number[] = [];
userCode: string;

constructor(
imageShape: [number, number, number, number],
transformShape: [number, number],
method: 'bilinear' | 'nearest',
size: [number, number],
fillValue: number
) {
const [, imageHeight, imageWidth, depth] = imageShape;
const [numTransforms,] = transformShape;
const [outHeight, outWidth] = size;
this.outputShape = [numTransforms, outHeight, outWidth, depth];
const methodId = method === 'bilinear' ? 1 : 0;

let texGetTransformId;
if(numTransforms === 1){
texGetTransformId = '0';
} else {
texGetTransformId = 'b';
}


// Reference implementation
// tslint:disable-next-line:max-line-length
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/image/kernels/image_ops_gpu.cu.cc

this.userCode = `
float readFillValue(int b, int y, int x, int d) {
if( y < 0 || y >= ${imageHeight} || x < 0 || x >= ${imageWidth} ) {
return float(${fillValue});
}
return getImage(b, y, x, d);
}
void main() {
ivec4 coords = getOutputCoords();
int b = coords[0];
int y = coords[1];
int x = coords[2];
int d = coords[3];

// get transform vals
float a0 = getTransform(${texGetTransformId},0);
float a1 = getTransform(${texGetTransformId},1);
float a2 = getTransform(${texGetTransformId},2);
float b0 = getTransform(${texGetTransformId},3);
float b1 = getTransform(${texGetTransformId},4);
float b2 = getTransform(${texGetTransformId},5);
float c0 = getTransform(${texGetTransformId},6);
float c1 = getTransform(${texGetTransformId},7);

float projection = c0 * float(x) + c1 * float(y) + 1.0;

float in_y = (b0 * float(x) + b1 * float(y) + b2) / projection;

float in_x = (a0 * float(x) + a1 * float(y) + a2) / projection;

vec2 sourceFracIndexRC = vec2(in_x, in_y);
vec2 sourceFracIndexRCCeil = vec2(in_x + 1.0, in_y + 1.0);
if(${methodId} == 1) {
// Compute the four integer indices.
int xFloor = int(floor(in_x));
int yFloor = int(floor(in_y));
int xCeil = int(floor(in_x + 1.0));
int yCeil = int(floor(in_y + 1.0));

float topLeft = readFillValue(b, yFloor, xFloor, d);
float bottomLeft = readFillValue(b, yCeil, xFloor, d);
float topRight = readFillValue(b, yFloor, xCeil, d);
float bottomRight = readFillValue(b, yCeil, xCeil, d);

float valueYFloor = topLeft * (float(xCeil) - in_x) + topRight * (in_x - float(xFloor));
float valueYCeil = bottomLeft * (float(xCeil) - in_x) + bottomRight * (in_x - float(xFloor));

float newValue = (float(yCeil) - in_y)*valueYFloor + (in_y - float(yFloor))*valueYCeil;

setOutput(float(newValue));
} else {
// Compute the coordinators of nearest neighbor point.
int xRound = int(round(in_x));
int yRound = int(round(in_y));
float newValue = readFillValue(b, yRound, xRound, d);
setOutput(newValue);
}
}
`;
}
}
55 changes: 55 additions & 0 deletions src/ops/image_ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,63 @@ function cropAndResize_(
return res as Tensor4D;
}

/**
* Applies the given transform(s) to the image(s).
*
* @param image 4d tensor of shape `[batch,imageHeight,imageWidth,depth]`,
* where imageHeight and imageWidth must be positive, specifying the
* batch of images to transform
* @param transforms 2d float32 tensor of shape `[batch, 8]` or `[1, 8]`.
* Each entry is a projective transform matrix/matrices
* If one row of transforms is `[a0, a1, a2, b0, b1, b2, c0, c1]`,
* then it maps the output point (x, y) to a transformed input point
* `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`, where `k = c0 x + c1 y + 1`
* @param method Optional, string from `'bilinear' | 'nearest'`,
* defaults to bilinear, which specifies the sampling method for resizing
* @param size Optional, The new size `[newHeight, newWidth]` for the output image
* defaults to `[imageHeight,imageWidth]`
* @param fillValue Optional, the value to fill the outside of the input image pixels
* default to 0
* @return A 4D tensor of the shape `[numBoxes,imageHeight,imageWidth,depth]`
*/
/** @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'} */
function transform_(
image: Tensor4D|TensorLike,
transforms: Tensor2D|TensorLike,
method?: 'bilinear'|'nearest',
size?: [number, number],
fillValue?: number
): Tensor4D {
const $image = convertToTensor(image, 'image', 'transform', 'float32');
const $transforms = convertToTensor(transforms, 'transforms', 'transform', 'float32');
method = method || 'bilinear';
fillValue = fillValue || 0;
size = size || [$image.shape[1], $image.shape[2]];

util.assert(
$image.rank === 4,
() => 'Error in transform: image must be rank 4,' +
`but got rank ${$image.rank}.`);
util.assert(
$transforms.rank === 2 && $transforms.shape[1] === 8 && ($transforms.shape[0] === 1 || $transforms.shape[0] === $image.shape[0]),
() => `Error in transform: transforms must be have size [${$image.shape[0]},8] or [1,8]` +
`but had shape ${$transforms.shape}.`);
util.assert(
size[0] >= 1 && size[1] >= 1,
() => `size must be atleast [1,1], but was ${size}`);
util.assert(
method === 'bilinear' || method === 'nearest',
() => `method must be bilinear or nearest, but was ${method}`);

const forward: ForwardFunc<Tensor4D> = (backend, save) =>
backend.transform($image, $transforms, method, size, fillValue);

const res = ENV.engine.runKernel(forward, {$image, $transforms});
return res as Tensor4D;
}
export const resizeBilinear = op({resizeBilinear_});
export const resizeNearestNeighbor = op({resizeNearestNeighbor_});
export const nonMaxSuppression = op({nonMaxSuppression_});
export const nonMaxSuppressionAsync = nonMaxSuppressionAsync_;
export const cropAndResize = op({cropAndResize_});
export const transform = op({transform_});
Loading