Skip to content

Commit 1eefcce

Browse files
authored
Add kernel RaggedTensorToTensor for CPU and WebGL backend (#6686)
* Add kernel RaggedTensorToTensor for CPU and WebGL backend * Add WebGL kernel forward * Change version
1 parent 7ae8860 commit 1eefcce

File tree

14 files changed

+1140
-0
lines changed

14 files changed

+1140
-0
lines changed
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/**
2+
* @license
3+
* Copyright 2022 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {KernelConfig, KernelFunc, RaggedTensorToTensor, RaggedTensorToTensorAttrs, RaggedTensorToTensorInputs, TensorInfo, TypedArray} from '@tensorflow/tfjs-core';
19+
20+
import {MathBackendCPU} from '../backend_cpu';
21+
22+
import {raggedTensorToTensorImpl} from './RaggedTensorToTensor_impl';
23+
24+
export function raggedTensorToTensor(args: {
25+
inputs: RaggedTensorToTensorInputs,
26+
backend: MathBackendCPU,
27+
attrs: RaggedTensorToTensorAttrs
28+
}): TensorInfo {
29+
const {inputs, backend, attrs} = args;
30+
const {shape, values, defaultValue, rowPartitionTensors} = inputs;
31+
const {rowPartitionTypes} = attrs;
32+
33+
const $shape = backend.data.get(shape.dataId).values as TypedArray;
34+
const $values = backend.data.get(values.dataId).values as TypedArray;
35+
const $defaultValue =
36+
backend.data.get(defaultValue.dataId).values as TypedArray;
37+
const $rowPartitionValues = rowPartitionTensors.map(
38+
t => backend.data.get(t.dataId).values as TypedArray);
39+
const rowPartitionValuesShapes = rowPartitionTensors.map(t => t.shape);
40+
41+
const [outputShape, output] = raggedTensorToTensorImpl(
42+
$shape, shape.shape, $values, values.shape, values.dtype, $defaultValue,
43+
defaultValue.shape, $rowPartitionValues, rowPartitionValuesShapes,
44+
rowPartitionTypes);
45+
return backend.makeTensorInfo(outputShape, values.dtype, output);
46+
}
47+
48+
export const raggedTensorToTensorConfig: KernelConfig = {
49+
kernelName: RaggedTensorToTensor,
50+
backendName: 'cpu',
51+
kernelFunc: raggedTensorToTensor as {} as KernelFunc,
52+
};

0 commit comments

Comments
 (0)