Skip to content

Commit 9a580c5

Browse files
authored
[converter] added dtype support for oneHot for converter (#6782)
* added dtype support for oneHot for converter * update the doc for onehot
1 parent 25cd4f4 commit 9a580c5

File tree

4 files changed

+15
-10
lines changed

4 files changed

+15
-10
lines changed

tfjs-converter/python/tensorflowjs/op_list/creation.json

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,7 @@
8888
{
8989
"tfName": "T",
9090
"name": "dtype",
91-
"type": "dtype",
92-
"notSupported": true
91+
"type": "dtype"
9392
}
9493
]
9594
},
@@ -366,4 +365,4 @@
366365
}
367366
]
368367
}
369-
]
368+
]

tfjs-converter/src/operations/executors/creation_executor.ts

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ import {InternalOpExecutor, Node} from '../types';
2626
import {getParamValue} from './utils';
2727

2828
export const executeOp: InternalOpExecutor =
29-
(node: Node, tensorMap: NamedTensorsMap,
30-
context: ExecutionContext, ops = tfOps): Tensor[] => {
29+
(node: Node, tensorMap: NamedTensorsMap, context: ExecutionContext,
30+
ops = tfOps): Tensor[] => {
3131
switch (node.op) {
3232
case 'Fill': {
3333
const shape =
@@ -64,7 +64,9 @@ export const executeOp: InternalOpExecutor =
6464
getParamValue('onValue', node, tensorMap, context) as number;
6565
const offValue =
6666
getParamValue('offValue', node, tensorMap, context) as number;
67-
return [ops.oneHot(indices, depth, onValue, offValue)];
67+
const dtype =
68+
getParamValue('dtype', node, tensorMap, context) as DataType;
69+
return [ops.oneHot(indices, depth, onValue, offValue, dtype)];
6870
}
6971
case 'Ones': {
7072
return [ops.ones(

tfjs-converter/src/operations/executors/creation_executor_test.ts

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ import * as creation from '../op_list/creation';
2222
import {Node} from '../types';
2323

2424
import {executeOp} from './creation_executor';
25+
import {RecursiveSpy, spyOnAllFunctions} from './spy_ops';
2526
import {createDtypeAttr, createNumberAttr, createNumberAttrFromIndex, createNumericArrayAttrFromIndex, createTensorAttr, validateParam} from './test_helper';
26-
import {spyOnAllFunctions, RecursiveSpy} from './spy_ops';
2727

2828
describe('creation', () => {
2929
let node: Node;
@@ -99,22 +99,25 @@ describe('creation', () => {
9999
node.inputParams['depth'] = createNumberAttrFromIndex(1);
100100
node.inputParams['onValue'] = createNumberAttrFromIndex(2);
101101
node.inputParams['offValue'] = createNumberAttrFromIndex(3);
102+
node.attrParams['dtype'] = createDtypeAttr('float32');
102103
node.inputNames = ['input', 'input2', 'input3', 'input4'];
103104
const input = [tfOps.tensor1d([0])];
104105
const input3 = [tfOps.scalar(2)];
105106
const input4 = [tfOps.scalar(3)];
106107
spyOps.oneHot.and.returnValue({});
107-
executeOp(node, {input, input2, input3, input4}, context,
108-
spyOpsAsTfOps);
108+
executeOp(
109+
node, {input, input2, input3, input4}, context, spyOpsAsTfOps);
109110

110-
expect(spyOps.oneHot).toHaveBeenCalledWith(input[0], 1, 2, 3);
111+
expect(spyOps.oneHot)
112+
.toHaveBeenCalledWith(input[0], 1, 2, 3, 'float32');
111113
});
112114
it('should match json def', () => {
113115
node.op = 'OneHot';
114116
node.inputParams['indices'] = createTensorAttr(0);
115117
node.inputParams['depth'] = createNumberAttrFromIndex(1);
116118
node.inputParams['onValue'] = createNumberAttrFromIndex(2);
117119
node.inputParams['offValue'] = createNumberAttrFromIndex(3);
120+
node.attrParams['dtype'] = createDtypeAttr('float32');
118121

119122
expect(validateParam(node, creation.json)).toBeTruthy();
120123
});

tfjs-core/src/ops/one_hot.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ import {op} from './operation';
4545
* the location.
4646
* @param offValue A number used to fill in the output when the index does
4747
* not match the location.
48+
* @param dtype The dtype of the output tensor, default to 'int32'.
4849
*
4950
* @doc {heading: 'Tensors', subheading: 'Creation'}
5051
*/

0 commit comments

Comments
 (0)