Skip to content

Overload all existing built-in functions #250

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion integration_tests/test_builtin_abs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ltypes import f64
from ltypes import f32, f64, i32

def test_abs():
x: f64
Expand All @@ -9,5 +9,19 @@ def test_abs():
assert abs(5.5) == 5.5
assert abs(-5.5) == 5.5

x2: f32
x2 = -5.5
assert abs(x2) == 5.5

i: i32
i = -5
assert abs(i) == 5
assert abs(-1) == 1

b: bool
b = True
assert abs(b) == 1
b = False
assert abs(b) == 0

test_abs()
31 changes: 28 additions & 3 deletions integration_tests/test_builtin_pow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,39 @@ def test_pow():
assert pow(a, b) == 1
a = 2
b = -1
assert abs(pow(2, -1) - 0.5) < eps
# assert abs(pow(a, b) - 0.5) < eps
print(pow(a, b))
a = 6
b = -4
print(pow(a, b))

a1: f64
a2: f64
a1 = 4.5
a2 = 2.3
assert abs(pow(a1, a2) - 31.7971929089206) < eps
assert abs(pow(a2, a1) - 42.43998894277659) < eps

x: i32
x = 3
y: f64
y = 2.3
assert abs(pow(x, y) - 12.513502532843182) < eps
assert abs(pow(y, x) - 12.166999999999998) < eps
assert abs(pow(x, 5.5) - 420.8883462392372) < eps

assert abs(pow(2, -1) - 0.5) < eps
assert abs(pow(6, -4) - 0.0007716049382716049) < eps
# assert abs(pow(a, b) - 0.0007716049382716049) < eps
assert abs(pow(-3, -5) + 0.00411522633744856) < eps
assert abs(pow(6, -4) - 0.0007716049382716049) < eps
assert abs(pow(4.5, 2.3) - 31.7971929089206) < eps
assert abs(pow(2.3, 0.0) - 1.0) < eps
assert abs(pow(2.3, -1.5) - 0.2866871623459944) < eps
assert abs(pow(2, 3.4) - 10.556063286183154) < eps
assert abs(pow(2, -3.4) - 0.09473228540689989) < eps
assert abs(pow(3.4, 9) - 60716.99276646398) < eps
assert abs(pow(0.0, 53) - 0.0) < eps
assert pow(4, 2) == 16
assert abs(pow(-4235.0, 52) - 3.948003805985264e+188) < eps


test_pow()
13 changes: 12 additions & 1 deletion integration_tests/test_builtin_round.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ltypes import f64
from ltypes import i32, f32, f64

def test_round():
f: f64
Expand All @@ -22,5 +22,16 @@ def test_round():
assert round(50.5) == 50
assert round(56.78) == 57

i: i32
i = -5
assert round(i) == -5
assert round(4) == 4

b: bool
b = True
assert round(b) == 1
b = False
assert round(b) == 0
assert round(False) == 0

test_round()
76 changes: 56 additions & 20 deletions src/lpython/semantics/python_comptime_eval.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,26 +97,32 @@ struct PythonIntrinsicProcedures {
) {
LFORTRAN_ASSERT(ASRUtils::all_args_evaluated(args));
if (args.size() != 1) {
throw SemanticError("Intrinsic abs function accepts exactly 1 argument", loc);
throw SemanticError("abs() takes exactly one argument (" +
std::to_string(args.size()) + " given)", loc);
}
ASR::expr_t* trig_arg = args[0];
ASR::expr_t* arg = args[0];
ASR::ttype_t* t = ASRUtils::expr_type(args[0]);
if (ASR::is_a<ASR::Real_t>(*t)) {
double rv = ASR::down_cast<ASR::ConstantReal_t>(trig_arg)->m_r;
ASR::ttype_t *int_type = ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4, nullptr, 0));
ASR::ttype_t *real_type = ASRUtils::TYPE(ASR::make_Real_t(al, loc, 8, nullptr, 0));
if (ASRUtils::is_real(*t)) {
double rv = ASR::down_cast<ASR::ConstantReal_t>(arg)->m_r;
double val = std::abs(rv);
return ASR::down_cast<ASR::expr_t>(ASR::make_ConstantReal_t(al, loc, val, t));
} else if (ASR::is_a<ASR::Integer_t>(*t)) {
int64_t rv = ASR::down_cast<ASR::ConstantInteger_t>(trig_arg)->m_n;
return ASR::down_cast<ASR::expr_t>(ASR::make_ConstantReal_t(al, loc, val, real_type));
} else if (ASRUtils::is_integer(*t)) {
int64_t rv = ASR::down_cast<ASR::ConstantInteger_t>(arg)->m_n;
int64_t val = std::abs(rv);
return ASR::down_cast<ASR::expr_t>(ASR::make_ConstantInteger_t(al, loc, val, t));
} else if (ASR::is_a<ASR::Complex_t>(*t)) {
double re = ASR::down_cast<ASR::ConstantComplex_t>(trig_arg)->m_re;
double im = ASR::down_cast<ASR::ConstantComplex_t>(trig_arg)->m_im;
return ASR::down_cast<ASR::expr_t>(ASR::make_ConstantInteger_t(al, loc, val, int_type));
} else if (ASRUtils::is_logical(*t)) {
int8_t val = ASR::down_cast<ASR::ConstantLogical_t>(arg)->m_value;
return ASR::down_cast<ASR::expr_t>(ASR::make_ConstantInteger_t(al, loc, val, int_type));
} else if (ASRUtils::is_complex(*t)) {
double re = ASR::down_cast<ASR::ConstantComplex_t>(arg)->m_re;
double im = ASR::down_cast<ASR::ConstantComplex_t>(arg)->m_im;
std::complex<double> x(re, im);
double result = std::abs(x);
return ASR::down_cast<ASR::expr_t>(ASR::make_ConstantReal_t(al, loc, result, t));
return ASR::down_cast<ASR::expr_t>(ASR::make_ConstantReal_t(al, loc, result, real_type));
} else {
throw SemanticError("Argument of the abs function must be Integer, Real or Complex", loc);
throw SemanticError("Argument of the abs function must be Integer, Real, Logical or Complex", loc);
}
}

Expand Down Expand Up @@ -249,15 +255,12 @@ struct PythonIntrinsicProcedures {

static ASR::expr_t *eval_pow(Allocator &al, const Location &loc, Vec<ASR::expr_t*> &args) {
LFORTRAN_ASSERT(ASRUtils::all_args_evaluated(args));
ASR::expr_t* arg1 = ASRUtils::expr_value(args[0]);
ASR::expr_t* arg2 = ASRUtils::expr_value(args[1]);
ASR::expr_t* arg1 = args[0];
ASR::expr_t* arg2 = args[1];
ASR::ttype_t* arg1_type = ASRUtils::expr_type(arg1);
ASR::ttype_t* arg2_type = ASRUtils::expr_type(arg2);
ASR::ttype_t *int_type = ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4, nullptr, 0));
ASR::ttype_t *real_type = ASRUtils::TYPE(ASR::make_Real_t(al, loc, 8, nullptr, 0));
if (!ASRUtils::check_equal_type(arg1_type, arg2_type)) {
throw SemanticError("The arguments to pow() must have the same type.", loc);
}
if (ASRUtils::is_integer(*arg1_type) && ASRUtils::is_integer(*arg2_type)) {
int64_t a = ASR::down_cast<ASR::ConstantInteger_t>(arg1)->m_n;
int64_t b = ASR::down_cast<ASR::ConstantInteger_t>(arg2)->m_n;
Expand All @@ -271,8 +274,35 @@ struct PythonIntrinsicProcedures {
return ASR::down_cast<ASR::expr_t>(make_ConstantInteger_t(al, loc,
(int64_t)pow(a, b), int_type));

} else if (ASRUtils::is_real(*arg1_type) && ASRUtils::is_real(*arg2_type)) {
double a = ASR::down_cast<ASR::ConstantReal_t>(arg1)->m_r;
double b = ASR::down_cast<ASR::ConstantReal_t>(arg2)->m_r;
if (a == 0.0 && b < 0.0) { // Zero Division
throw SemanticError("0.0 cannot be raised to a negative power.", loc);
}
return ASR::down_cast<ASR::expr_t>(make_ConstantReal_t(al, loc,
pow(a, b), real_type));

} else if (ASRUtils::is_integer(*arg1_type) && ASRUtils::is_real(*arg2_type)) {
int64_t a = ASR::down_cast<ASR::ConstantInteger_t>(arg1)->m_n;
double b = ASR::down_cast<ASR::ConstantReal_t>(arg2)->m_r;
if (a == 0 && b < 0.0) { // Zero Division
throw SemanticError("0.0 cannot be raised to a negative power.", loc);
}
return ASR::down_cast<ASR::expr_t>(make_ConstantReal_t(al, loc,
pow(a, b), real_type));

} else if (ASRUtils::is_real(*arg1_type) && ASRUtils::is_integer(*arg2_type)) {
double a = ASR::down_cast<ASR::ConstantReal_t>(arg1)->m_r;
int64_t b = ASR::down_cast<ASR::ConstantInteger_t>(arg2)->m_n;
if (a == 0.0 && b < 0) { // Zero Division
throw SemanticError("0.0 cannot be raised to a negative power.", loc);
}
return ASR::down_cast<ASR::expr_t>(make_ConstantReal_t(al, loc,
pow(a, b), real_type));

} else {
throw SemanticError("The arguments to pow() must be of type integers for now.", loc);
throw SemanticError("The two arguments to pow() must be of type integer or float.", loc);
}
}

Expand Down Expand Up @@ -423,8 +453,14 @@ struct PythonIntrinsicProcedures {
if (fabs(rv-rounded) == 0.5)
rounded = 2.0*round(rv/2.0);
return ASR::down_cast<ASR::expr_t>(make_ConstantInteger_t(al, loc, rounded, type));
} else if (ASRUtils::is_integer(*t)) {
int64_t rv = ASR::down_cast<ASR::ConstantInteger_t>(expr)->m_n;
return ASR::down_cast<ASR::expr_t>(make_ConstantInteger_t(al, loc, rv, type));
} else if (ASRUtils::is_logical(*t)) {
int64_t rv = ASR::down_cast<ASR::ConstantLogical_t>(expr)->m_value;
return ASR::down_cast<ASR::expr_t>(make_ConstantInteger_t(al, loc, rv, type));
} else {
throw SemanticError("round() argument must be float for now, not '" +
throw SemanticError("round() argument must be float, integer, or logical for now, not '" +
ASRUtils::type_to_str(t) + "'", loc);
}
}
Expand Down
70 changes: 65 additions & 5 deletions src/runtime/lpython_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ def chr(i: i32) -> str:
# exit(1)


# This is an implementation for f64.
# TODO: implement abs() as a generic procedure, and implement for all types
#: abs() as a generic procedure.
#: supported types for argument:
#: i32, f32, f64, bool, c32, c64
@overload
def abs(x: f64) -> f64:
"""
Return the absolute value of `x`.
Expand All @@ -39,6 +41,35 @@ def abs(x: f64) -> f64:
else:
return -x

@overload
def abs(x: f32) -> f32:
if x >= 0.0:
return x
else:
return -x

@overload
def abs(x: i32) -> i64:
if x >= 0:
return x
else:
return -x

@overload
def abs(b: bool) -> i32:
if b:
return 1
else:
return 0

@overload
def abs(c: c32) -> f32:
pass

@overload
def abs(c: c64) -> f64:
pass


def str(x: i32) -> str:
"""
Expand Down Expand Up @@ -115,12 +146,30 @@ def len(s: str) -> i32:
"""
pass

#: pow() as a generic procedure.
#: supported types for arguments:
#: (i32, i32), (f64, f64), (i32, f64), (f64, i32)
@overload
def pow(x: i32, y: i32) -> i32:
"""
Returns x**y.
"""
return x**y

def pow(x: i32, y: i32) -> f64:
@overload
def pow(x: f64, y: f64) -> f64:
"""
Returns x**y.
"""
return 1.0*x**y
return x**y

@overload
def pow(x: i32, y: f64) -> f64:
return x**y

@overload
def pow(x: f64, y: i32) -> f64:
return x**y


def int(f: f64) -> i32:
Expand Down Expand Up @@ -206,7 +255,10 @@ def oct(n: i32) -> str:
res += _values[remainder]
return prep + res[::-1]


#: round() as a generic procedure.
#: supported types for argument:
#: i32, f64, bool
@overload
def round(value: f64) -> i32:
"""
Rounds a floating point number to the nearest integer.
Expand All @@ -216,6 +268,14 @@ def round(value: f64) -> i32:
else:
return int(value) + 1

@overload
def round(value: i32) -> i64:
return value

@overload
def round(b: bool) -> i32:
return abs(b)

def complex(x: f64, y: f64) -> c64:
pass

Expand Down
12 changes: 6 additions & 6 deletions tests/constants1.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ def test_ord_chr():


def test_abs():
# a: i32
# a = abs(5)
# a = abs(-500)
# a = abs(False)
# a = abs(True)
a: i32
a = abs(5)
a = abs(-500)
a = abs(False)
a = abs(True)
b: f32
b = abs(3.45)
b = abs(-5346.34)
# b = abs(complex(3.45, 5.6))
b = abs(complex(3.45, 5.6))


def test_len():
Expand Down
9 changes: 4 additions & 5 deletions tests/expr7.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
def test_pow():
a: f64
a: i32
a = pow(2, 2)


def test_pow_1(a: i32, b: i32) -> f64:
res: f64
def test_pow_1(a: i32, b: i32) -> i32:
res: i32
res = pow(a, b)
return res

def main0():
test_pow()
c: f64
c: i32
c = test_pow_1(1, 2)

main0()
2 changes: 1 addition & 1 deletion tests/reference/asr-complex1-f26c460.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"outfile": null,
"outfile_hash": null,
"stdout": "asr-complex1-f26c460.stdout",
"stdout_hash": "ad6d981d374283dd344f860604121d3777ad15fd84cf52982d400ef4",
"stdout_hash": "219fe9f1acbedd5d18e1ed0d17c82f9a616e4b2e2b7ec8618fe43b1b",
"stderr": null,
"stderr_hash": null,
"returncode": 0
Expand Down
Loading