Skip to content

Commit fd5b4ca

Browse files
committed
This commit introduces support for calling functions within expressions.
The key changes are: - Interpreter: - Added a new `NodeFunctionCall` AST node to represent a function call. - The interpreter can now execute functions passed in the variables map. It supports functions with 0 to 3 parameters of type `any`. - Implemented short-circuiting for `AND` and `OR` logical operators. - Parser: - The parser now recognizes function call syntax (`identifier(...)`). - It can parse comma-separated arguments in function calls. - Lexer: - Added a `TokenComma` to tokenize function argument lists. - Conversions: - Type conversion functions now handle `func()` types, allowing for lazy evaluation of values. - Testing: - Added a comprehensive test suite for the new function call functionality.
1 parent cc7bbaa commit fd5b4ca

File tree

6 files changed

+330
-2
lines changed

6 files changed

+330
-2
lines changed

README.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ Features:
1111
- Fast, low-allocation parser and runtime
1212
- Many simple expressions are zero-allocation
1313
- Type checking during parsing
14+
- Type conversion for `func()`
1415
- Simple
1516
- Easy to learn
1617
- Easy to read
@@ -137,6 +138,8 @@ Math operations between constants are precomputed when possible, so it is effici
137138
- `and`
138139
- `or`
139140

141+
Both `and` and `or` are short-circuited.
142+
140143
```py
141144
1 < 2 and 3 < 4
142145
```
@@ -148,6 +151,20 @@ Non-boolean values are converted to booleans. The following result in `true`:
148151
- array with at least one item
149152
- map with at least one key/value pair
150153

154+
### Functions
155+
156+
- `identifier(...)`
157+
158+
Functions can be called by providing them in the variables map.
159+
160+
```go
161+
result, err := mexpr.Eval("myFunc(a, b)", map[string]interface{}{
162+
"myFunc": func(a, b int) int { return a + b },
163+
"a": 1,
164+
"b": 2,
165+
})
166+
```
167+
151168
### String operators
152169

153170
- Indexing, e.g. `foo[0]`
@@ -221,6 +238,21 @@ not (items where id > 3)
221238
- `in` (has key), e.g. `"key" in foo`
222239
- `contains` e.g. `foo contains "key"`
223240

241+
### Conversions
242+
243+
Any value concatenated with a string will result in a string. For example `"id" + 1` will result in `"id1"`.
244+
245+
The value of a variable can be mapped to a function. This allows the implementor to use functions to retrieve actual values of variables rather than pre-computing values:
246+
247+
```go
248+
result, _ := mexpr.Eval(`id + 1`, map[string]interface{}{
249+
"id": func() int { return 123 },
250+
})
251+
// result is 124
252+
```
253+
254+
In combination with short-circuiting with and/or it allows lazy evaluation.
255+
224256
#### Map wildcard filtering
225257

226258
A `where` clause can be used as a wildcard key to filter values for all keys in a map. The left side of the clause is the map to be filtered, while the right side is an expression to run on each value of the map. If the right side expression evaluates to true then the value is added to the result slice. For example, given:

conversions.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ func isNumber(v interface{}) bool {
1212
return true
1313
case float32, float64:
1414
return true
15+
case func() int:
16+
return true
17+
case func() float64:
18+
return true
1519
}
1620
return false
1721
}
@@ -42,6 +46,10 @@ func toNumber(ast *Node, v interface{}) (float64, Error) {
4246
return float64(n), nil
4347
case float32:
4448
return float64(n), nil
49+
case func() int:
50+
return float64(n()), nil
51+
case func() float64:
52+
return n(), nil
4553
}
4654
return 0, NewError(ast.Offset, ast.Length, "unable to convert to number: %v", v)
4755
}
@@ -64,6 +72,8 @@ func toString(v interface{}) string {
6472
return string(s)
6573
case []byte:
6674
return string(s)
75+
case func() string:
76+
return s()
6777
}
6878
return fmt.Sprintf("%v", v)
6979
}
@@ -162,6 +172,10 @@ func normalize(v interface{}) interface{} {
162172
return float64(n)
163173
case []byte:
164174
return string(n)
175+
case func() int:
176+
return float64(n())
177+
case func() float64:
178+
return n()
165179
}
166180

167181
return v

interpreter.go

Lines changed: 101 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,24 +96,44 @@ func (i *interpreter) run(ast *Node, value any) (any, Error) {
9696
return value, nil
9797
case "length":
9898
// Special pseudo-property to get the value's length.
99+
if s, ok := value.(func() string); ok {
100+
return len(s()), nil
101+
}
99102
if s, ok := value.(string); ok {
100103
return len(s), nil
101104
}
102105
if a, ok := value.([]any); ok {
103106
return len(a), nil
104107
}
105108
case "lower":
109+
if s, ok := value.(func() string); ok {
110+
return strings.ToLower(s()), nil
111+
}
106112
if s, ok := value.(string); ok {
107113
return strings.ToLower(s), nil
108114
}
109115
case "upper":
116+
if s, ok := value.(func() string); ok {
117+
return strings.ToUpper(s()), nil
118+
}
110119
if s, ok := value.(string); ok {
111120
return strings.ToUpper(s), nil
112121
}
113122
}
114123
if m, ok := value.(map[string]any); ok {
115124
if v, ok := m[ast.Value.(string)]; ok {
116-
return v, nil
125+
switch n := v.(type) {
126+
case func() int:
127+
return n(), nil
128+
case func() float64:
129+
return n(), nil
130+
case func() bool:
131+
return n(), nil
132+
case func() string:
133+
return n(), nil
134+
default:
135+
return v, nil
136+
}
117137
}
118138
}
119139
if m, ok := value.(map[any]any); ok {
@@ -335,11 +355,21 @@ func (i *interpreter) run(ast *Node, value any) (any, Error) {
335355
if err != nil {
336356
return nil, err
337357
}
358+
left := toBool(resultLeft)
359+
switch ast.Type {
360+
case NodeAnd:
361+
if !left {
362+
return left, nil
363+
}
364+
case NodeOr:
365+
if left {
366+
return left, nil
367+
}
368+
}
338369
resultRight, err := i.run(ast.Right, value)
339370
if err != nil {
340371
return nil, err
341372
}
342-
left := toBool(resultLeft)
343373
right := toBool(resultRight)
344374
switch ast.Type {
345375
case NodeAnd:
@@ -470,6 +500,75 @@ func (i *interpreter) run(ast *Node, value any) (any, Error) {
470500
}
471501
}
472502
return results, nil
503+
case NodeFunctionCall:
504+
funcName := ast.Left.Value.(string)
505+
if m, ok := value.(map[string]any); ok {
506+
if fn, ok := m[funcName]; ok {
507+
// Get function parameters
508+
params := []any{}
509+
for _, param := range ast.Value.([]Node) {
510+
paramValue, err := i.run(&param, value)
511+
if err != nil {
512+
return nil, err
513+
}
514+
params = append(params, paramValue)
515+
}
516+
517+
// Execute function based on parameter count
518+
switch f := fn.(type) {
519+
case func() any:
520+
if len(params) != 0 {
521+
return nil, NewError(ast.Offset, ast.Length, "function %s expects 0 parameter, got %d", funcName, len(params))
522+
}
523+
result := f()
524+
switch result.(type) {
525+
case error:
526+
return nil, NewError(ast.Offset, ast.Length, "Runtime error: %v", result.(error))
527+
default:
528+
return result, nil
529+
}
530+
case func(any) any:
531+
if len(params) != 1 {
532+
return nil, NewError(ast.Offset, ast.Length, "function %s expects 1 parameter, got %d", funcName, len(params))
533+
}
534+
result := f(params[0])
535+
switch result.(type) {
536+
case error:
537+
return nil, NewError(ast.Offset, ast.Length, "Runtime error: %v", result.(error))
538+
default:
539+
return result, nil
540+
}
541+
case func(any, any) any:
542+
if len(params) != 2 {
543+
return nil, NewError(ast.Offset, ast.Length, "function %s expects 2 parameters, got %d", funcName, len(params))
544+
}
545+
result := f(params[0], params[1])
546+
switch result.(type) {
547+
case error:
548+
return nil, NewError(ast.Offset, ast.Length, "Runtime error: %v", result.(error))
549+
default:
550+
return result, nil
551+
}
552+
case func(any, any, any) any:
553+
if len(params) != 3 {
554+
return nil, NewError(ast.Offset, ast.Length, "function %s expects 3 parameters, got %d", funcName, len(params))
555+
}
556+
result := f(params[0], params[1], params[2])
557+
switch result.(type) {
558+
case error:
559+
return nil, NewError(ast.Offset, ast.Length, "Runtime error: %v", result.(error))
560+
default:
561+
return result, nil
562+
}
563+
564+
}
565+
return nil, NewError(ast.Offset, ast.Length, "unsupported function type for %s", funcName)
566+
}
567+
}
568+
if i.strict {
569+
return nil, NewError(ast.Offset, ast.Length, "function %s not found", funcName)
570+
}
571+
return nil, nil
473572
}
474573
return nil, nil
475574
}

interpreter_test.go

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package mexpr
22

33
import (
44
"encoding/json"
5+
"fmt"
56
"reflect"
67
"strings"
78
"testing"
@@ -232,6 +233,124 @@ func TestInterpreter(t *testing.T) {
232233
}
233234
}
234235

236+
func TestFunctions(t *testing.T) {
237+
238+
varMap := make(map[string]interface{})
239+
240+
varMap["func0"] = func() any {
241+
return 43.0
242+
}
243+
244+
varMap["func1"] = func(param1 any) any {
245+
switch param1.(type) {
246+
case float64:
247+
return param1.(float64) * 2
248+
default:
249+
return fmt.Errorf("Invalid type for param1")
250+
}
251+
}
252+
253+
varMap["func2"] = func(param1 any, param2 any) any {
254+
switch param1.(type) {
255+
case float64:
256+
switch param2.(type) {
257+
case float64:
258+
return param1.(float64) * param2.(float64)
259+
default:
260+
return fmt.Errorf("Invalid type for param2")
261+
}
262+
default:
263+
return fmt.Errorf("Invalid type for param1")
264+
}
265+
}
266+
267+
varMap["func3"] = func(param1 any, param2 any, param3 any) any {
268+
switch param1.(type) {
269+
case float64:
270+
switch param2.(type) {
271+
case float64:
272+
switch param3.(type) {
273+
case float64:
274+
return param1.(float64) * param2.(float64) * param3.(float64)
275+
default:
276+
return fmt.Errorf("Invalid type for param3")
277+
}
278+
default:
279+
return fmt.Errorf("Invalid type for param2")
280+
}
281+
default:
282+
return fmt.Errorf("Invalid type for param1")
283+
}
284+
}
285+
286+
type test struct {
287+
expr string
288+
output interface{}
289+
err string
290+
}
291+
cases := []test{
292+
{expr: "func0()", output: 43.0},
293+
{expr: "func1(42)", output: 84.0},
294+
{expr: "func2(3,4)", output: 12.0},
295+
{expr: "func3(2,3,4)", output: 24.0},
296+
{expr: "func0(42)", err: "expects 0 parameter"},
297+
{expr: "func1()", err: "expects 1 parameter"},
298+
{expr: "func1(1,2)", err: "expects 1 parameter"},
299+
{expr: "func2()", err: "expects 2 parameters"},
300+
{expr: "func2(1)", err: "expects 2 parameters"},
301+
{expr: "func2(1,2,3)", err: "expects 2 parameters"},
302+
{expr: "func3()", err: "expects 3 parameters"},
303+
{expr: "func3(1)", err: "expects 3 parameters"},
304+
{expr: "func3(1,2)", err: "expects 3 parameters"},
305+
{expr: "func3(1,2,3,4)", err: "expects 3 parameters"},
306+
{expr: "func1(\"foo\")", err: "Invalid type for"},
307+
{expr: "func2(\"foo\",\"bar\")", err: "Invalid type for"},
308+
{expr: "func3(\"foo\",\"qux\",\"quz\")", err: "Invalid type for"},
309+
}
310+
311+
for _, tc := range cases {
312+
t.Run(tc.expr, func(t *testing.T) {
313+
314+
ast, err := Parse(tc.expr, nil)
315+
316+
if ast != nil {
317+
t.Log("graph G {\n" + ast.Dot("") + "\n}")
318+
}
319+
320+
if tc.err != "" {
321+
if err != nil {
322+
if strings.Contains(err.Error(), tc.err) {
323+
return
324+
}
325+
t.Fatal(err.Pretty(tc.expr))
326+
}
327+
} else {
328+
if err != nil {
329+
t.Fatal(err.Pretty(tc.expr))
330+
}
331+
}
332+
333+
result, err := Run(ast, varMap, StrictMode)
334+
if tc.err != "" {
335+
if err == nil {
336+
t.Fatal("expected error but found none")
337+
}
338+
if strings.Contains(err.Error(), tc.err) {
339+
return
340+
}
341+
t.Fatal(err.Pretty(tc.expr))
342+
} else {
343+
if err != nil {
344+
t.Fatal(err.Pretty(tc.expr))
345+
}
346+
if !reflect.DeepEqual(tc.output, result) {
347+
t.Fatalf("expected %v but found %v", tc.output, result)
348+
}
349+
}
350+
})
351+
}
352+
}
353+
235354
func FuzzMexpr(f *testing.F) {
236355
f.Fuzz(func(t *testing.T, s string) {
237356
Eval(s, nil)

0 commit comments

Comments
 (0)