Skip to content
This repository was archived by the owner on Jun 14, 2024. It is now read-only.

Commit df3e54d

Browse files
authored
implemented tensor array ops and hooked up with the executor (#170)
1 parent 01e7933 commit df3e54d

22 files changed

+618
-30
lines changed

src/data/types.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
* =============================================================================
1616
*/
1717
import {DataType, Tensor} from '@tensorflow/tfjs-core';
18+
import {TensorArray} from '../executor/tensor_array';
1819

1920
export type NamedTensorMap = {
2021
[key: string]: Tensor
@@ -24,6 +25,10 @@ export type NamedTensorsMap = {
2425
[key: string]: Tensor[]
2526
};
2627

28+
export type TensorArrayMap = {
29+
[key: number]: TensorArray
30+
};
31+
2732
export interface TensorInfo {
2833
name: string;
2934
shape?: number[];

src/executor/execution_context.ts

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
*/
1717
import {Tensor} from '@tensorflow/tfjs-core';
1818

19-
import {NamedTensorsMap} from '../data/types';
19+
import {NamedTensorsMap, TensorArrayMap} from '../data/types';
20+
21+
import {TensorArray} from './tensor_array';
2022

2123
export interface ExecutionContextInfo {
2224
id: number; // the unique id of the context info
@@ -40,7 +42,9 @@ export class ExecutionContext {
4042
private lastId = 0;
4143
private _currentContextIds: string[];
4244

43-
constructor(public weightMap: NamedTensorsMap) {
45+
constructor(
46+
public readonly weightMap: NamedTensorsMap,
47+
public readonly tensorArrayMap: TensorArrayMap) {
4448
this.generateCurrentContextIds();
4549
}
4650

@@ -151,4 +155,12 @@ export class ExecutionContext {
151155
getWeight(name: string): Tensor[] {
152156
return this.weightMap[name];
153157
}
158+
159+
addTensorArray(tensorArray: TensorArray) {
160+
this.tensorArrayMap[tensorArray.id] = tensorArray;
161+
}
162+
163+
getTensorArray(id: number): TensorArray {
164+
return this.tensorArrayMap[id];
165+
}
154166
}

src/executor/execution_context_test.ts

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@
1616
*/
1717

1818
import {ExecutionContext} from './execution_context';
19+
import {TensorArray} from './tensor_array';
1920

2021
let context: ExecutionContext;
22+
let tensorArray: TensorArray;
2123
describe('ExecutionContext', () => {
2224
beforeEach(() => {
23-
context = new ExecutionContext({});
25+
context = new ExecutionContext({}, {});
2426
});
25-
afterEach(() => {});
2627

2728
it('should initialize', () => {
2829
expect(context.currentContext).toEqual([
@@ -31,6 +32,21 @@ describe('ExecutionContext', () => {
3132
expect(context.currentContextId).toEqual('');
3233
});
3334

35+
describe('tensor array', () => {
36+
beforeEach(() => {
37+
tensorArray = new TensorArray('', 'float32', 10, [1], true, true, true);
38+
});
39+
40+
it('should be able to add tensor array', () => {
41+
context.addTensorArray(tensorArray);
42+
expect(context.getTensorArray(tensorArray.id)).toBe(tensorArray);
43+
});
44+
45+
it('should be able to read tensor array', () => {
46+
expect(context.getTensorArray(tensorArray.id)).toBeUndefined();
47+
});
48+
});
49+
3450
describe('enterFrame', () => {
3551
it('should add new Frame', () => {
3652
context.enterFrame('1');

src/executor/graph_executor.ts

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515
* =============================================================================
1616
*/
1717

18-
// tslint:disable-next-line:max-line-length
1918
import {DataType, Tensor, tidy, util} from '@tensorflow/tfjs-core';
2019

21-
import {NamedTensorMap, NamedTensorsMap, TensorInfo} from '../data/types';
22-
import {getNodeNameAndIndex, getTensor} from '../operations/executors/utils';
20+
// tslint:disable-next-line:max-line-length
21+
import {NamedTensorMap, NamedTensorsMap, TensorArrayMap, TensorInfo} from '../data/types';
22+
// tslint:disable-next-line:max-line-length
23+
import {getNodeNameAndIndex, getParamValue, getTensor} from '../operations/executors/utils';
2324
import {executeOp} from '../operations/operation_executor';
2425
import {Graph, Node} from '../operations/types';
2526

@@ -128,8 +129,9 @@ export class GraphExecutor {
128129
execute(inputs: NamedTensorsMap, outputs?: string|string[]): NamedTensorMap {
129130
this.checkInput(inputs);
130131
this.checkInputShapeAndType(inputs);
132+
const tensorArrayMap: TensorArrayMap = {};
131133
const result = tidy(() => {
132-
const context = new ExecutionContext(this._weightMap);
134+
const context = new ExecutionContext(this._weightMap, tensorArrayMap);
133135
const tensors =
134136
this.compiledOrder.reduce<NamedTensorsMap>((map, node) => {
135137
map[node.name] = executeOp(node, map, context) as Tensor[];
@@ -153,7 +155,8 @@ export class GraphExecutor {
153155
Promise<NamedTensorMap> {
154156
this.checkInput(inputs);
155157
this.checkInputShapeAndType(inputs);
156-
const context = new ExecutionContext(this._weightMap);
158+
const tensorArrayMap: TensorArrayMap = {};
159+
const context = new ExecutionContext(this._weightMap, tensorArrayMap);
157160
// Graph with control flow op requires runtime evaluation of the execution
158161
// order, while without control flow the execution order is pre-determined
159162
// in the compile method.
@@ -196,10 +199,20 @@ export class GraphExecutor {
196199
while (stack.length > 0) {
197200
const item = stack.pop();
198201
context.currentContext = item.contexts;
199-
202+
let nodeName = '';
203+
// The tensor of the Enter op with isConstant set should be set
204+
// in the parent scope, so it will be available as constant for the
205+
// whole loop.
206+
if (item.node.op === 'enter' &&
207+
getParamValue('isConstant', item.node, tensorMap, context)) {
208+
[nodeName] = getNodeNameAndIndex(item.node.name, context);
209+
}
200210
const tensors = executeOp(item.node, tensorMap, context);
201211

202-
const [nodeName, ] = getNodeNameAndIndex(item.node.name, context);
212+
if (!nodeName) {
213+
[nodeName] = getNodeNameAndIndex(item.node.name, context);
214+
}
215+
203216
tensorMap[nodeName] = await tensors;
204217
item.node.children.forEach((childNode) => {
205218
const [nodeName, ] = getNodeNameAndIndex(childNode.name, context);

src/executor/tensor_array.ts

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,18 @@ export interface TensorWithState {
2828
* allows reading from the array and writing to the array.
2929
*/
3030
export class TensorArray {
31+
private static nextId = 0;
3132
private tensors: TensorWithState[] = [];
3233
private closed_ = false;
34+
readonly id: number;
3335
constructor(
3436
public readonly name: string, public readonly dtype: DataType,
3537
private maxSize: number, private elementShape: number[],
3638
public readonly identicalElementShapes: boolean,
3739
public readonly dynamicSize: boolean,
38-
public readonly clearAfterRead: boolean) {}
40+
public readonly clearAfterRead: boolean) {
41+
this.id = TensorArray.nextId++;
42+
}
3943

4044
get closed() {
4145
return this.closed_;
@@ -114,6 +118,12 @@ export class TensorArray {
114118
because the value dtype is ${
115119
tensor.dtype}, but TensorArray dtype is ${this.dtype}.`);
116120
}
121+
122+
// Set the shape for the first time write to unknow shape tensor array
123+
if (this.size() === 0 && this.elementShape.length === 0) {
124+
this.elementShape = tensor.shape;
125+
}
126+
117127
util.assertShapesMatch(
118128
this.elementShape, tensor.shape,
119129
`TensorArray ${this.name}: Could not write to TensorArray index ${

src/operations/executors/arithmetic_executor_test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ describe('arithmetic', () => {
2626
let node: Node;
2727
const input1 = [tfc.scalar(1)];
2828
const input2 = [tfc.scalar(1)];
29-
const context = new ExecutionContext({});
29+
const context = new ExecutionContext({}, {});
3030

3131
beforeEach(() => {
3232
node = {

src/operations/executors/basic_math_executor_test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import {createNumberAttr, createTensorAttr} from './test_helper';
2525
describe('basic math', () => {
2626
let node: Node;
2727
const input1 = [tfc.scalar(1)];
28-
const context = new ExecutionContext({});
28+
const context = new ExecutionContext({}, {});
2929

3030
beforeEach(() => {
3131
node = {

src/operations/executors/control_executor.ts

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
*/
1717

1818
import * as tfc from '@tensorflow/tfjs-core';
19+
import {scalar} from '@tensorflow/tfjs-core';
1920

2021
import {NamedTensorsMap} from '../../data/types';
2122
import {ExecutionContext} from '../../executor/execution_context';
23+
import {TensorArray} from '../../executor/tensor_array';
2224
import {Node} from '../types';
2325

2426
import {getParamValue, getTensor} from './utils';
@@ -61,6 +63,97 @@ export async function executeOp(
6163
getParamValue('tensor', node, tensorMap, context) as tfc.Tensor;
6264
context.nextIteration();
6365
return [input];
66+
67+
case 'tensorArray':
68+
const size = getParamValue('size', node, tensorMap, context) as number;
69+
const dtype =
70+
getParamValue('dtype', node, tensorMap, context) as tfc.DataType;
71+
const elementShape =
72+
getParamValue('elementShape', node, tensorMap, context) as number[];
73+
const dynamicSize =
74+
getParamValue('dynamicSize', node, tensorMap, context) as boolean;
75+
const clearAfterRead =
76+
getParamValue('clearAfterRead', node, tensorMap, context) as boolean;
77+
const identicalElementShapes =
78+
getParamValue('identicalElementShapes', node, tensorMap, context) as
79+
boolean;
80+
const name = getParamValue('name', node, tensorMap, context) as string;
81+
const tensorArray = new TensorArray(
82+
name, dtype, size, elementShape, identicalElementShapes, dynamicSize,
83+
clearAfterRead);
84+
context.addTensorArray(tensorArray);
85+
return [scalar(tensorArray.id), scalar(1.0)];
86+
87+
case 'tensorArrayWrite':
88+
const id =
89+
getParamValue('tensorArrayId', node, tensorMap, context) as number;
90+
const index = getParamValue('index', node, tensorMap, context) as number;
91+
const writeTensor =
92+
getParamValue('tensor', node, tensorMap, context) as tfc.Tensor;
93+
const writeTensorArray = context.getTensorArray(id);
94+
writeTensorArray.write(index, writeTensor);
95+
return [scalar(1.0)];
96+
97+
case 'tensorArrayRead':
98+
const readId =
99+
getParamValue('tensorArrayId', node, tensorMap, context) as number;
100+
const readIndex =
101+
getParamValue('index', node, tensorMap, context) as number;
102+
const readTensorArray = context.getTensorArray(readId);
103+
return [readTensorArray.read(readIndex)];
104+
105+
case 'tensorArrayGather':
106+
const gatherId =
107+
getParamValue('tensorArrayId', node, tensorMap, context) as number;
108+
const gatherIndices =
109+
getParamValue('indices', node, tensorMap, context) as number[];
110+
const gatherDtype =
111+
getParamValue('dtype', node, tensorMap, context) as tfc.DataType;
112+
const gatherTensorArray = context.getTensorArray(gatherId);
113+
return [gatherTensorArray.gather(gatherIndices, gatherDtype)];
114+
115+
case 'tensorArrayScatter':
116+
const scatterId =
117+
getParamValue('tensorArrayId', node, tensorMap, context) as number;
118+
const scatterIndices =
119+
getParamValue('indices', node, tensorMap, context) as number[];
120+
const scatterTensor =
121+
getParamValue('tensor', node, tensorMap, context) as tfc.Tensor;
122+
const scatterTensorArray = context.getTensorArray(scatterId);
123+
scatterTensorArray.scatter(scatterIndices, scatterTensor);
124+
return [scalar(1.0)];
125+
126+
case 'tensorArrayConcat':
127+
const concatId =
128+
getParamValue('tensorArrayId', node, tensorMap, context) as number;
129+
const concatTensorArray = context.getTensorArray(concatId);
130+
const concatDtype =
131+
getParamValue('dtype', node, tensorMap, context) as tfc.DataType;
132+
return [concatTensorArray.concat(concatDtype)];
133+
134+
case 'tensorArraySplit':
135+
const splitId =
136+
getParamValue('tensorArrayId', node, tensorMap, context) as number;
137+
const splitTensor =
138+
getParamValue('tensor', node, tensorMap, context) as tfc.Tensor;
139+
const lengths =
140+
getParamValue('lengths', node, tensorMap, context) as number[];
141+
const splitTensorArray = context.getTensorArray(splitId);
142+
splitTensorArray.split(lengths, splitTensor);
143+
return [scalar(1.0)];
144+
145+
case 'tensorArraySize':
146+
const sizeId =
147+
getParamValue('tensorArrayId', node, tensorMap, context) as number;
148+
const sizeTensorArray = context.getTensorArray(sizeId);
149+
return [scalar(sizeTensorArray.size(), 'int32')];
150+
151+
case 'tensorArrayClose':
152+
const closeId =
153+
getParamValue('tensorArrayId', node, tensorMap, context) as number;
154+
const closeTensorArray = context.getTensorArray(closeId);
155+
closeTensorArray.clearAndClose();
156+
return [];
64157
default:
65158
throw TypeError(`Node type ${node.op} is not implemented`);
66159
}

0 commit comments

Comments
 (0)