Skip to content

Commit af8ecca

Browse files
authored
Merge pull request #250 from namannimmo10/overload
Overload all existing built-in functions
2 parents 314443f + 149f6cf commit af8ecca

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+245
-100
lines changed

integration_tests/test_builtin_abs.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from ltypes import f64
1+
from ltypes import f32, f64, i32
22

33
def test_abs():
44
x: f64
@@ -9,5 +9,19 @@ def test_abs():
99
assert abs(5.5) == 5.5
1010
assert abs(-5.5) == 5.5
1111

12+
x2: f32
13+
x2 = -5.5
14+
assert abs(x2) == 5.5
15+
16+
i: i32
17+
i = -5
18+
assert abs(i) == 5
19+
assert abs(-1) == 1
20+
21+
b: bool
22+
b = True
23+
assert abs(b) == 1
24+
b = False
25+
assert abs(b) == 0
1226

1327
test_abs()

integration_tests/test_builtin_pow.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,39 @@ def test_pow():
1717
assert pow(a, b) == 1
1818
a = 2
1919
b = -1
20-
assert abs(pow(2, -1) - 0.5) < eps
21-
# assert abs(pow(a, b) - 0.5) < eps
20+
print(pow(a, b))
2221
a = 6
2322
b = -4
23+
print(pow(a, b))
24+
25+
a1: f64
26+
a2: f64
27+
a1 = 4.5
28+
a2 = 2.3
29+
assert abs(pow(a1, a2) - 31.7971929089206) < eps
30+
assert abs(pow(a2, a1) - 42.43998894277659) < eps
31+
32+
x: i32
33+
x = 3
34+
y: f64
35+
y = 2.3
36+
assert abs(pow(x, y) - 12.513502532843182) < eps
37+
assert abs(pow(y, x) - 12.166999999999998) < eps
38+
assert abs(pow(x, 5.5) - 420.8883462392372) < eps
39+
40+
assert abs(pow(2, -1) - 0.5) < eps
2441
assert abs(pow(6, -4) - 0.0007716049382716049) < eps
25-
# assert abs(pow(a, b) - 0.0007716049382716049) < eps
2642
assert abs(pow(-3, -5) + 0.00411522633744856) < eps
2743
assert abs(pow(6, -4) - 0.0007716049382716049) < eps
44+
assert abs(pow(4.5, 2.3) - 31.7971929089206) < eps
45+
assert abs(pow(2.3, 0.0) - 1.0) < eps
46+
assert abs(pow(2.3, -1.5) - 0.2866871623459944) < eps
47+
assert abs(pow(2, 3.4) - 10.556063286183154) < eps
48+
assert abs(pow(2, -3.4) - 0.09473228540689989) < eps
49+
assert abs(pow(3.4, 9) - 60716.99276646398) < eps
50+
assert abs(pow(0.0, 53) - 0.0) < eps
51+
assert pow(4, 2) == 16
52+
assert abs(pow(-4235.0, 52) - 3.948003805985264e+188) < eps
2853

2954

3055
test_pow()

integration_tests/test_builtin_round.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from ltypes import f64
1+
from ltypes import i32, f32, f64
22

33
def test_round():
44
f: f64
@@ -22,5 +22,16 @@ def test_round():
2222
assert round(50.5) == 50
2323
assert round(56.78) == 57
2424

25+
i: i32
26+
i = -5
27+
assert round(i) == -5
28+
assert round(4) == 4
29+
30+
b: bool
31+
b = True
32+
assert round(b) == 1
33+
b = False
34+
assert round(b) == 0
35+
assert round(False) == 0
2536

2637
test_round()

src/lpython/semantics/python_comptime_eval.h

Lines changed: 56 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -97,26 +97,32 @@ struct PythonIntrinsicProcedures {
9797
) {
9898
LFORTRAN_ASSERT(ASRUtils::all_args_evaluated(args));
9999
if (args.size() != 1) {
100-
throw SemanticError("Intrinsic abs function accepts exactly 1 argument", loc);
100+
throw SemanticError("abs() takes exactly one argument (" +
101+
std::to_string(args.size()) + " given)", loc);
101102
}
102-
ASR::expr_t* trig_arg = args[0];
103+
ASR::expr_t* arg = args[0];
103104
ASR::ttype_t* t = ASRUtils::expr_type(args[0]);
104-
if (ASR::is_a<ASR::Real_t>(*t)) {
105-
double rv = ASR::down_cast<ASR::ConstantReal_t>(trig_arg)->m_r;
105+
ASR::ttype_t *int_type = ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4, nullptr, 0));
106+
ASR::ttype_t *real_type = ASRUtils::TYPE(ASR::make_Real_t(al, loc, 8, nullptr, 0));
107+
if (ASRUtils::is_real(*t)) {
108+
double rv = ASR::down_cast<ASR::ConstantReal_t>(arg)->m_r;
106109
double val = std::abs(rv);
107-
return ASR::down_cast<ASR::expr_t>(ASR::make_ConstantReal_t(al, loc, val, t));
108-
} else if (ASR::is_a<ASR::Integer_t>(*t)) {
109-
int64_t rv = ASR::down_cast<ASR::ConstantInteger_t>(trig_arg)->m_n;
110+
return ASR::down_cast<ASR::expr_t>(ASR::make_ConstantReal_t(al, loc, val, real_type));
111+
} else if (ASRUtils::is_integer(*t)) {
112+
int64_t rv = ASR::down_cast<ASR::ConstantInteger_t>(arg)->m_n;
110113
int64_t val = std::abs(rv);
111-
return ASR::down_cast<ASR::expr_t>(ASR::make_ConstantInteger_t(al, loc, val, t));
112-
} else if (ASR::is_a<ASR::Complex_t>(*t)) {
113-
double re = ASR::down_cast<ASR::ConstantComplex_t>(trig_arg)->m_re;
114-
double im = ASR::down_cast<ASR::ConstantComplex_t>(trig_arg)->m_im;
114+
return ASR::down_cast<ASR::expr_t>(ASR::make_ConstantInteger_t(al, loc, val, int_type));
115+
} else if (ASRUtils::is_logical(*t)) {
116+
int8_t val = ASR::down_cast<ASR::ConstantLogical_t>(arg)->m_value;
117+
return ASR::down_cast<ASR::expr_t>(ASR::make_ConstantInteger_t(al, loc, val, int_type));
118+
} else if (ASRUtils::is_complex(*t)) {
119+
double re = ASR::down_cast<ASR::ConstantComplex_t>(arg)->m_re;
120+
double im = ASR::down_cast<ASR::ConstantComplex_t>(arg)->m_im;
115121
std::complex<double> x(re, im);
116122
double result = std::abs(x);
117-
return ASR::down_cast<ASR::expr_t>(ASR::make_ConstantReal_t(al, loc, result, t));
123+
return ASR::down_cast<ASR::expr_t>(ASR::make_ConstantReal_t(al, loc, result, real_type));
118124
} else {
119-
throw SemanticError("Argument of the abs function must be Integer, Real or Complex", loc);
125+
throw SemanticError("Argument of the abs function must be Integer, Real, Logical or Complex", loc);
120126
}
121127
}
122128

@@ -249,15 +255,12 @@ struct PythonIntrinsicProcedures {
249255

250256
static ASR::expr_t *eval_pow(Allocator &al, const Location &loc, Vec<ASR::expr_t*> &args) {
251257
LFORTRAN_ASSERT(ASRUtils::all_args_evaluated(args));
252-
ASR::expr_t* arg1 = ASRUtils::expr_value(args[0]);
253-
ASR::expr_t* arg2 = ASRUtils::expr_value(args[1]);
258+
ASR::expr_t* arg1 = args[0];
259+
ASR::expr_t* arg2 = args[1];
254260
ASR::ttype_t* arg1_type = ASRUtils::expr_type(arg1);
255261
ASR::ttype_t* arg2_type = ASRUtils::expr_type(arg2);
256262
ASR::ttype_t *int_type = ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4, nullptr, 0));
257263
ASR::ttype_t *real_type = ASRUtils::TYPE(ASR::make_Real_t(al, loc, 8, nullptr, 0));
258-
if (!ASRUtils::check_equal_type(arg1_type, arg2_type)) {
259-
throw SemanticError("The arguments to pow() must have the same type.", loc);
260-
}
261264
if (ASRUtils::is_integer(*arg1_type) && ASRUtils::is_integer(*arg2_type)) {
262265
int64_t a = ASR::down_cast<ASR::ConstantInteger_t>(arg1)->m_n;
263266
int64_t b = ASR::down_cast<ASR::ConstantInteger_t>(arg2)->m_n;
@@ -271,8 +274,35 @@ struct PythonIntrinsicProcedures {
271274
return ASR::down_cast<ASR::expr_t>(make_ConstantInteger_t(al, loc,
272275
(int64_t)pow(a, b), int_type));
273276

277+
} else if (ASRUtils::is_real(*arg1_type) && ASRUtils::is_real(*arg2_type)) {
278+
double a = ASR::down_cast<ASR::ConstantReal_t>(arg1)->m_r;
279+
double b = ASR::down_cast<ASR::ConstantReal_t>(arg2)->m_r;
280+
if (a == 0.0 && b < 0.0) { // Zero Division
281+
throw SemanticError("0.0 cannot be raised to a negative power.", loc);
282+
}
283+
return ASR::down_cast<ASR::expr_t>(make_ConstantReal_t(al, loc,
284+
pow(a, b), real_type));
285+
286+
} else if (ASRUtils::is_integer(*arg1_type) && ASRUtils::is_real(*arg2_type)) {
287+
int64_t a = ASR::down_cast<ASR::ConstantInteger_t>(arg1)->m_n;
288+
double b = ASR::down_cast<ASR::ConstantReal_t>(arg2)->m_r;
289+
if (a == 0 && b < 0.0) { // Zero Division
290+
throw SemanticError("0.0 cannot be raised to a negative power.", loc);
291+
}
292+
return ASR::down_cast<ASR::expr_t>(make_ConstantReal_t(al, loc,
293+
pow(a, b), real_type));
294+
295+
} else if (ASRUtils::is_real(*arg1_type) && ASRUtils::is_integer(*arg2_type)) {
296+
double a = ASR::down_cast<ASR::ConstantReal_t>(arg1)->m_r;
297+
int64_t b = ASR::down_cast<ASR::ConstantInteger_t>(arg2)->m_n;
298+
if (a == 0.0 && b < 0) { // Zero Division
299+
throw SemanticError("0.0 cannot be raised to a negative power.", loc);
300+
}
301+
return ASR::down_cast<ASR::expr_t>(make_ConstantReal_t(al, loc,
302+
pow(a, b), real_type));
303+
274304
} else {
275-
throw SemanticError("The arguments to pow() must be of type integers for now.", loc);
305+
throw SemanticError("The two arguments to pow() must be of type integer or float.", loc);
276306
}
277307
}
278308

@@ -423,8 +453,14 @@ struct PythonIntrinsicProcedures {
423453
if (fabs(rv-rounded) == 0.5)
424454
rounded = 2.0*round(rv/2.0);
425455
return ASR::down_cast<ASR::expr_t>(make_ConstantInteger_t(al, loc, rounded, type));
456+
} else if (ASRUtils::is_integer(*t)) {
457+
int64_t rv = ASR::down_cast<ASR::ConstantInteger_t>(expr)->m_n;
458+
return ASR::down_cast<ASR::expr_t>(make_ConstantInteger_t(al, loc, rv, type));
459+
} else if (ASRUtils::is_logical(*t)) {
460+
int64_t rv = ASR::down_cast<ASR::ConstantLogical_t>(expr)->m_value;
461+
return ASR::down_cast<ASR::expr_t>(make_ConstantInteger_t(al, loc, rv, type));
426462
} else {
427-
throw SemanticError("round() argument must be float for now, not '" +
463+
throw SemanticError("round() argument must be float, integer, or logical for now, not '" +
428464
ASRUtils::type_to_str(t) + "'", loc);
429465
}
430466
}

src/runtime/lpython_builtin.py

Lines changed: 65 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@ def chr(i: i32) -> str:
2828
# exit(1)
2929

3030

31-
# This is an implementation for f64.
32-
# TODO: implement abs() as a generic procedure, and implement for all types
31+
#: abs() as a generic procedure.
32+
#: supported types for argument:
33+
#: i32, f32, f64, bool, c32, c64
34+
@overload
3335
def abs(x: f64) -> f64:
3436
"""
3537
Return the absolute value of `x`.
@@ -39,6 +41,35 @@ def abs(x: f64) -> f64:
3941
else:
4042
return -x
4143

44+
@overload
45+
def abs(x: f32) -> f32:
46+
if x >= 0.0:
47+
return x
48+
else:
49+
return -x
50+
51+
@overload
52+
def abs(x: i32) -> i64:
53+
if x >= 0:
54+
return x
55+
else:
56+
return -x
57+
58+
@overload
59+
def abs(b: bool) -> i32:
60+
if b:
61+
return 1
62+
else:
63+
return 0
64+
65+
@overload
66+
def abs(c: c32) -> f32:
67+
pass
68+
69+
@overload
70+
def abs(c: c64) -> f64:
71+
pass
72+
4273

4374
def str(x: i32) -> str:
4475
"""
@@ -115,12 +146,30 @@ def len(s: str) -> i32:
115146
"""
116147
pass
117148

149+
#: pow() as a generic procedure.
150+
#: supported types for arguments:
151+
#: (i32, i32), (f64, f64), (i32, f64), (f64, i32)
152+
@overload
153+
def pow(x: i32, y: i32) -> i32:
154+
"""
155+
Returns x**y.
156+
"""
157+
return x**y
118158

119-
def pow(x: i32, y: i32) -> f64:
159+
@overload
160+
def pow(x: f64, y: f64) -> f64:
120161
"""
121162
Returns x**y.
122163
"""
123-
return 1.0*x**y
164+
return x**y
165+
166+
@overload
167+
def pow(x: i32, y: f64) -> f64:
168+
return x**y
169+
170+
@overload
171+
def pow(x: f64, y: i32) -> f64:
172+
return x**y
124173

125174

126175
def int(f: f64) -> i32:
@@ -206,7 +255,10 @@ def oct(n: i32) -> str:
206255
res += _values[remainder]
207256
return prep + res[::-1]
208257

209-
258+
#: round() as a generic procedure.
259+
#: supported types for argument:
260+
#: i32, f64, bool
261+
@overload
210262
def round(value: f64) -> i32:
211263
"""
212264
Rounds a floating point number to the nearest integer.
@@ -216,6 +268,14 @@ def round(value: f64) -> i32:
216268
else:
217269
return int(value) + 1
218270

271+
@overload
272+
def round(value: i32) -> i64:
273+
return value
274+
275+
@overload
276+
def round(b: bool) -> i32:
277+
return abs(b)
278+
219279
def complex(x: f64, y: f64) -> c64:
220280
pass
221281

tests/constants1.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@ def test_ord_chr():
1919

2020

2121
def test_abs():
22-
# a: i32
23-
# a = abs(5)
24-
# a = abs(-500)
25-
# a = abs(False)
26-
# a = abs(True)
22+
a: i32
23+
a = abs(5)
24+
a = abs(-500)
25+
a = abs(False)
26+
a = abs(True)
2727
b: f32
2828
b = abs(3.45)
2929
b = abs(-5346.34)
30-
# b = abs(complex(3.45, 5.6))
30+
b = abs(complex(3.45, 5.6))
3131

3232

3333
def test_len():

tests/expr7.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
def test_pow():
2-
a: f64
2+
a: i32
33
a = pow(2, 2)
44

5-
6-
def test_pow_1(a: i32, b: i32) -> f64:
7-
res: f64
5+
def test_pow_1(a: i32, b: i32) -> i32:
6+
res: i32
87
res = pow(a, b)
98
return res
109

1110
def main0():
1211
test_pow()
13-
c: f64
12+
c: i32
1413
c = test_pow_1(1, 2)
1514

1615
main0()

tests/reference/asr-complex1-f26c460.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"outfile": null,
77
"outfile_hash": null,
88
"stdout": "asr-complex1-f26c460.stdout",
9-
"stdout_hash": "ad6d981d374283dd344f860604121d3777ad15fd84cf52982d400ef4",
9+
"stdout_hash": "219fe9f1acbedd5d18e1ed0d17c82f9a616e4b2e2b7ec8618fe43b1b",
1010
"stderr": null,
1111
"stderr_hash": null,
1212
"returncode": 0

0 commit comments

Comments
 (0)