Skip to content

Commit dd30524

Browse files
committed
Merge branch 'java-unknown-type' into reland-minibench-refactor-2
2 parents 6da6a04 + f549064 commit dd30524

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

extension/android/executorch_android/src/main/java/org/pytorch/executorch/DType.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,13 @@ public enum DType {
7373
DType(int jniCode) {
7474
this.jniCode = jniCode;
7575
}
76+
77+
public static DType fromJniCode(int jniCode) {
78+
for (DType dtype : values()) {
79+
if (dtype.jniCode == jniCode) {
80+
return dtype;
81+
}
82+
}
83+
throw new IllegalArgumentException("No DType found for jniCode " + jniCode);
84+
}
7685
}

extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
package org.pytorch.executorch;
1010

11+
import android.util.Log;
1112
import com.facebook.jni.HybridData;
1213
import com.facebook.jni.annotations.DoNotStrip;
1314
import java.nio.Buffer;
@@ -630,6 +631,30 @@ public String toString() {
630631
}
631632
}
632633

634+
static class Tensor_unsupported extends Tensor {
635+
private final ByteBuffer data;
636+
private final DType mDtype;
637+
638+
private Tensor_unsupported(ByteBuffer data, long[] shape, DType dtype) {
639+
super(shape);
640+
this.data = data;
641+
this.mDtype = dtype;
642+
Log.e(
643+
"ExecuTorch",
644+
toString() + " in Java. Please consider re-export the model with proper return type");
645+
}
646+
647+
@Override
648+
public DType dtype() {
649+
return mDtype;
650+
}
651+
652+
@Override
653+
public String toString() {
654+
return String.format("Unsupported tensor(%s, dtype=%d)", Arrays.toString(shape), this.mDtype);
655+
}
656+
}
657+
633658
// region checks
634659
private static void checkArgument(boolean expression, String errorMessage, Object... args) {
635660
if (!expression) {
@@ -675,7 +700,7 @@ private static Tensor nativeNewTensor(
675700
} else if (DType.INT8.jniCode == dtype) {
676701
tensor = new Tensor_int8(data, shape);
677702
} else {
678-
throw new IllegalArgumentException("Unknown Tensor dtype");
703+
tensor = new Tensor_unsupported(data, shape, DType.fromJniCode(dtype));
679704
}
680705
tensor.mHybridData = hybridData;
681706
return tensor;

0 commit comments

Comments
 (0)