Skip to content

Commit 71e2276

Browse files
committed
bind/java: pin java byte array elements until Seq send is done.
When passing a byte array from Java to Go, Seq.writeByteArray JNI call encodes only the array size and the pointer to the array. Go-side receives the (size, ptr) pair info during the subsequent Seq.send JNI call, and copies the elements into a Go byte slice. We must pin the array elements until Go-side completes copying so that they are not moved or collected by Java runtime. This change keeps track of the pinned array info in a 'pinned' linked list, and unpin them as the Seq memory is freed. The jbyteArray argument passed to Seq.writeByteArray is needed to release the pinned byte array elements, but that is a "local reference". It is not guaranteed that the reference is valid after the method returns. Thus, we stash its global reference in the 'pinned' list and delete it later as well. A similar problem can occur on the byte slice returned from a Go function. This change does not address the case yet. Fixes golang/go#9486 Change-Id: I1255aefbc80b21ccbe9b2bf37699faaf0c5f0bae Reviewed-on: https://go-review.googlesource.com/2586 Reviewed-by: David Crawshaw <[email protected]>
1 parent d97f4d8 commit 71e2276

File tree

6 files changed

+147
-37
lines changed

6 files changed

+147
-37
lines changed

bind/java/SeqTest.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,27 @@ public void testByteArray() {
8181
}
8282
}
8383

84+
// Test for golang.org/issue/9486.
85+
public void testByteArrayAfterString() {
86+
byte[] bytes = new byte[1024];
87+
for (int i=0; i < bytes.length; i++) {
88+
bytes[i] = 8;
89+
}
90+
91+
String stuff = "stuff";
92+
byte[] got = Testpkg.AppendToString(stuff, bytes);
93+
94+
try {
95+
byte[] s = stuff.getBytes("UTF-8");
96+
byte[] want = new byte[s.length + bytes.length];
97+
System.arraycopy(s, 0, want, 0, s.length);
98+
System.arraycopy(bytes, 0, want, s.length, bytes.length);
99+
MoreAsserts.assertEquals("Bytes should match", want, got);
100+
} catch (Exception e) {
101+
fail("Cannot perform the test: " + e.toString());
102+
}
103+
}
104+
84105
public void testGoRefGC() {
85106
Testpkg.S s = Testpkg.New();
86107
runGC();

bind/java/seq_android.c

Lines changed: 69 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,15 @@ static jfieldID receive_refnum_id;
1919
static jfieldID receive_code_id;
2020
static jfieldID receive_handle_id;
2121

22+
static jclass jbytearray_clazz;
23+
24+
// pinned represents a pinned array to be released at the end of Send call.
25+
typedef struct pinned {
26+
jobject ref;
27+
void* ptr;
28+
struct pinned* next;
29+
} pinned;
30+
2231
// mem is a simple C equivalent of seq.Buffer.
2332
//
2433
// Many of the allocations around mem could be avoided to improve
@@ -28,6 +37,9 @@ typedef struct mem {
2837
uint32_t off;
2938
uint32_t len;
3039
uint32_t cap;
40+
41+
// TODO(hyangah): have it as a separate field outside mem?
42+
pinned* pinned;
3143
} mem;
3244

3345
// mem_ensure ensures that m has at least size bytes free.
@@ -42,6 +54,7 @@ static mem *mem_ensure(mem *m, uint32_t size) {
4254
m->off = 0;
4355
m->len = 0;
4456
m->buf = NULL;
57+
m->pinned = NULL;
4558
}
4659
if (m->cap > m->off+size) {
4760
return m;
@@ -95,6 +108,47 @@ uint8_t *mem_write(JNIEnv *env, jobject obj, uint32_t size) {
95108
return res;
96109
}
97110

111+
static void *pin_array(JNIEnv *env, jobject obj, jobject arr) {
112+
mem *m = mem_get(env, obj);
113+
if (m == NULL) {
114+
m = mem_ensure(m, 64);
115+
}
116+
pinned *p = (pinned*) malloc(sizeof(pinned));
117+
if (p == NULL) {
118+
LOG_FATAL("pin_array malloc failed");
119+
}
120+
p->ref = (*env)->NewGlobalRef(env, arr);
121+
122+
if ((*env)->IsInstanceOf(env, p->ref, jbytearray_clazz)) {
123+
p->ptr = (*env)->GetByteArrayElements(env, p->ref, NULL);
124+
} else {
125+
LOG_FATAL("unsupported array type");
126+
}
127+
128+
p->next = m->pinned;
129+
m->pinned = p;
130+
return p->ptr;
131+
}
132+
133+
static void unpin_arrays(JNIEnv *env, mem *m) {
134+
pinned* p = m->pinned;
135+
while (p != NULL) {
136+
if ((*env)->IsInstanceOf(env, p->ref, jbytearray_clazz)) {
137+
(*env)->ReleaseByteArrayElements(env, p->ref, (jbyte*)p->ptr, JNI_ABORT);
138+
} else {
139+
LOG_FATAL("invalid array type");
140+
}
141+
142+
(*env)->DeleteGlobalRef(env, p->ref);
143+
144+
pinned* o = p;
145+
p = p->next;
146+
free(o);
147+
}
148+
m->pinned = NULL;
149+
}
150+
151+
98152
static jfieldID find_field(JNIEnv *env, const char *class_name, const char *field_name, const char *field_type) {
99153
jclass clazz = (*env)->FindClass(env, class_name);
100154
if (clazz == NULL) {
@@ -109,6 +163,15 @@ static jfieldID find_field(JNIEnv *env, const char *class_name, const char *fiel
109163
return id;
110164
}
111165

166+
static jclass find_class(JNIEnv *env, const char *class_name) {
167+
jclass clazz = (*env)->FindClass(env, class_name);
168+
if (clazz == NULL) {
169+
LOG_FATAL("cannot find %s", class_name);
170+
return NULL;
171+
}
172+
return (*env)->NewGlobalRef(env, clazz);
173+
}
174+
112175
void init_seq(void *javavm) {
113176
JavaVM *vm = (JavaVM*)javavm;
114177
JNIEnv *env;
@@ -128,6 +191,8 @@ void init_seq(void *javavm) {
128191
receive_handle_id = find_field(env, "go/Seq$Receive", "handle", "I");
129192
receive_code_id = find_field(env, "go/Seq$Receive", "code", "I");
130193

194+
jbytearray_clazz = find_class(env, "[B");
195+
131196
LOG_INFO("loaded go/Seq");
132197

133198
if (res == JNI_EDETACHED) {
@@ -148,6 +213,7 @@ JNIEXPORT void JNICALL
148213
Java_go_Seq_free(JNIEnv *env, jobject obj) {
149214
mem *m = mem_get(env, obj);
150215
if (m != NULL) {
216+
unpin_arrays(env, m);
151217
free((void*)m->buf);
152218
free((void*)m);
153219
}
@@ -276,17 +342,8 @@ Java_go_Seq_writeByteArray(JNIEnv *env, jobject obj, jbyteArray v) {
276342
return;
277343
}
278344

279-
jboolean isCopy;
280-
jbyte* b = (*env)->GetByteArrayElements(env, v, &isCopy);
281-
if (isCopy) {
282-
// TODO: It's not clear how to handle if b is pointing to
283-
// a copy that may become invalid with ReleaseByteArrayElements.
284-
// Should we fall back to copy the byte array into the buffer?
285-
LOG_FATAL("got a copied byte array (len=%d)", len);
286-
}
287-
// gross pointer-to-int64 conversion.
288-
MEM_WRITE(int64_t) = (int64_t)((intptr_t)b);
289-
(*env)->ReleaseByteArrayElements(env, v, (jbyte*)b, 0);
345+
jbyte* b = pin_array(env, obj, v);
346+
MEM_WRITE(int64_t) = (jlong)(uintptr_t)b;
290347
}
291348

292349
JNIEXPORT void JNICALL
@@ -337,6 +394,7 @@ Java_go_Seq_send(JNIEnv *env, jclass clazz, jstring descriptor, jint code, jobje
337394
desc.n = (*env)->GetStringUTFLength(env, descriptor);
338395
Send(desc, (GoInt)code, src->buf, src->len, &dst->buf, &dst->len);
339396
(*env)->ReleaseStringUTFChars(env, descriptor, desc.p);
397+
unpin_arrays(env, src); // assume 'src' is no longer needed.
340398
}
341399

342400
JNIEXPORT void JNICALL

bind/java/seq_android.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ func Send(descriptor string, code int, req *C.uint8_t, reqlen C.size_t, res **C.
3636
}
3737
out := new(seq.Buffer)
3838
fn(out, in)
39+
// BUG(hyangah): the function returning a go byte slice (so fn writes a pointer into 'out') is unsafe.
40+
// After fn is complete here, Go runtime is free to collect or move the pointed byte slice
41+
// contents. (Explicitly calling runtime.GC here will surface the problem?)
42+
// Without pinning support from Go side, it will be hard to fix it without extra copying.
43+
3944
seqToBuf(res, reslen, out)
4045
}
4146

bind/java/testpkg/Testpkg.java

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,17 @@ public static long Add(long x, long y) {
2020
return _result;
2121
}
2222

23+
public static byte[] AppendToString(String str, byte[] someBytes) {
24+
go.Seq _in = new go.Seq();
25+
go.Seq _out = new go.Seq();
26+
byte[] _result;
27+
_in.writeUTF16(str);
28+
_in.writeByteArray(someBytes);
29+
Seq.send(DESCRIPTOR, CALL_AppendToString, _in, _out);
30+
_result = _out.readByteArray();
31+
return _result;
32+
}
33+
2334
public static byte[] BytesAppend(byte[] a, byte[] b) {
2435
go.Seq _in = new go.Seq();
2536
go.Seq _out = new go.Seq();
@@ -372,18 +383,19 @@ public static String StrDup(String s) {
372383
}
373384

374385
private static final int CALL_Add = 1;
375-
private static final int CALL_BytesAppend = 2;
376-
private static final int CALL_CallE = 3;
377-
private static final int CALL_CallF = 4;
378-
private static final int CALL_CallI = 5;
379-
private static final int CALL_CallS = 6;
380-
private static final int CALL_CallV = 7;
381-
private static final int CALL_CallVE = 8;
382-
private static final int CALL_Err = 9;
383-
private static final int CALL_GC = 10;
384-
private static final int CALL_Keep = 11;
385-
private static final int CALL_New = 12;
386-
private static final int CALL_NumSCollected = 13;
387-
private static final int CALL_StrDup = 14;
386+
private static final int CALL_AppendToString = 2;
387+
private static final int CALL_BytesAppend = 3;
388+
private static final int CALL_CallE = 4;
389+
private static final int CALL_CallF = 5;
390+
private static final int CALL_CallI = 6;
391+
private static final int CALL_CallS = 7;
392+
private static final int CALL_CallV = 8;
393+
private static final int CALL_CallVE = 9;
394+
private static final int CALL_Err = 10;
395+
private static final int CALL_GC = 11;
396+
private static final int CALL_Keep = 12;
397+
private static final int CALL_New = 13;
398+
private static final int CALL_NumSCollected = 14;
399+
private static final int CALL_StrDup = 15;
388400
private static final String DESCRIPTOR = "testpkg";
389401
}

bind/java/testpkg/go_testpkg/go_testpkg.go

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@ func proxy_Add(out, in *seq.Buffer) {
1616
out.WriteInt(res)
1717
}
1818

19+
func proxy_AppendToString(out, in *seq.Buffer) {
20+
param_str := in.ReadUTF16()
21+
param_someBytes := in.ReadByteArray()
22+
res := testpkg.AppendToString(param_str, param_someBytes)
23+
out.WriteByteArray(res)
24+
}
25+
1926
func proxy_BytesAppend(out, in *seq.Buffer) {
2027
param_a := in.ReadByteArray()
2128
param_b := in.ReadByteArray()
@@ -308,17 +315,18 @@ func proxy_StrDup(out, in *seq.Buffer) {
308315

309316
func init() {
310317
seq.Register("testpkg", 1, proxy_Add)
311-
seq.Register("testpkg", 2, proxy_BytesAppend)
312-
seq.Register("testpkg", 3, proxy_CallE)
313-
seq.Register("testpkg", 4, proxy_CallF)
314-
seq.Register("testpkg", 5, proxy_CallI)
315-
seq.Register("testpkg", 6, proxy_CallS)
316-
seq.Register("testpkg", 7, proxy_CallV)
317-
seq.Register("testpkg", 8, proxy_CallVE)
318-
seq.Register("testpkg", 9, proxy_Err)
319-
seq.Register("testpkg", 10, proxy_GC)
320-
seq.Register("testpkg", 11, proxy_Keep)
321-
seq.Register("testpkg", 12, proxy_New)
322-
seq.Register("testpkg", 13, proxy_NumSCollected)
323-
seq.Register("testpkg", 14, proxy_StrDup)
318+
seq.Register("testpkg", 2, proxy_AppendToString)
319+
seq.Register("testpkg", 3, proxy_BytesAppend)
320+
seq.Register("testpkg", 4, proxy_CallE)
321+
seq.Register("testpkg", 5, proxy_CallF)
322+
seq.Register("testpkg", 6, proxy_CallI)
323+
seq.Register("testpkg", 7, proxy_CallS)
324+
seq.Register("testpkg", 8, proxy_CallV)
325+
seq.Register("testpkg", 9, proxy_CallVE)
326+
seq.Register("testpkg", 10, proxy_Err)
327+
seq.Register("testpkg", 11, proxy_GC)
328+
seq.Register("testpkg", 12, proxy_Keep)
329+
seq.Register("testpkg", 13, proxy_New)
330+
seq.Register("testpkg", 14, proxy_NumSCollected)
331+
seq.Register("testpkg", 15, proxy_StrDup)
324332
}

bind/java/testpkg/testpkg.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,9 @@ func Err(s string) error {
108108
func BytesAppend(a []byte, b []byte) []byte {
109109
return append(a, b...)
110110
}
111+
112+
func AppendToString(str string, someBytes []byte) []byte {
113+
a := []byte(str)
114+
fmt.Printf("str=%q (len=%d), someBytes=%v (len=%d)\n", str, len(str), someBytes, len(someBytes))
115+
return append(a, someBytes...)
116+
}

0 commit comments

Comments
 (0)