Skip to content

Commit bfb1bfd

Browse files
authored
Merge pull request #166 from Smit-create/overload1
Implement Overload decorator in pure python
2 parents 27ee768 + d1b8bec commit bfb1bfd

File tree

3 files changed

+95
-1
lines changed

3 files changed

+95
-1
lines changed

integration_tests/run_tests.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@
1818
"test_math1.py"
1919
]
2020

21+
# At present we run these tests on cpython, later we should also move to lpython
22+
test_cpython = [
23+
"test_generics_01.py"
24+
]
25+
26+
2127
def main():
2228
print("Compiling...")
2329
for pyfile in tests:
@@ -30,7 +36,9 @@ def main():
3036
if r != 0:
3137
print("Command '%s' failed." % cmd)
3238
sys.exit(1)
39+
3340
print("Running...")
41+
python_path="src/runtime/ltypes"
3442
for pyfile in tests:
3543
basename = os.path.splitext(pyfile)[0]
3644
cmd = "integration_tests/%s" % (basename)
@@ -39,7 +47,16 @@ def main():
3947
if r != 0:
4048
print("Command '%s' failed." % cmd)
4149
sys.exit(1)
42-
python_path="src/runtime/ltypes"
50+
cmd = "PYTHONPATH=%s python integration_tests/%s" % (python_path,
51+
pyfile)
52+
print("+ " + cmd)
53+
r = os.system(cmd)
54+
if r != 0:
55+
print("Command '%s' failed." % cmd)
56+
sys.exit(1)
57+
58+
print("Running cpython tests...")
59+
for pyfile in test_cpython:
4360
cmd = "PYTHONPATH=%s python integration_tests/%s" % (python_path,
4461
pyfile)
4562
print("+ " + cmd)

integration_tests/test_generics_01.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from ltypes import overload
2+
3+
@overload
4+
def foo(a: int, b: int) -> int:
5+
return a*b
6+
7+
@overload
8+
def foo(a: int) -> int:
9+
return a**2
10+
11+
@overload
12+
def foo(a: str) -> str:
13+
return "lpython-" + a
14+
15+
@overload
16+
def test(a: int) -> int:
17+
return a + 10
18+
19+
@overload
20+
def test(a: bool) -> int:
21+
if a:
22+
return 10
23+
return -10
24+
25+
26+
assert foo(2) == 4
27+
assert foo(2, 10) == 20
28+
assert foo("hello") == "lpython-hello"
29+
assert test(10) == 20
30+
assert test(False) == -test(True) and test(True) == 10

src/runtime/ltypes/ltypes.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,53 @@
1+
from inspect import getfullargspec, getcallargs
2+
3+
# data-types
4+
15
i32 = []
26
i64 = []
37
f32 = []
48
f64 = []
59
c32 = []
610
c64 = []
11+
12+
# overloading support
13+
14+
class OverloadedFunction:
15+
"""
16+
A wrapper class for allowing overloading.
17+
"""
18+
global_map = {}
19+
20+
def __init__(self, func):
21+
self.func_name = func.__name__
22+
f_list = self.global_map.get(func.__name__, [])
23+
f_list.append((func, getfullargspec(func)))
24+
self.global_map[func.__name__] = f_list
25+
26+
def __call__(self, *args, **kwargs):
27+
func_map_list = self.global_map.get(self.func_name, False)
28+
if not func_map_list:
29+
raise Exception("Function not defined")
30+
for item in func_map_list:
31+
func, key = item
32+
try:
33+
# This might fail for the cases when arguments don't match
34+
ann_dict = getcallargs(func, *args, **kwargs)
35+
except TypeError:
36+
continue
37+
flag = True
38+
for k, v in ann_dict.items():
39+
if not key.annotations.get(k, False):
40+
flag = False
41+
break
42+
else:
43+
if type(v) != key.annotations.get(k):
44+
flag = False
45+
break
46+
if flag:
47+
return func(*args, **kwargs)
48+
raise Exception("Function not found with matching signature")
49+
50+
51+
def overload(f):
52+
overloaded_f = OverloadedFunction(f)
53+
return overloaded_f

0 commit comments

Comments
 (0)