From c81a76b2bc03151b009fe4502a90565dc1ec1071 Mon Sep 17 00:00:00 2001 From: Brian Atkinson Date: Sat, 21 Jan 2017 15:23:20 -0600 Subject: [PATCH] Enable passing Python functions to Go for invocation. --- runtime/function.go | 5 ++ runtime/native.go | 102 ++++++++++++++++++++++++++++++++++-- runtime/native_test.go | 116 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 218 insertions(+), 5 deletions(-) diff --git a/runtime/function.go b/runtime/function.go index f3ed41ba..d25641fa 100644 --- a/runtime/function.go +++ b/runtime/function.go @@ -125,6 +125,10 @@ func functionGet(_ *Frame, desc, instance *Object, owner *Type) (*Object, *BaseE return NewMethod(toFunctionUnsafe(desc), instance, owner).ToObject(), nil } +func functionNative(f *Frame, o *Object) (reflect.Value, *BaseException) { + return reflect.ValueOf(o.Call), nil +} + func functionRepr(_ *Frame, o *Object) (*Object, *BaseException) { fun := toFunctionUnsafe(o) return NewStr(fmt.Sprintf("<%s %s at %p>", fun.typ.Name(), fun.Name(), fun)).ToObject(), nil @@ -134,6 +138,7 @@ func initFunctionType(map[string]*Object) { FunctionType.flags &= ^(typeFlagInstantiable | typeFlagBasetype) FunctionType.slots.Call = &callSlot{functionCall} FunctionType.slots.Get = &getSlot{functionGet} + FunctionType.slots.Native = &nativeSlot{functionNative} FunctionType.slots.Repr = &unaryOpSlot{functionRepr} } diff --git a/runtime/native.go b/runtime/native.go index 05681e22..1486960d 100644 --- a/runtime/native.go +++ b/runtime/native.go @@ -56,6 +56,9 @@ var ( } nativeTypesMutex = sync.Mutex{} sliceIteratorType = newBasisType("sliceiterator", reflect.TypeOf(sliceIterator{}), toSliceIteratorUnsafe, ObjectType) + + baseExceptionReflectType = reflect.TypeOf((*BaseException)(nil)) + frameReflectType = reflect.TypeOf((*Frame)(nil)) ) type nativeMetaclass struct { @@ -489,22 +492,111 @@ func maybeConvertValue(f *Frame, o *Object, expectedRType reflect.Type) (reflect if raised != nil { return reflect.Value{}, raised } - rtype := val.Type() for { + rtype := val.Type() if rtype == expectedRType { return val, nil } if rtype.ConvertibleTo(expectedRType) { return val.Convert(expectedRType), nil } - if rtype.Kind() == reflect.Ptr { + switch rtype.Kind() { + case reflect.Ptr: val = val.Elem() - rtype = val.Type() continue + + case reflect.Func: + if fn, ok := val.Interface().(func(*Frame, Args, KWArgs) (*Object, *BaseException)); ok { + val = nativeToPyFuncBridge(fn, expectedRType) + continue + } } - break + return val, f.RaiseType(TypeErrorType, fmt.Sprintf("cannot convert %s to %s", rtype, expectedRType)) + } +} + +// pyToNativeRaised supports pushing a `raised` exception from python code to +// native calling code. If the raised exception can't be returned to native +// code, then the raised exception is panic-ed. +func pyToNativeRaised(outs []reflect.Type, raised *BaseException) []reflect.Value { + last := len(outs) - 1 + if len(outs) == 0 || outs[last] != baseExceptionReflectType { + panic(raised) } - return reflect.Value{}, f.RaiseType(TypeErrorType, fmt.Sprintf("cannot convert %s to %s", rtype, expectedRType)) + ret := make([]reflect.Value, len(outs)) + for i, out := range outs[:last] { + ret[i] = reflect.Zero(out) + } + ret[last] = reflect.ValueOf(raised) + return ret +} + +func nativeToPyFuncBridge(fn func(*Frame, Args, KWArgs) (*Object, *BaseException), target reflect.Type) reflect.Value { + firstInIsFrame := target.NumIn() > 0 && target.In(0) == frameReflectType + + outs := make([]reflect.Type, target.NumOut()) + for i := range outs { + outs[i] = target.Out(i) + } + + return reflect.MakeFunc(target, func(args []reflect.Value) []reflect.Value { + var f *Frame + if firstInIsFrame { + f, args = args[0].Interface().(*Frame), args[1:] + } else { + f = NewRootFrame() + } + + pyArgs := f.MakeArgs(len(args)) + for i, arg := range args { + var raised *BaseException + pyArgs[i], raised = WrapNative(f, arg) + if raised != nil { + return pyToNativeRaised(outs, raised) + } + } + + ret, raised := fn(f, pyArgs, nil) + f.FreeArgs(pyArgs) + if raised != nil { + return pyToNativeRaised(outs, raised) + } + + switch len(outs) { + case 0: + if ret != nil && ret != None { + return pyToNativeRaised(outs, f.RaiseType(TypeErrorType, fmt.Sprintf("unexpected return of %v when None expected", ret))) + } + return nil + + case 1: + v, raised := maybeConvertValue(f, ret, outs[0]) + if raised != nil { + return pyToNativeRaised(outs, raised) + } + return []reflect.Value{v} + + default: + converted := make([]reflect.Value, 0, len(outs)) + if raised := seqForEach(f, ret, func(o *Object) *BaseException { + i := len(converted) + if i >= len(outs) { + return f.RaiseType(TypeErrorType, fmt.Sprintf("return value too long, want %d items", len(outs))) + } + v, raised := maybeConvertValue(f, o, outs[i]) + converted = append(converted, v) + return raised + }); raised != nil { + return pyToNativeRaised(outs, raised) + } + + if len(converted) != len(outs) { + return pyToNativeRaised(outs, f.RaiseType(TypeErrorType, fmt.Sprintf("return value wrong size %d, want %d", len(converted), len(outs)))) + } + + return converted + } + }) } func nativeFuncTypeName(rtype reflect.Type) string { diff --git a/runtime/native_test.go b/runtime/native_test.go index 3c69e636..7bed836d 100644 --- a/runtime/native_test.go +++ b/runtime/native_test.go @@ -422,6 +422,122 @@ func TestMaybeConvertValue(t *testing.T) { } } +func TestNativveToPyFuncBridge(t *testing.T) { + tests := []struct { + name string + fn func(*testing.T, *Frame, Args, KWArgs) (*Object, *BaseException) + typ reflect.Type + args []interface{} + ret []interface{} + panc *BaseException + }{ + { + name: "no args", + fn: func(t *testing.T, f *Frame, a Args, k KWArgs) (*Object, *BaseException) { + if f == nil || len(a) != 0 || len(k) != 0 { + t.Errorf("fn called with (%v, %v, %v), want (non-nil, %v, %v)", f, a, k, Args{}, KWArgs{}) + } + return nil, nil + }, + typ: reflect.TypeOf(func() {}), + ret: []interface{}{}, + }, + { + name: "return wrong size", + fn: func(*testing.T, *Frame, Args, KWArgs) (*Object, *BaseException) { + return NewInt(1).ToObject(), nil + }, + typ: reflect.TypeOf(func() {}), + panc: mustCreateException(TypeErrorType, "unexpected return of 1 when None expected"), + }, + { + name: "single return value", + fn: func(*testing.T, *Frame, Args, KWArgs) (*Object, *BaseException) { + return NewInt(1).ToObject(), nil + }, + typ: reflect.TypeOf((*func() int)(nil)).Elem(), + ret: []interface{}{1}, + }, + { + name: "wrong size multiple return value", + fn: func(*testing.T, *Frame, Args, KWArgs) (*Object, *BaseException) { + return NewTuple(NewInt(1).ToObject(), NewInt(2).ToObject(), NewInt(3).ToObject()).ToObject(), nil + }, + typ: reflect.TypeOf((*func() (int, int))(nil)).Elem(), + panc: mustCreateException(TypeErrorType, "return value too long, want 2 items"), + }, + { + name: "multiple return value", + fn: func(*testing.T, *Frame, Args, KWArgs) (*Object, *BaseException) { + return NewTuple(NewInt(1).ToObject(), NewInt(2).ToObject(), NewInt(3).ToObject()).ToObject(), nil + }, + typ: reflect.TypeOf((*func() (int, int, int))(nil)).Elem(), + ret: []interface{}{1, 2, 3}, + }, + + { + name: "func takes args", + fn: func(t *testing.T, f *Frame, a Args, k KWArgs) (*Object, *BaseException) { + want := Args{ + NewInt(1).ToObject(), + NewInt(2).ToObject(), + } + if f == nil || !reflect.DeepEqual(a, want) || len(k) != 0 { + t.Errorf("fn called with (%v, %v, %v), want (non-nil, %v, %v)", f, a, k, want, KWArgs{}) + } + return nil, nil + }, + typ: reflect.TypeOf(func(int, int) {}), + args: []interface{}{1, 2}, + ret: []interface{}{}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + called := false + fn := func(f *Frame, a Args, k KWArgs) (*Object, *BaseException) { + called = true + return test.fn(t, f, a, k) + } + + args := make([]reflect.Value, len(test.args)) + for i, a := range test.args { + args[i] = reflect.ValueOf(a) + } + + nativeFn := nativeToPyFuncBridge(fn, test.typ) + ret := func() []reflect.Value { + if test.panc != nil { + defer func() { + r := recover() + raised, ok := r.(*BaseException) + if r == nil || !ok || !exceptionsAreEquivalent(raised, test.panc) { + t.Errorf("recover()=%v (type %T), want %v", r, r, test.panc) + } + }() + } + return nativeFn.Call(args) + }() + + if test.panc == nil { + got := make([]interface{}, 0, len(test.ret)) + for _, v := range ret { + got = append(got, v.Interface()) + } + + if !reflect.DeepEqual(got, test.ret) { + t.Errorf("fn returned %v, want %v", got, test.ret) + } + } + + if !called { + t.Errorf("fn not called, want to be called") + } + }) + } +} + func TestNativeTypedefNative(t *testing.T) { fun := wrapFuncForTest(func(f *Frame, o *Object, wantType reflect.Type) (bool, *BaseException) { val, raised := ToNative(f, o)