Skip to content
This repository was archived by the owner on Jun 14, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/data/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
* =============================================================================
*/
import {DataType, Tensor} from '@tensorflow/tfjs-core';
import {TensorArray} from '../executor/tensor_array';

export type NamedTensorMap = {
[key: string]: Tensor
Expand All @@ -24,6 +25,10 @@ export type NamedTensorsMap = {
[key: string]: Tensor[]
};

export type TensorArrayMap = {
[key: number]: TensorArray
};

export interface TensorInfo {
name: string;
shape?: number[];
Expand Down
16 changes: 14 additions & 2 deletions src/executor/execution_context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
*/
import {Tensor} from '@tensorflow/tfjs-core';

import {NamedTensorsMap} from '../data/types';
import {NamedTensorsMap, TensorArrayMap} from '../data/types';

import {TensorArray} from './tensor_array';

export interface ExecutionContextInfo {
id: number; // the unique id of the context info
Expand All @@ -40,7 +42,9 @@ export class ExecutionContext {
private lastId = 0;
private _currentContextIds: string[];

constructor(public weightMap: NamedTensorsMap) {
constructor(
public readonly weightMap: NamedTensorsMap,
public readonly tensorArrayMap: TensorArrayMap) {
this.generateCurrentContextIds();
}

Expand Down Expand Up @@ -151,4 +155,12 @@ export class ExecutionContext {
getWeight(name: string): Tensor[] {
return this.weightMap[name];
}

addTensorArray(tensorArray: TensorArray) {
this.tensorArrayMap[tensorArray.id] = tensorArray;
}

getTensorArray(id: number): TensorArray {
return this.tensorArrayMap[id];
}
}
20 changes: 18 additions & 2 deletions src/executor/execution_context_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
*/

import {ExecutionContext} from './execution_context';
import {TensorArray} from './tensor_array';

let context: ExecutionContext;
let tensorArray: TensorArray;
describe('ExecutionContext', () => {
beforeEach(() => {
context = new ExecutionContext({});
context = new ExecutionContext({}, {});
});
afterEach(() => {});

it('should initialize', () => {
expect(context.currentContext).toEqual([
Expand All @@ -31,6 +32,21 @@ describe('ExecutionContext', () => {
expect(context.currentContextId).toEqual('');
});

describe('tensor array', () => {
beforeEach(() => {
tensorArray = new TensorArray('', 'float32', 10, [1], true, true, true);
});

it('should be able to add tensor array', () => {
context.addTensorArray(tensorArray);
expect(context.getTensorArray(tensorArray.id)).toBe(tensorArray);
});

it('should be able to read tensor array', () => {
expect(context.getTensorArray(tensorArray.id)).toBeUndefined();
});
});

describe('enterFrame', () => {
it('should add new Frame', () => {
context.enterFrame('1');
Expand Down
27 changes: 20 additions & 7 deletions src/executor/graph_executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
* =============================================================================
*/

// tslint:disable-next-line:max-line-length
import {DataType, Tensor, tidy, util} from '@tensorflow/tfjs-core';

import {NamedTensorMap, NamedTensorsMap, TensorInfo} from '../data/types';
import {getNodeNameAndIndex, getTensor} from '../operations/executors/utils';
// tslint:disable-next-line:max-line-length
import {NamedTensorMap, NamedTensorsMap, TensorArrayMap, TensorInfo} from '../data/types';
// tslint:disable-next-line:max-line-length
import {getNodeNameAndIndex, getParamValue, getTensor} from '../operations/executors/utils';
import {executeOp} from '../operations/operation_executor';
import {Graph, Node} from '../operations/types';

Expand Down Expand Up @@ -128,8 +129,9 @@ export class GraphExecutor {
execute(inputs: NamedTensorsMap, outputs?: string|string[]): NamedTensorMap {
this.checkInput(inputs);
this.checkInputShapeAndType(inputs);
const tensorArrayMap: TensorArrayMap = {};
const result = tidy(() => {
const context = new ExecutionContext(this._weightMap);
const context = new ExecutionContext(this._weightMap, tensorArrayMap);
const tensors =
this.compiledOrder.reduce<NamedTensorsMap>((map, node) => {
map[node.name] = executeOp(node, map, context) as Tensor[];
Expand All @@ -153,7 +155,8 @@ export class GraphExecutor {
Promise<NamedTensorMap> {
this.checkInput(inputs);
this.checkInputShapeAndType(inputs);
const context = new ExecutionContext(this._weightMap);
const tensorArrayMap: TensorArrayMap = {};
const context = new ExecutionContext(this._weightMap, tensorArrayMap);
// Graph with control flow op requires runtime evaluation of the execution
// order, while without control flow the execution order is pre-determined
// in the compile method.
Expand Down Expand Up @@ -196,10 +199,20 @@ export class GraphExecutor {
while (stack.length > 0) {
const item = stack.pop();
context.currentContext = item.contexts;

let nodeName = '';
// The tensor of the Enter op with isConstant set should be set
// in the parent scope, so it will be available as constant for the
// whole loop.
if (item.node.op === 'enter' &&
getParamValue('isConstant', item.node, tensorMap, context)) {
[nodeName] = getNodeNameAndIndex(item.node.name, context);
}
const tensors = executeOp(item.node, tensorMap, context);

const [nodeName, ] = getNodeNameAndIndex(item.node.name, context);
if (!nodeName) {
[nodeName] = getNodeNameAndIndex(item.node.name, context);
}

tensorMap[nodeName] = await tensors;
item.node.children.forEach((childNode) => {
const [nodeName, ] = getNodeNameAndIndex(childNode.name, context);
Expand Down
12 changes: 11 additions & 1 deletion src/executor/tensor_array.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,18 @@ export interface TensorWithState {
* allows reading from the array and writing to the array.
*/
export class TensorArray {
private static nextId = 0;
private tensors: TensorWithState[] = [];
private closed_ = false;
readonly id: number;
constructor(
public readonly name: string, public readonly dtype: DataType,
private maxSize: number, private elementShape: number[],
public readonly identicalElementShapes: boolean,
public readonly dynamicSize: boolean,
public readonly clearAfterRead: boolean) {}
public readonly clearAfterRead: boolean) {
this.id = TensorArray.nextId++;
}

get closed() {
return this.closed_;
Expand Down Expand Up @@ -114,6 +118,12 @@ export class TensorArray {
because the value dtype is ${
tensor.dtype}, but TensorArray dtype is ${this.dtype}.`);
}

// Set the shape for the first time write to unknow shape tensor array
if (this.size() === 0 && this.elementShape.length === 0) {
this.elementShape = tensor.shape;
}

util.assertShapesMatch(
this.elementShape, tensor.shape,
`TensorArray ${this.name}: Could not write to TensorArray index ${
Expand Down
2 changes: 1 addition & 1 deletion src/operations/executors/arithmetic_executor_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ describe('arithmetic', () => {
let node: Node;
const input1 = [tfc.scalar(1)];
const input2 = [tfc.scalar(1)];
const context = new ExecutionContext({});
const context = new ExecutionContext({}, {});

beforeEach(() => {
node = {
Expand Down
2 changes: 1 addition & 1 deletion src/operations/executors/basic_math_executor_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import {createNumberAttr, createTensorAttr} from './test_helper';
describe('basic math', () => {
let node: Node;
const input1 = [tfc.scalar(1)];
const context = new ExecutionContext({});
const context = new ExecutionContext({}, {});

beforeEach(() => {
node = {
Expand Down
93 changes: 93 additions & 0 deletions src/operations/executors/control_executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
*/

import * as tfc from '@tensorflow/tfjs-core';
import {scalar} from '@tensorflow/tfjs-core';

import {NamedTensorsMap} from '../../data/types';
import {ExecutionContext} from '../../executor/execution_context';
import {TensorArray} from '../../executor/tensor_array';
import {Node} from '../types';

import {getParamValue, getTensor} from './utils';
Expand Down Expand Up @@ -61,6 +63,97 @@ export async function executeOp(
getParamValue('tensor', node, tensorMap, context) as tfc.Tensor;
context.nextIteration();
return [input];

case 'tensorArray':
const size = getParamValue('size', node, tensorMap, context) as number;
const dtype =
getParamValue('dtype', node, tensorMap, context) as tfc.DataType;
const elementShape =
getParamValue('elementShape', node, tensorMap, context) as number[];
const dynamicSize =
getParamValue('dynamicSize', node, tensorMap, context) as boolean;
const clearAfterRead =
getParamValue('clearAfterRead', node, tensorMap, context) as boolean;
const identicalElementShapes =
getParamValue('identicalElementShapes', node, tensorMap, context) as
boolean;
const name = getParamValue('name', node, tensorMap, context) as string;
const tensorArray = new TensorArray(
name, dtype, size, elementShape, identicalElementShapes, dynamicSize,
clearAfterRead);
context.addTensorArray(tensorArray);
return [scalar(tensorArray.id), scalar(1.0)];

case 'tensorArrayWrite':
const id =
getParamValue('tensorArrayId', node, tensorMap, context) as number;
const index = getParamValue('index', node, tensorMap, context) as number;
const writeTensor =
getParamValue('tensor', node, tensorMap, context) as tfc.Tensor;
const writeTensorArray = context.getTensorArray(id);
writeTensorArray.write(index, writeTensor);
return [scalar(1.0)];

case 'tensorArrayRead':
const readId =
getParamValue('tensorArrayId', node, tensorMap, context) as number;
const readIndex =
getParamValue('index', node, tensorMap, context) as number;
const readTensorArray = context.getTensorArray(readId);
return [readTensorArray.read(readIndex)];

case 'tensorArrayGather':
const gatherId =
getParamValue('tensorArrayId', node, tensorMap, context) as number;
const gatherIndices =
getParamValue('indices', node, tensorMap, context) as number[];
const gatherDtype =
getParamValue('dtype', node, tensorMap, context) as tfc.DataType;
const gatherTensorArray = context.getTensorArray(gatherId);
return [gatherTensorArray.gather(gatherIndices, gatherDtype)];

case 'tensorArrayScatter':
const scatterId =
getParamValue('tensorArrayId', node, tensorMap, context) as number;
const scatterIndices =
getParamValue('indices', node, tensorMap, context) as number[];
const scatterTensor =
getParamValue('tensor', node, tensorMap, context) as tfc.Tensor;
const scatterTensorArray = context.getTensorArray(scatterId);
scatterTensorArray.scatter(scatterIndices, scatterTensor);
return [scalar(1.0)];

case 'tensorArrayConcat':
const concatId =
getParamValue('tensorArrayId', node, tensorMap, context) as number;
const concatTensorArray = context.getTensorArray(concatId);
const concatDtype =
getParamValue('dtype', node, tensorMap, context) as tfc.DataType;
return [concatTensorArray.concat(concatDtype)];

case 'tensorArraySplit':
const splitId =
getParamValue('tensorArrayId', node, tensorMap, context) as number;
const splitTensor =
getParamValue('tensor', node, tensorMap, context) as tfc.Tensor;
const lengths =
getParamValue('lengths', node, tensorMap, context) as number[];
const splitTensorArray = context.getTensorArray(splitId);
splitTensorArray.split(lengths, splitTensor);
return [scalar(1.0)];

case 'tensorArraySize':
const sizeId =
getParamValue('tensorArrayId', node, tensorMap, context) as number;
const sizeTensorArray = context.getTensorArray(sizeId);
return [scalar(sizeTensorArray.size(), 'int32')];

case 'tensorArrayClose':
const closeId =
getParamValue('tensorArrayId', node, tensorMap, context) as number;
const closeTensorArray = context.getTensorArray(closeId);
closeTensorArray.clearAndClose();
return [];
default:
throw TypeError(`Node type ${node.op} is not implemented`);
}
Expand Down
Loading