Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit 833475d

Browse files
manrajgroverbileschi
authored andcommitted
One Hot: Fixes CPU result dtype (#1110)
BUG: tensorflow/tfjs#435
1 parent b5b5430 commit 833475d

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

src/kernels/backend_cpu.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1888,7 +1888,7 @@ export class MathBackendCPU implements KernelBackend {
18881888
for (let event = 0; event < indices.size; ++event) {
18891889
res[event * depth + indices.get(event)] = onValue;
18901890
}
1891-
return ops.tensor2d(res, [indices.size, depth]);
1891+
return ops.tensor2d(res, [indices.size, depth], 'int32');
18921892
}
18931893

18941894
private broadcastedBinaryOp(

src/ops/array_ops_test.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2162,6 +2162,14 @@ describeWithFlags('oneHot', ALL_ENVS, () => {
21622162
const indices = tf.tensor1d([0, 1], 'float32');
21632163
expect(() => tf.oneHot(indices, 2)).toThrowError();
21642164
});
2165+
2166+
it('check output dtype', () => {
2167+
const expectedType = 'int32';
2168+
const indices = tf.tensor1d([0, 1], 'int32');
2169+
const res = tf.oneHot(indices, 2);
2170+
2171+
expect(res.dtype).toEqual(expectedType);
2172+
});
21652173
});
21662174

21672175
describeWithFlags('linspace', ALL_ENVS, () => {

0 commit comments

Comments
 (0)