From 812f7e85a2b2e402b27d126185f845e86d2cd60b Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Thu, 12 Jul 2018 10:14:04 -0700 Subject: [PATCH] implemented tensor array ops and hooked up with the executor --- src/data/types.ts | 5 + src/executor/execution_context.ts | 16 +- src/executor/execution_context_test.ts | 20 +- src/executor/graph_executor.ts | 27 +- src/executor/tensor_array.ts | 12 +- .../executors/arithmetic_executor_test.ts | 2 +- .../executors/basic_math_executor_test.ts | 2 +- src/operations/executors/control_executor.ts | 93 +++++++ .../executors/control_executor_test.ts | 182 ++++++++++++- .../executors/convolution_executor_test.ts | 2 +- .../executors/creation_executor_test.ts | 2 +- .../executors/graph_executor_test.ts | 2 +- .../executors/image_executor_test.ts | 2 +- .../executors/logical_executor_test.ts | 2 +- .../executors/matrices_executor_test.ts | 2 +- .../executors/normalization_executor_test.ts | 2 +- .../executors/reduction_executor_test.ts | 2 +- .../executors/slice_join_executor.ts | 11 +- .../executors/slice_join_executor_test.ts | 2 +- .../executors/transformation_executor_test.ts | 2 +- src/operations/op_list/control.json | 256 ++++++++++++++++++ src/operations/operation_executor_test.ts | 2 +- 22 files changed, 618 insertions(+), 30 deletions(-) diff --git a/src/data/types.ts b/src/data/types.ts index 25a95211..ad27b16e 100644 --- a/src/data/types.ts +++ b/src/data/types.ts @@ -15,6 +15,7 @@ * ============================================================================= */ import {DataType, Tensor} from '@tensorflow/tfjs-core'; +import {TensorArray} from '../executor/tensor_array'; export type NamedTensorMap = { [key: string]: Tensor @@ -24,6 +25,10 @@ export type NamedTensorsMap = { [key: string]: Tensor[] }; +export type TensorArrayMap = { + [key: number]: TensorArray +}; + export interface TensorInfo { name: string; shape?: number[]; diff --git a/src/executor/execution_context.ts b/src/executor/execution_context.ts index a589c894..e792d7e6 100644 --- a/src/executor/execution_context.ts +++ b/src/executor/execution_context.ts @@ -16,7 +16,9 @@ */ import {Tensor} from '@tensorflow/tfjs-core'; -import {NamedTensorsMap} from '../data/types'; +import {NamedTensorsMap, TensorArrayMap} from '../data/types'; + +import {TensorArray} from './tensor_array'; export interface ExecutionContextInfo { id: number; // the unique id of the context info @@ -40,7 +42,9 @@ export class ExecutionContext { private lastId = 0; private _currentContextIds: string[]; - constructor(public weightMap: NamedTensorsMap) { + constructor( + public readonly weightMap: NamedTensorsMap, + public readonly tensorArrayMap: TensorArrayMap) { this.generateCurrentContextIds(); } @@ -151,4 +155,12 @@ export class ExecutionContext { getWeight(name: string): Tensor[] { return this.weightMap[name]; } + + addTensorArray(tensorArray: TensorArray) { + this.tensorArrayMap[tensorArray.id] = tensorArray; + } + + getTensorArray(id: number): TensorArray { + return this.tensorArrayMap[id]; + } } diff --git a/src/executor/execution_context_test.ts b/src/executor/execution_context_test.ts index 97d0be3a..c2756505 100644 --- a/src/executor/execution_context_test.ts +++ b/src/executor/execution_context_test.ts @@ -16,13 +16,14 @@ */ import {ExecutionContext} from './execution_context'; +import {TensorArray} from './tensor_array'; let context: ExecutionContext; +let tensorArray: TensorArray; describe('ExecutionContext', () => { beforeEach(() => { - context = new ExecutionContext({}); + context = new ExecutionContext({}, {}); }); - afterEach(() => {}); it('should initialize', () => { expect(context.currentContext).toEqual([ @@ -31,6 +32,21 @@ describe('ExecutionContext', () => { expect(context.currentContextId).toEqual(''); }); + describe('tensor array', () => { + beforeEach(() => { + tensorArray = new TensorArray('', 'float32', 10, [1], true, true, true); + }); + + it('should be able to add tensor array', () => { + context.addTensorArray(tensorArray); + expect(context.getTensorArray(tensorArray.id)).toBe(tensorArray); + }); + + it('should be able to read tensor array', () => { + expect(context.getTensorArray(tensorArray.id)).toBeUndefined(); + }); + }); + describe('enterFrame', () => { it('should add new Frame', () => { context.enterFrame('1'); diff --git a/src/executor/graph_executor.ts b/src/executor/graph_executor.ts index b3516ad2..b3f2fde4 100644 --- a/src/executor/graph_executor.ts +++ b/src/executor/graph_executor.ts @@ -15,11 +15,12 @@ * ============================================================================= */ -// tslint:disable-next-line:max-line-length import {DataType, Tensor, tidy, util} from '@tensorflow/tfjs-core'; -import {NamedTensorMap, NamedTensorsMap, TensorInfo} from '../data/types'; -import {getNodeNameAndIndex, getTensor} from '../operations/executors/utils'; +// tslint:disable-next-line:max-line-length +import {NamedTensorMap, NamedTensorsMap, TensorArrayMap, TensorInfo} from '../data/types'; +// tslint:disable-next-line:max-line-length +import {getNodeNameAndIndex, getParamValue, getTensor} from '../operations/executors/utils'; import {executeOp} from '../operations/operation_executor'; import {Graph, Node} from '../operations/types'; @@ -128,8 +129,9 @@ export class GraphExecutor { execute(inputs: NamedTensorsMap, outputs?: string|string[]): NamedTensorMap { this.checkInput(inputs); this.checkInputShapeAndType(inputs); + const tensorArrayMap: TensorArrayMap = {}; const result = tidy(() => { - const context = new ExecutionContext(this._weightMap); + const context = new ExecutionContext(this._weightMap, tensorArrayMap); const tensors = this.compiledOrder.reduce((map, node) => { map[node.name] = executeOp(node, map, context) as Tensor[]; @@ -153,7 +155,8 @@ export class GraphExecutor { Promise { this.checkInput(inputs); this.checkInputShapeAndType(inputs); - const context = new ExecutionContext(this._weightMap); + const tensorArrayMap: TensorArrayMap = {}; + const context = new ExecutionContext(this._weightMap, tensorArrayMap); // Graph with control flow op requires runtime evaluation of the execution // order, while without control flow the execution order is pre-determined // in the compile method. @@ -196,10 +199,20 @@ export class GraphExecutor { while (stack.length > 0) { const item = stack.pop(); context.currentContext = item.contexts; - + let nodeName = ''; + // The tensor of the Enter op with isConstant set should be set + // in the parent scope, so it will be available as constant for the + // whole loop. + if (item.node.op === 'enter' && + getParamValue('isConstant', item.node, tensorMap, context)) { + [nodeName] = getNodeNameAndIndex(item.node.name, context); + } const tensors = executeOp(item.node, tensorMap, context); - const [nodeName, ] = getNodeNameAndIndex(item.node.name, context); + if (!nodeName) { + [nodeName] = getNodeNameAndIndex(item.node.name, context); + } + tensorMap[nodeName] = await tensors; item.node.children.forEach((childNode) => { const [nodeName, ] = getNodeNameAndIndex(childNode.name, context); diff --git a/src/executor/tensor_array.ts b/src/executor/tensor_array.ts index f4e7c63d..d20d2c69 100644 --- a/src/executor/tensor_array.ts +++ b/src/executor/tensor_array.ts @@ -28,14 +28,18 @@ export interface TensorWithState { * allows reading from the array and writing to the array. */ export class TensorArray { + private static nextId = 0; private tensors: TensorWithState[] = []; private closed_ = false; + readonly id: number; constructor( public readonly name: string, public readonly dtype: DataType, private maxSize: number, private elementShape: number[], public readonly identicalElementShapes: boolean, public readonly dynamicSize: boolean, - public readonly clearAfterRead: boolean) {} + public readonly clearAfterRead: boolean) { + this.id = TensorArray.nextId++; + } get closed() { return this.closed_; @@ -114,6 +118,12 @@ export class TensorArray { because the value dtype is ${ tensor.dtype}, but TensorArray dtype is ${this.dtype}.`); } + + // Set the shape for the first time write to unknow shape tensor array + if (this.size() === 0 && this.elementShape.length === 0) { + this.elementShape = tensor.shape; + } + util.assertShapesMatch( this.elementShape, tensor.shape, `TensorArray ${this.name}: Could not write to TensorArray index ${ diff --git a/src/operations/executors/arithmetic_executor_test.ts b/src/operations/executors/arithmetic_executor_test.ts index be6c8bc3..478e468e 100644 --- a/src/operations/executors/arithmetic_executor_test.ts +++ b/src/operations/executors/arithmetic_executor_test.ts @@ -26,7 +26,7 @@ describe('arithmetic', () => { let node: Node; const input1 = [tfc.scalar(1)]; const input2 = [tfc.scalar(1)]; - const context = new ExecutionContext({}); + const context = new ExecutionContext({}, {}); beforeEach(() => { node = { diff --git a/src/operations/executors/basic_math_executor_test.ts b/src/operations/executors/basic_math_executor_test.ts index 4ec6fa8c..3dad5b6a 100644 --- a/src/operations/executors/basic_math_executor_test.ts +++ b/src/operations/executors/basic_math_executor_test.ts @@ -25,7 +25,7 @@ import {createNumberAttr, createTensorAttr} from './test_helper'; describe('basic math', () => { let node: Node; const input1 = [tfc.scalar(1)]; - const context = new ExecutionContext({}); + const context = new ExecutionContext({}, {}); beforeEach(() => { node = { diff --git a/src/operations/executors/control_executor.ts b/src/operations/executors/control_executor.ts index dc4db716..ff5a3780 100644 --- a/src/operations/executors/control_executor.ts +++ b/src/operations/executors/control_executor.ts @@ -16,9 +16,11 @@ */ import * as tfc from '@tensorflow/tfjs-core'; +import {scalar} from '@tensorflow/tfjs-core'; import {NamedTensorsMap} from '../../data/types'; import {ExecutionContext} from '../../executor/execution_context'; +import {TensorArray} from '../../executor/tensor_array'; import {Node} from '../types'; import {getParamValue, getTensor} from './utils'; @@ -61,6 +63,97 @@ export async function executeOp( getParamValue('tensor', node, tensorMap, context) as tfc.Tensor; context.nextIteration(); return [input]; + + case 'tensorArray': + const size = getParamValue('size', node, tensorMap, context) as number; + const dtype = + getParamValue('dtype', node, tensorMap, context) as tfc.DataType; + const elementShape = + getParamValue('elementShape', node, tensorMap, context) as number[]; + const dynamicSize = + getParamValue('dynamicSize', node, tensorMap, context) as boolean; + const clearAfterRead = + getParamValue('clearAfterRead', node, tensorMap, context) as boolean; + const identicalElementShapes = + getParamValue('identicalElementShapes', node, tensorMap, context) as + boolean; + const name = getParamValue('name', node, tensorMap, context) as string; + const tensorArray = new TensorArray( + name, dtype, size, elementShape, identicalElementShapes, dynamicSize, + clearAfterRead); + context.addTensorArray(tensorArray); + return [scalar(tensorArray.id), scalar(1.0)]; + + case 'tensorArrayWrite': + const id = + getParamValue('tensorArrayId', node, tensorMap, context) as number; + const index = getParamValue('index', node, tensorMap, context) as number; + const writeTensor = + getParamValue('tensor', node, tensorMap, context) as tfc.Tensor; + const writeTensorArray = context.getTensorArray(id); + writeTensorArray.write(index, writeTensor); + return [scalar(1.0)]; + + case 'tensorArrayRead': + const readId = + getParamValue('tensorArrayId', node, tensorMap, context) as number; + const readIndex = + getParamValue('index', node, tensorMap, context) as number; + const readTensorArray = context.getTensorArray(readId); + return [readTensorArray.read(readIndex)]; + + case 'tensorArrayGather': + const gatherId = + getParamValue('tensorArrayId', node, tensorMap, context) as number; + const gatherIndices = + getParamValue('indices', node, tensorMap, context) as number[]; + const gatherDtype = + getParamValue('dtype', node, tensorMap, context) as tfc.DataType; + const gatherTensorArray = context.getTensorArray(gatherId); + return [gatherTensorArray.gather(gatherIndices, gatherDtype)]; + + case 'tensorArrayScatter': + const scatterId = + getParamValue('tensorArrayId', node, tensorMap, context) as number; + const scatterIndices = + getParamValue('indices', node, tensorMap, context) as number[]; + const scatterTensor = + getParamValue('tensor', node, tensorMap, context) as tfc.Tensor; + const scatterTensorArray = context.getTensorArray(scatterId); + scatterTensorArray.scatter(scatterIndices, scatterTensor); + return [scalar(1.0)]; + + case 'tensorArrayConcat': + const concatId = + getParamValue('tensorArrayId', node, tensorMap, context) as number; + const concatTensorArray = context.getTensorArray(concatId); + const concatDtype = + getParamValue('dtype', node, tensorMap, context) as tfc.DataType; + return [concatTensorArray.concat(concatDtype)]; + + case 'tensorArraySplit': + const splitId = + getParamValue('tensorArrayId', node, tensorMap, context) as number; + const splitTensor = + getParamValue('tensor', node, tensorMap, context) as tfc.Tensor; + const lengths = + getParamValue('lengths', node, tensorMap, context) as number[]; + const splitTensorArray = context.getTensorArray(splitId); + splitTensorArray.split(lengths, splitTensor); + return [scalar(1.0)]; + + case 'tensorArraySize': + const sizeId = + getParamValue('tensorArrayId', node, tensorMap, context) as number; + const sizeTensorArray = context.getTensorArray(sizeId); + return [scalar(sizeTensorArray.size(), 'int32')]; + + case 'tensorArrayClose': + const closeId = + getParamValue('tensorArrayId', node, tensorMap, context) as number; + const closeTensorArray = context.getTensorArray(closeId); + closeTensorArray.clearAndClose(); + return []; default: throw TypeError(`Node type ${node.op} is not implemented`); } diff --git a/src/operations/executors/control_executor_test.ts b/src/operations/executors/control_executor_test.ts index 39f31ff1..60607817 100644 --- a/src/operations/executors/control_executor_test.ts +++ b/src/operations/executors/control_executor_test.ts @@ -15,17 +15,21 @@ * ============================================================================= */ import * as tfc from '@tensorflow/tfjs-core'; +import {scalar, tensor1d, tensor2d} from '@tensorflow/tfjs-core'; +import {expectArraysClose} from '@tensorflow/tfjs-core/dist/test_util'; import {ExecutionContext} from '../../executor/execution_context'; +import {TensorArray} from '../../executor/tensor_array'; import {Node} from '../types'; import {executeOp} from './control_executor'; -import {createTensorAttr} from './test_helper'; +// tslint:disable-next-line:max-line-length +import {createBoolAttr, createDtypeAttr, createNumberAttrFromIndex, createNumericArrayAttr, createNumericArrayAttrFromIndex, createStrAttr, createTensorAttr} from './test_helper'; describe('control', () => { let node: Node; - const input1 = [tfc.scalar(1)]; - const context = new ExecutionContext({}); + const input1 = [tfc.scalar(1, 'int32')]; + const context = new ExecutionContext({}, {}); beforeEach(() => { node = { @@ -113,5 +117,177 @@ describe('control', () => { expect(context.nextIteration).toHaveBeenCalled(); }); }); + + describe('tensorArray', () => { + it('should create new tensor on the context', async () => { + node.op = 'tensorArray'; + node.params['name'] = createStrAttr(''); + node.params['dtype'] = createDtypeAttr('int32'); + node.params['elementShape'] = createNumericArrayAttr([10, 10]); + node.params['dynamicSize'] = createBoolAttr(false); + node.params['clearAfterRead'] = createBoolAttr(true); + node.params['identicalElementShapes'] = createBoolAttr(true); + node.inputNames = ['input1']; + + const tensorId = + (await executeOp(node, {input1}, context))[0].dataSync()[0]; + expect(context.getTensorArray(tensorId)).toBeDefined(); + }); + }); + + describe('tensorArrayWrite', () => { + it('should write the tensor to tensorArray', async () => { + const tensorArray = + new TensorArray('', 'int32', 5, [], true, false, true); + context.addTensorArray(tensorArray); + node.op = 'tensorArrayWrite'; + node.params['tensorArrayId'] = createNumberAttrFromIndex(0); + node.params['index'] = createNumberAttrFromIndex(1); + node.params['tensor'] = createTensorAttr(2); + node.inputNames = ['input2', 'input3', 'input1']; + const input2 = [scalar(tensorArray.id)]; + const input3 = [scalar(0)]; + await executeOp(node, {input1, input2, input3}, context); + + expect(tensorArray.size()).toEqual(1); + }); + }); + + describe('tensorArrayRead', () => { + it('should read the tensor from tensorArray', async () => { + const tensorArray = + new TensorArray('', 'int32', 5, [3], true, false, true); + const input4 = tensor1d([0, 0, 0], 'int32'); + tensorArray.write(0, input4); + context.addTensorArray(tensorArray); + node.op = 'tensorArrayRead'; + node.params['tensorArrayId'] = createNumberAttrFromIndex(0); + node.params['index'] = createNumberAttrFromIndex(1); + node.inputNames = ['input2', 'input3']; + const input2 = [scalar(tensorArray.id)]; + const input3 = [scalar(0)]; + const read = await executeOp(node, {input1, input2, input3}, context); + + expectArraysClose(read[0], input4); + }); + }); + + describe('tensorArrayGather', () => { + it('should gather the tensors from tensorArray', async () => { + const tensorArray = + new TensorArray('', 'int32', 5, [3], true, false, true); + const input4 = tensor1d([0, 0, 0], 'int32'); + const input5 = tensor1d([1, 1, 1], 'int32'); + tensorArray.writeMany([0, 1], [input4, input5]); + context.addTensorArray(tensorArray); + node.op = 'tensorArrayGather'; + node.params['tensorArrayId'] = createNumberAttrFromIndex(0); + node.params['indices'] = createNumericArrayAttrFromIndex(1); + node.params['dtype'] = createDtypeAttr('int32'); + node.inputNames = ['input2', 'input3']; + const input2 = [scalar(tensorArray.id)]; + const input3 = [tensor1d([0, 1])]; + const gather = await executeOp(node, {input2, input3}, context); + expect(gather.length).toEqual(1); + expect(gather[0].shape).toEqual([2, 3]); + expectArraysClose( + gather[0].dataSync(), new Int32Array([0, 0, 0, 1, 1, 1])); + }); + }); + + describe('tensorArrayScatter', () => { + it('should scatter the tensor to tensorArray', async () => { + const tensorArray = + new TensorArray('', 'int32', 5, [3], true, false, true); + const input4 = [tensor2d([0, 0, 0, 1, 1, 1], [2, 3], 'int32')]; + context.addTensorArray(tensorArray); + node.op = 'tensorArrayScatter'; + node.params['tensorArrayId'] = createNumberAttrFromIndex(0); + node.params['indices'] = createNumericArrayAttrFromIndex(1); + node.params['tensor'] = createTensorAttr(2); + node.inputNames = ['input2', 'input3', 'input4']; + const input2 = [scalar(tensorArray.id)]; + const input3 = [tensor1d([0, 1], 'int32')]; + await executeOp(node, {input2, input3, input4}, context); + + expect(tensorArray.size()).toEqual(2); + }); + }); + + describe('tensorArraySplit', () => { + it('should split the tensor to tensorArray', async () => { + const tensorArray = + new TensorArray('', 'int32', 2, [3], true, false, true); + const input4 = [tensor2d([0, 0, 0, 1, 1, 1], [2, 3], 'int32')]; + context.addTensorArray(tensorArray); + node.op = 'tensorArraySplit'; + node.params['tensorArrayId'] = createNumberAttrFromIndex(0); + node.params['tensor'] = createTensorAttr(1); + node.params['lengths'] = createNumericArrayAttrFromIndex(2); + node.inputNames = ['input2', 'input4', 'input3']; + const input2 = [scalar(tensorArray.id)]; + const input3 = [tensor1d([1, 1], 'int32')]; + await executeOp(node, {input2, input3, input4}, context); + + expect(tensorArray.size()).toEqual(2); + }); + }); + + describe('tensorArrayConcat', () => { + it('should concat the tensors from tensorArray', async () => { + const tensorArray = + new TensorArray('', 'int32', 5, [3], true, false, true); + const input4 = tensor1d([0, 0, 0], 'int32'); + const input5 = tensor1d([1, 1, 1], 'int32'); + tensorArray.writeMany([0, 1], [input4, input5]); + context.addTensorArray(tensorArray); + node.op = 'tensorArrayConcat'; + node.params['tensorArrayId'] = createNumberAttrFromIndex(0); + node.params['dtype'] = createDtypeAttr('int32'); + node.inputNames = ['input2']; + const input2 = [scalar(tensorArray.id)]; + const concat = await executeOp(node, {input2}, context); + expect(concat.length).toEqual(1); + expect(concat[0].shape).toEqual([6]); + expectArraysClose( + concat[0].dataSync(), new Int32Array([0, 0, 0, 1, 1, 1])); + }); + }); + + describe('tensorArraySize', () => { + it('should get the size of tensorArray', async () => { + const tensorArray = + new TensorArray('', 'int32', 5, [3], true, false, true); + const input4 = tensor1d([0, 0, 0], 'int32'); + const input5 = tensor1d([1, 1, 1], 'int32'); + tensorArray.writeMany([0, 1], [input4, input5]); + context.addTensorArray(tensorArray); + node.op = 'tensorArraySize'; + node.params['tensorArrayId'] = createNumberAttrFromIndex(0); + node.inputNames = ['input2']; + const input2 = [scalar(tensorArray.id)]; + const size = await executeOp(node, {input2}, context); + expect(size.length).toEqual(1); + expect(size[0].shape).toEqual([]); + expectArraysClose(size[0].dataSync(), new Int32Array([2])); + }); + }); + + describe('tensorArrayClose', () => { + it('should close the tensorArray', async () => { + const tensorArray = + new TensorArray('', 'int32', 5, [3], true, false, true); + const input4 = tensor1d([0, 0, 0], 'int32'); + const input5 = tensor1d([1, 1, 1], 'int32'); + tensorArray.writeMany([0, 1], [input4, input5]); + context.addTensorArray(tensorArray); + node.op = 'tensorArrayClose'; + node.params['tensorArrayId'] = createNumberAttrFromIndex(0); + node.inputNames = ['input2']; + const input2 = [scalar(tensorArray.id)]; + await executeOp(node, {input2}, context); + expect(tensorArray.closed).toBeTruthy(); + }); + }); }); }); diff --git a/src/operations/executors/convolution_executor_test.ts b/src/operations/executors/convolution_executor_test.ts index cd418151..f48fc9b1 100644 --- a/src/operations/executors/convolution_executor_test.ts +++ b/src/operations/executors/convolution_executor_test.ts @@ -26,7 +26,7 @@ import {createNumberAttr, createNumericArrayAttr, createStrAttr, createTensorAtt describe('convolution', () => { let node: Node; const input = [tfc.scalar(1)]; - const context = new ExecutionContext({}); + const context = new ExecutionContext({}, {}); beforeEach(() => { node = { diff --git a/src/operations/executors/creation_executor_test.ts b/src/operations/executors/creation_executor_test.ts index 4a74c176..6e1e7104 100644 --- a/src/operations/executors/creation_executor_test.ts +++ b/src/operations/executors/creation_executor_test.ts @@ -27,7 +27,7 @@ describe('creation', () => { let node: Node; const input1 = [tfc.tensor1d([1, 2, 3])]; const input2 = [tfc.scalar(1)]; - const context = new ExecutionContext({}); + const context = new ExecutionContext({}, {}); beforeEach(() => { node = { diff --git a/src/operations/executors/graph_executor_test.ts b/src/operations/executors/graph_executor_test.ts index c8e162b3..c6f33aca 100644 --- a/src/operations/executors/graph_executor_test.ts +++ b/src/operations/executors/graph_executor_test.ts @@ -29,7 +29,7 @@ describe('graph', () => { const input1 = [tfc.tensor1d([1])]; const input2 = [tfc.tensor1d([1])]; const input3 = [tfc.tensor3d([1, 1, 1, 2, 2, 2], [1, 2, 3])]; - const context = new ExecutionContext({}); + const context = new ExecutionContext({}, {}); beforeEach(() => { node = { diff --git a/src/operations/executors/image_executor_test.ts b/src/operations/executors/image_executor_test.ts index 385edd87..a9b3d558 100644 --- a/src/operations/executors/image_executor_test.ts +++ b/src/operations/executors/image_executor_test.ts @@ -26,7 +26,7 @@ import {createBoolAttr, createNumericArrayAttr, createTensorAttr} from './test_h describe('image', () => { let node: Node; const input1 = [tfc.tensor1d([1])]; - const context = new ExecutionContext({}); + const context = new ExecutionContext({}, {}); beforeEach(() => { node = { diff --git a/src/operations/executors/logical_executor_test.ts b/src/operations/executors/logical_executor_test.ts index ebb26354..fb347477 100644 --- a/src/operations/executors/logical_executor_test.ts +++ b/src/operations/executors/logical_executor_test.ts @@ -26,7 +26,7 @@ describe('logical', () => { let node: Node; const input1 = [tfc.scalar(1)]; const input2 = [tfc.scalar(2)]; - const context = new ExecutionContext({}); + const context = new ExecutionContext({}, {}); beforeEach(() => { node = { diff --git a/src/operations/executors/matrices_executor_test.ts b/src/operations/executors/matrices_executor_test.ts index 7254d8c6..b1798c2b 100644 --- a/src/operations/executors/matrices_executor_test.ts +++ b/src/operations/executors/matrices_executor_test.ts @@ -27,7 +27,7 @@ describe('matrices', () => { let node: Node; const input1 = [tfc.scalar(1)]; const input2 = [tfc.scalar(2)]; - const context = new ExecutionContext({}); + const context = new ExecutionContext({}, {}); beforeEach(() => { node = { diff --git a/src/operations/executors/normalization_executor_test.ts b/src/operations/executors/normalization_executor_test.ts index 1bf9bc4c..22bb0233 100644 --- a/src/operations/executors/normalization_executor_test.ts +++ b/src/operations/executors/normalization_executor_test.ts @@ -25,7 +25,7 @@ import {createNumberAttr, createTensorAttr} from './test_helper'; describe('normalization', () => { let node: Node; const input1 = [tfc.scalar(1)]; - const context = new ExecutionContext({}); + const context = new ExecutionContext({}, {}); beforeEach(() => { node = { diff --git a/src/operations/executors/reduction_executor_test.ts b/src/operations/executors/reduction_executor_test.ts index ad52ef6c..10f06211 100644 --- a/src/operations/executors/reduction_executor_test.ts +++ b/src/operations/executors/reduction_executor_test.ts @@ -26,7 +26,7 @@ import {createBoolAttr, createNumberAttr, createTensorAttr} from './test_helper' describe('reduction', () => { let node: Node; const input1 = [tfc.scalar(1)]; - const context = new ExecutionContext({}); + const context = new ExecutionContext({}, {}); beforeEach(() => { node = { diff --git a/src/operations/executors/slice_join_executor.ts b/src/operations/executors/slice_join_executor.ts index de348228..6365429a 100644 --- a/src/operations/executors/slice_join_executor.ts +++ b/src/operations/executors/slice_join_executor.ts @@ -65,9 +65,16 @@ export let executeOp: OpExecutor = (node: Node, tensorMap: NamedTensorsMap, getParamValue('beginMask', node, tensorMap, context) as number; const endMask = getParamValue('endMask', node, tensorMap, context) as number; + const tensor = getParamValue('x', node, tensorMap, context) as tfc.Tensor; + if (begin.length === 1 && tensor.shape.length > 1) { + for (let i = 1; i < tensor.shape.length; i++) { + begin.push(0); + end.push(tensor.shape[i]); + strides.push(strides[0]); + } + } return [tfc.stridedSlice( - getParamValue('x', node, tensorMap, context) as tfc.Tensor, begin, - end, strides, beginMask, endMask)]; + tensor, begin, end, strides, beginMask, endMask)]; } case 'stack': { return tfc.tidy(() => { diff --git a/src/operations/executors/slice_join_executor_test.ts b/src/operations/executors/slice_join_executor_test.ts index b8b0b27e..97deba56 100644 --- a/src/operations/executors/slice_join_executor_test.ts +++ b/src/operations/executors/slice_join_executor_test.ts @@ -30,7 +30,7 @@ describe('slice join', () => { const input3 = [tfc.scalar(3)]; const input4 = [tfc.tensor1d([3])]; const input5 = [tfc.tensor1d([3, 4])]; - const context = new ExecutionContext({}); + const context = new ExecutionContext({}, {}); describe('multi-tensor ops', () => { beforeEach(() => { diff --git a/src/operations/executors/transformation_executor_test.ts b/src/operations/executors/transformation_executor_test.ts index db850178..8420bf5d 100644 --- a/src/operations/executors/transformation_executor_test.ts +++ b/src/operations/executors/transformation_executor_test.ts @@ -27,7 +27,7 @@ describe('transformation', () => { let node: Node; const input1 = [tfc.scalar(1)]; const input2 = [tfc.tensor1d([1, 1])]; - const context = new ExecutionContext({}); + const context = new ExecutionContext({}, {}); beforeEach(() => { node = { diff --git a/src/operations/op_list/control.json b/src/operations/op_list/control.json index 004bed2c..db902326 100644 --- a/src/operations/op_list/control.json +++ b/src/operations/op_list/control.json @@ -104,5 +104,261 @@ "notSupported": true } ] + }, + { + "tfOpName": "TensorArrayV3", + "dlOpName": "tensorArray", + "category": "control", + "params": [ + { + "tfInputIndex": 0, + "dlParamName": "size", + "type": "number" + }, + { + "tfParamName": "dtype", + "dlParamName": "dtype", + "type": "dtype" + }, + { + "tfParamName": "element_shape", + "dlParamName": "elementShape", + "type": "shape" + }, + { + "tfParamName": "dynamic_size", + "dlParamName": "dynamicSize", + "type": "bool" + }, + { + "tfParamName": "clear_after_read", + "dlParamName": "clearAfterRead", + "type": "bool" + }, + { + "tfParamName": "identical_element_shapes", + "dlParamName": "identicalElementShapes", + "type": "bool" + }, + { + "tfParamName": "tensor_array_name", + "dlParamName": "name", + "type": "string" + } + ] + }, + { + "tfOpName": "TensorArrayWriteV3", + "dlOpName": "tensorArrayWrite", + "category": "control", + "params": [ + { + "tfInputIndex": 0, + "dlParamName": "tensorArrayId", + "type": "number" + }, + { + "tfInputIndex": 1, + "dlParamName": "index", + "type": "number" + }, + { + "tfInputIndex": 2, + "dlParamName": "tensor", + "type": "tensor" + }, + { + "tfInputIndex": 3, + "dlParamName": "flowIn", + "type": "number" + }, + { + "tfParamName": "T", + "dlParamName": "dtype", + "type": "dtype", + "notSupported": true + } + ] + }, + { + "tfOpName": "TensorArrayReadV3", + "dlOpName": "tensorArrayRead", + "category": "control", + "params": [ + { + "tfInputIndex": 0, + "dlParamName": "tensorArrayId", + "type": "number" + }, + { + "tfInputIndex": 1, + "dlParamName": "index", + "type": "number" + }, + { + "tfInputIndex": 2, + "dlParamName": "flowIn", + "type": "number" + }, + { + "tfParamName": "dtype", + "dlParamName": "dtype", + "type": "dtype", + "notSupported": true + } + ] + }, + { + "tfOpName": "TensorArrayGatherV3", + "dlOpName": "tensorArrayGather", + "category": "control", + "params": [ + { + "tfInputIndex": 0, + "dlParamName": "tensorArrayId", + "type": "number" + }, + { + "tfInputIndex": 1, + "dlParamName": "indices", + "type": "number[]" + }, + { + "tfInputIndex": 2, + "dlParamName": "flowIn", + "type": "number" + }, + { + "tfParamName": "dtype", + "dlParamName": "dtype", + "type": "dtype" + }, + { + "tfParamName": "element_shape", + "dlParamName": "elementShape", + "type": "shape" + } + ] + }, + { + "tfOpName": "TensorArrayScatterV3", + "dlOpName": "tensorArrayScatter", + "category": "control", + "params": [ + { + "tfInputIndex": 0, + "dlParamName": "tensorArrayId", + "type": "number" + }, + { + "tfInputIndex": 1, + "dlParamName": "indices", + "type": "number[]" + }, + { + "tfInputIndex": 2, + "dlParamName": "tensor", + "type": "number[]" + }, + { + "tfInputIndex": 3, + "dlParamName": "flowIn", + "type": "number" + }, + { + "tfParamName": "T", + "dlParamName": "dtype", + "type": "dtype" + } + ] + }, + { + "tfOpName": "TensorArrayConcatV3", + "dlOpName": "tensorArrayConcat", + "category": "control", + "params": [ + { + "tfInputIndex": 0, + "dlParamName": "tensorArrayId", + "type": "number" + }, + { + "tfInputIndex": 1, + "dlParamName": "flowIn", + "type": "number" + }, + { + "tfParamName": "dtype", + "dlParamName": "dtype", + "type": "dtype" + }, + { + "tfParamName": "element_shape_except0", + "dlParamName": "elementShapeExcept0", + "type": "shape", + "notSupported": true + } + ] + }, + { + "tfOpName": "TensorArraySplitV3", + "dlOpName": "tensorArraySplit", + "category": "control", + "params": [ + { + "tfInputIndex": 0, + "dlParamName": "tensorArrayId", + "type": "number" + }, + { + "tfInputIndex": 1, + "dlParamName": "tensor", + "type": "tensor" + }, + { + "tfInputIndex": 2, + "dlParamName": "lengths", + "type": "number[]" + }, + { + "tfInputIndex": 3, + "dlParamName": "flowIn", + "type": "number" + }, + { + "tfParamName": "T", + "dlParamName": "dtype", + "type": "dtype" + } + ] + }, + { + "tfOpName": "TensorArraySizeV3", + "dlOpName": "tensorArraySize", + "category": "control", + "params": [ + { + "tfInputIndex": 0, + "dlParamName": "tensorArrayId", + "type": "number" + }, + { + "tfInputIndex": 1, + "dlParamName": "flowIn", + "type": "number" + } + ] + }, + { + "tfOpName": "TensorArrayCloseV3", + "dlOpName": "tensorArrayClose", + "category": "control", + "params": [ + { + "tfInputIndex": 0, + "dlParamName": "tensorArrayId", + "type": "number" + } + ] } ] diff --git a/src/operations/operation_executor_test.ts b/src/operations/operation_executor_test.ts index 81cba621..2527a117 100644 --- a/src/operations/operation_executor_test.ts +++ b/src/operations/operation_executor_test.ts @@ -34,7 +34,7 @@ import {Node} from './types'; describe('OperationExecutor', () => { let node: Node; - const context = new ExecutionContext({}); + const context = new ExecutionContext({}, {}); beforeEach(() => { node = {