Skip to content

Commit c2b04d0

Browse files
committed
This commit introduces support for calling functions within expressions.
The key changes are: - Interpreter: - Added a new `NodeFunctionCall` AST node to represent a function call. - The interpreter can now execute functions passed in the variables map. It supports functions with 0 to 3 parameters of type `any`. - Implemented short-circuiting for `AND` and `OR` logical operators. - Parser: - The parser now recognizes function call syntax (`identifier(...)`). - It can parse comma-separated arguments in function calls. - Lexer: - Added a `TokenComma` to tokenize function argument lists. - Conversions: - Type conversion functions now handle `func()` types, allowing for lazy evaluation of values. - Testing: - Added a comprehensive test suite for the new function call functionality.
1 parent cc7bbaa commit c2b04d0

File tree

5 files changed

+298
-2
lines changed

5 files changed

+298
-2
lines changed

conversions.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ func isNumber(v interface{}) bool {
1212
return true
1313
case float32, float64:
1414
return true
15+
case func() int:
16+
return true
17+
case func() float64:
18+
return true
1519
}
1620
return false
1721
}
@@ -42,6 +46,10 @@ func toNumber(ast *Node, v interface{}) (float64, Error) {
4246
return float64(n), nil
4347
case float32:
4448
return float64(n), nil
49+
case func() int:
50+
return float64(n()), nil
51+
case func() float64:
52+
return n(), nil
4553
}
4654
return 0, NewError(ast.Offset, ast.Length, "unable to convert to number: %v", v)
4755
}
@@ -64,6 +72,8 @@ func toString(v interface{}) string {
6472
return string(s)
6573
case []byte:
6674
return string(s)
75+
case func() string:
76+
return s()
6777
}
6878
return fmt.Sprintf("%v", v)
6979
}
@@ -162,6 +172,10 @@ func normalize(v interface{}) interface{} {
162172
return float64(n)
163173
case []byte:
164174
return string(n)
175+
case func() int:
176+
return float64(n())
177+
case func() float64:
178+
return n()
165179
}
166180

167181
return v

interpreter.go

Lines changed: 101 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,24 +96,44 @@ func (i *interpreter) run(ast *Node, value any) (any, Error) {
9696
return value, nil
9797
case "length":
9898
// Special pseudo-property to get the value's length.
99+
if s, ok := value.(func() string); ok {
100+
return len(s()), nil
101+
}
99102
if s, ok := value.(string); ok {
100103
return len(s), nil
101104
}
102105
if a, ok := value.([]any); ok {
103106
return len(a), nil
104107
}
105108
case "lower":
109+
if s, ok := value.(func() string); ok {
110+
return strings.ToLower(s()), nil
111+
}
106112
if s, ok := value.(string); ok {
107113
return strings.ToLower(s), nil
108114
}
109115
case "upper":
116+
if s, ok := value.(func() string); ok {
117+
return strings.ToUpper(s()), nil
118+
}
110119
if s, ok := value.(string); ok {
111120
return strings.ToUpper(s), nil
112121
}
113122
}
114123
if m, ok := value.(map[string]any); ok {
115124
if v, ok := m[ast.Value.(string)]; ok {
116-
return v, nil
125+
switch n := v.(type) {
126+
case func() int:
127+
return n(), nil
128+
case func() float64:
129+
return n(), nil
130+
case func() bool:
131+
return n(), nil
132+
case func() string:
133+
return n(), nil
134+
default:
135+
return v, nil
136+
}
117137
}
118138
}
119139
if m, ok := value.(map[any]any); ok {
@@ -335,11 +355,21 @@ func (i *interpreter) run(ast *Node, value any) (any, Error) {
335355
if err != nil {
336356
return nil, err
337357
}
358+
left := toBool(resultLeft)
359+
switch ast.Type {
360+
case NodeAnd:
361+
if !left {
362+
return left, nil
363+
}
364+
case NodeOr:
365+
if left {
366+
return left, nil
367+
}
368+
}
338369
resultRight, err := i.run(ast.Right, value)
339370
if err != nil {
340371
return nil, err
341372
}
342-
left := toBool(resultLeft)
343373
right := toBool(resultRight)
344374
switch ast.Type {
345375
case NodeAnd:
@@ -470,6 +500,75 @@ func (i *interpreter) run(ast *Node, value any) (any, Error) {
470500
}
471501
}
472502
return results, nil
503+
case NodeFunctionCall:
504+
funcName := ast.Left.Value.(string)
505+
if m, ok := value.(map[string]any); ok {
506+
if fn, ok := m[funcName]; ok {
507+
// Get function parameters
508+
params := []any{}
509+
for _, param := range ast.Value.([]Node) {
510+
paramValue, err := i.run(&param, value)
511+
if err != nil {
512+
return nil, err
513+
}
514+
params = append(params, paramValue)
515+
}
516+
517+
// Execute function based on parameter count
518+
switch f := fn.(type) {
519+
case func() any:
520+
if len(params) != 0 {
521+
return nil, NewError(ast.Offset, ast.Length, "function %s expects 0 parameter, got %d", funcName, len(params))
522+
}
523+
result := f()
524+
switch result.(type) {
525+
case error:
526+
return nil, NewError(ast.Offset, ast.Length, "Runtime error: %v", result.(error))
527+
default:
528+
return result, nil
529+
}
530+
case func(any) any:
531+
if len(params) != 1 {
532+
return nil, NewError(ast.Offset, ast.Length, "function %s expects 1 parameter, got %d", funcName, len(params))
533+
}
534+
result := f(params[0])
535+
switch result.(type) {
536+
case error:
537+
return nil, NewError(ast.Offset, ast.Length, "Runtime error: %v", result.(error))
538+
default:
539+
return result, nil
540+
}
541+
case func(any, any) any:
542+
if len(params) != 2 {
543+
return nil, NewError(ast.Offset, ast.Length, "function %s expects 2 parameters, got %d", funcName, len(params))
544+
}
545+
result := f(params[0], params[1])
546+
switch result.(type) {
547+
case error:
548+
return nil, NewError(ast.Offset, ast.Length, "Runtime error: %v", result.(error))
549+
default:
550+
return result, nil
551+
}
552+
case func(any, any, any) any:
553+
if len(params) != 3 {
554+
return nil, NewError(ast.Offset, ast.Length, "function %s expects 3 parameters, got %d", funcName, len(params))
555+
}
556+
result := f(params[0], params[1], params[2])
557+
switch result.(type) {
558+
case error:
559+
return nil, NewError(ast.Offset, ast.Length, "Runtime error: %v", result.(error))
560+
default:
561+
return result, nil
562+
}
563+
564+
}
565+
return nil, NewError(ast.Offset, ast.Length, "unsupported function type for %s", funcName)
566+
}
567+
}
568+
if i.strict {
569+
return nil, NewError(ast.Offset, ast.Length, "function %s not found", funcName)
570+
}
571+
return nil, nil
473572
}
474573
return nil, nil
475574
}

interpreter_test.go

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package mexpr
22

33
import (
44
"encoding/json"
5+
"fmt"
56
"reflect"
67
"strings"
78
"testing"
@@ -232,6 +233,124 @@ func TestInterpreter(t *testing.T) {
232233
}
233234
}
234235

236+
func TestFunctions(t *testing.T) {
237+
238+
varMap := make(map[string]interface{})
239+
240+
varMap["func0"] = func() any {
241+
return 43.0
242+
}
243+
244+
varMap["func1"] = func(param1 any) any {
245+
switch param1.(type) {
246+
case float64:
247+
return param1.(float64) * 2
248+
default:
249+
return fmt.Errorf("Invalid type for param1")
250+
}
251+
}
252+
253+
varMap["func2"] = func(param1 any, param2 any) any {
254+
switch param1.(type) {
255+
case float64:
256+
switch param2.(type) {
257+
case float64:
258+
return param1.(float64) * param2.(float64)
259+
default:
260+
return fmt.Errorf("Invalid type for param2")
261+
}
262+
default:
263+
return fmt.Errorf("Invalid type for param1")
264+
}
265+
}
266+
267+
varMap["func3"] = func(param1 any, param2 any, param3 any) any {
268+
switch param1.(type) {
269+
case float64:
270+
switch param2.(type) {
271+
case float64:
272+
switch param3.(type) {
273+
case float64:
274+
return param1.(float64) * param2.(float64) * param3.(float64)
275+
default:
276+
return fmt.Errorf("Invalid type for param3")
277+
}
278+
default:
279+
return fmt.Errorf("Invalid type for param2")
280+
}
281+
default:
282+
return fmt.Errorf("Invalid type for param1")
283+
}
284+
}
285+
286+
type test struct {
287+
expr string
288+
output interface{}
289+
err string
290+
}
291+
cases := []test{
292+
{expr: "func0()", output: 43.0},
293+
{expr: "func1(42)", output: 84.0},
294+
{expr: "func2(3,4)", output: 12.0},
295+
{expr: "func3(2,3,4)", output: 24.0},
296+
{expr: "func0(42)", err: "expects 0 parameter"},
297+
{expr: "func1()", err: "expects 1 parameter"},
298+
{expr: "func1(1,2)", err: "expects 1 parameter"},
299+
{expr: "func2()", err: "expects 2 parameters"},
300+
{expr: "func2(1)", err: "expects 2 parameters"},
301+
{expr: "func2(1,2,3)", err: "expects 2 parameters"},
302+
{expr: "func3()", err: "expects 3 parameters"},
303+
{expr: "func3(1)", err: "expects 3 parameters"},
304+
{expr: "func3(1,2)", err: "expects 3 parameters"},
305+
{expr: "func3(1,2,3,4)", err: "expects 3 parameters"},
306+
{expr: "func1(\"foo\")", err: "Invalid type for"},
307+
{expr: "func2(\"foo\",\"bar\")", err: "Invalid type for"},
308+
{expr: "func3(\"foo\",\"qux\",\"quz\")", err: "Invalid type for"},
309+
}
310+
311+
for _, tc := range cases {
312+
t.Run(tc.expr, func(t *testing.T) {
313+
314+
ast, err := Parse(tc.expr, nil)
315+
316+
if ast != nil {
317+
t.Log("graph G {\n" + ast.Dot("") + "\n}")
318+
}
319+
320+
if tc.err != "" {
321+
if err != nil {
322+
if strings.Contains(err.Error(), tc.err) {
323+
return
324+
}
325+
t.Fatal(err.Pretty(tc.expr))
326+
}
327+
} else {
328+
if err != nil {
329+
t.Fatal(err.Pretty(tc.expr))
330+
}
331+
}
332+
333+
result, err := Run(ast, varMap, StrictMode)
334+
if tc.err != "" {
335+
if err == nil {
336+
t.Fatal("expected error but found none")
337+
}
338+
if strings.Contains(err.Error(), tc.err) {
339+
return
340+
}
341+
t.Fatal(err.Pretty(tc.expr))
342+
} else {
343+
if err != nil {
344+
t.Fatal(err.Pretty(tc.expr))
345+
}
346+
if !reflect.DeepEqual(tc.output, result) {
347+
t.Fatalf("expected %v but found %v", tc.output, result)
348+
}
349+
}
350+
})
351+
}
352+
}
353+
235354
func FuzzMexpr(f *testing.F) {
236355
f.Fuzz(func(t *testing.T, s string) {
237356
Eval(s, nil)

lexer.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ const (
3131
TokenStringCompare
3232
TokenWhere
3333
TokenEOF
34+
TokenComma // New token type for separating function parameters
3435
)
3536

3637
func (t TokenType) String() string {
@@ -73,6 +74,8 @@ func (t TokenType) String() string {
7374
return "where"
7475
case TokenEOF:
7576
return "eof"
77+
case TokenComma:
78+
return "comma"
7679
}
7780
return "unknown"
7881
}
@@ -97,6 +100,8 @@ func basic(input rune) TokenType {
97100
return TokenMulDiv
98101
case '^':
99102
return TokenPower
103+
case ',':
104+
return TokenComma
100105
}
101106

102107
return TokenUnknown

0 commit comments

Comments
 (0)