Skip to content

Implement Overload decorator in pure python #166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion integration_tests/run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
"test_math1.py"
]

# At present we run these tests on cpython, later we should also move to lpython
test_cpython = [
"test_generics_01.py"
]


def main():
print("Compiling...")
for pyfile in tests:
Expand All @@ -30,7 +36,9 @@ def main():
if r != 0:
print("Command '%s' failed." % cmd)
sys.exit(1)

print("Running...")
python_path="src/runtime/ltypes"
for pyfile in tests:
basename = os.path.splitext(pyfile)[0]
cmd = "integration_tests/%s" % (basename)
Expand All @@ -39,7 +47,16 @@ def main():
if r != 0:
print("Command '%s' failed." % cmd)
sys.exit(1)
python_path="src/runtime/ltypes"
cmd = "PYTHONPATH=%s python integration_tests/%s" % (python_path,
pyfile)
print("+ " + cmd)
r = os.system(cmd)
if r != 0:
print("Command '%s' failed." % cmd)
sys.exit(1)

print("Running cpython tests...")
for pyfile in test_cpython:
cmd = "PYTHONPATH=%s python integration_tests/%s" % (python_path,
pyfile)
print("+ " + cmd)
Expand Down
30 changes: 30 additions & 0 deletions integration_tests/test_generics_01.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from ltypes import overload

@overload
def foo(a: int, b: int) -> int:
return a*b

@overload
def foo(a: int) -> int:
return a**2

@overload
def foo(a: str) -> str:
return "lpython-" + a

@overload
def test(a: int) -> int:
return a + 10

@overload
def test(a: bool) -> int:
if a:
return 10
return -10


assert foo(2) == 4
assert foo(2, 10) == 20
assert foo("hello") == "lpython-hello"
assert test(10) == 20
assert test(False) == -test(True) and test(True) == 10
47 changes: 47 additions & 0 deletions src/runtime/ltypes/ltypes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,53 @@
from inspect import getfullargspec, getcallargs

# data-types

i32 = []
i64 = []
f32 = []
f64 = []
c32 = []
c64 = []

# overloading support

class OverloadedFunction:
"""
A wrapper class for allowing overloading.
"""
global_map = {}

def __init__(self, func):
self.func_name = func.__name__
f_list = self.global_map.get(func.__name__, [])
f_list.append((func, getfullargspec(func)))
self.global_map[func.__name__] = f_list

def __call__(self, *args, **kwargs):
func_map_list = self.global_map.get(self.func_name, False)
if not func_map_list:
raise Exception("Function not defined")
for item in func_map_list:
func, key = item
try:
# This might fail for the cases when arguments don't match
ann_dict = getcallargs(func, *args, **kwargs)
except TypeError:
continue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the arguments don't match, shouldn't we rather raise an exception that the arguments don't match?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suppose, the list contains two functions foo(a) and foo(a, b), then, foo(1, 2) will raise an error if we use getcallargs on the first function.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. We should determine which of the overloads contains two integer arguments in this case. And just call that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's right!

flag = True
for k, v in ann_dict.items():
if not key.annotations.get(k, False):
flag = False
break
else:
if type(v) != key.annotations.get(k):
flag = False
break
if flag:
return func(*args, **kwargs)
raise Exception("Function not found with matching signature")


def overload(f):
overloaded_f = OverloadedFunction(f)
return overloaded_f