Skip to content

Commit e71d8ae

Browse files
authored
Ensure that tf.exp() ops up-cast 'int32' types to Float. (#107)
1 parent d49198f commit e71d8ae

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

src/nodejs_kernel_backend.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,8 @@ export class NodeJSKernelBackend implements KernelBackend {
416416
}
417417

418418
exp<T extends Tensor>(x: T): T {
419-
return this.executeSingleInput('Exp', x) as T;
419+
const xTensor = x.dtype === 'int32' ? x.toFloat() : x;
420+
return this.executeSingleInput('Exp', xTensor) as T;
420421
}
421422

422423
log<T extends Tensor>(x: T): T {

src/nodejs_kernel_backend_test.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,9 @@ describe('delayed upload', () => {
3737
expect(softmaxLogits.get(2)).toEqual(data[2]);
3838
});
3939
});
40+
41+
describe('type casting', () => {
42+
it('exp support int32', () => {
43+
tf.exp(tf.scalar(2, 'int32'));
44+
});
45+
});

0 commit comments

Comments
 (0)