Skip to content

Commit bcf8b25

Browse files
authored
Merge pull request #876 from Madhav2310/pow_improvements
Modulo(3rd argument) added to pow
2 parents d766dcf + 2b2aa9a commit bcf8b25

File tree

52 files changed

+224
-53
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+224
-53
lines changed

integration_tests/test_builtin_pow.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,15 @@ def test_pow():
7373
assert pow(4, 2) == 16
7474
assert abs(pow(-4235.0, 52) - 3.948003805985264e+188) < eps
7575

76+
i: i64
77+
i = 7
78+
j: i64
79+
j = 2
80+
k: i64
81+
k = 5
82+
assert pow(i, j, k) == 4
83+
assert pow(102, 3, 121) == 38
84+
7685
c1: c32
7786
c1 = complex(4, 5)
7887
c1 = pow(c1, 4)

integration_tests/test_math.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def test_pow():
7272
a = 2
7373
b: i64
7474
b = 4
75-
assert abs(pow(a, b) - 16) < eps
75+
assert pow(a, b) == 16
7676

7777
def test_ldexp():
7878
i: f64

src/lpython/semantics/python_comptime_eval.h

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,16 @@ struct PythonIntrinsicProcedures {
207207
ASR::expr_t* arg2 = args[1];
208208
ASR::ttype_t* arg1_type = ASRUtils::expr_type(arg1);
209209
ASR::ttype_t* arg2_type = ASRUtils::expr_type(arg2);
210+
int64_t mod_by = -1;
211+
if (args.size() == 3) {
212+
ASR::expr_t* arg3 = args[2];
213+
ASR::ttype_t* arg3_type = ASRUtils::expr_type(arg3);
214+
if (!ASRUtils::is_integer(*arg3_type) ) { // Zero Division
215+
throw SemanticError("Third argument must be an integer. Found: " + \
216+
ASRUtils::type_to_str_python(arg3_type), loc);
217+
}
218+
mod_by = ASR::down_cast<ASR::IntegerConstant_t>(arg3)->m_n;
219+
}
210220
ASR::ttype_t *int_type = ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4, nullptr, 0));
211221
ASR::ttype_t *real_type = ASRUtils::TYPE(ASR::make_Real_t(al, loc, 8, nullptr, 0));
212222
ASR::ttype_t *complex_type = ASRUtils::TYPE(ASR::make_Complex_t(al, loc, 8, nullptr, 0));
@@ -219,9 +229,16 @@ struct PythonIntrinsicProcedures {
219229
if (b < 0) // Negative power
220230
return ASR::down_cast<ASR::expr_t>(make_RealConstant_t(al, loc,
221231
pow(a, b), real_type));
222-
else // Positive power
223-
return ASR::down_cast<ASR::expr_t>(make_IntegerConstant_t(al, loc,
224-
(int64_t)pow(a, b), int_type));
232+
else {// Positive power
233+
if (mod_by == -1)
234+
return ASR::down_cast<ASR::expr_t>(make_IntegerConstant_t(al, loc,
235+
(int64_t)pow(a, b), int_type));
236+
else {
237+
int64_t res = (int64_t)pow(a, b);
238+
return ASR::down_cast<ASR::expr_t>(make_IntegerConstant_t(al, loc,
239+
res % mod_by, int_type));
240+
}
241+
}
225242

226243
} else if (ASRUtils::is_real(*arg1_type) && ASRUtils::is_real(*arg2_type)) {
227244
double a = ASR::down_cast<ASR::RealConstant_t>(arg1)->m_r;

src/runtime/lpython_builtin.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,3 +541,64 @@ def min(a: f64, b: f64) -> f64:
541541
return a
542542
else:
543543
return b
544+
545+
546+
@overload
547+
def _floor(x: f64) -> i64:
548+
r: i64
549+
r = int(x)
550+
if x >= 0 or x == r:
551+
return r
552+
return r - 1
553+
554+
@overload
555+
def _floor(x: f32) -> i32:
556+
r: i32
557+
r = int(x)
558+
if x >= 0 or x == r:
559+
return r
560+
return r - 1
561+
562+
563+
@overload
564+
def _mod(a: i32, b: i32) -> i32:
565+
"""
566+
Returns a%b
567+
"""
568+
r: i32
569+
r = _floor(a/b)
570+
return a - r*b
571+
572+
573+
@overload
574+
def _mod(a: i64, b: i64) -> i64:
575+
"""
576+
Returns a%b
577+
"""
578+
r: i64
579+
r = _floor(a/b)
580+
return a - r*b
581+
582+
583+
@overload
584+
def pow(x: i32, y: i32, z: i32) -> i32:
585+
"""
586+
Return `x` raised to the power `y`.
587+
"""
588+
if y < 0:
589+
raise ValueError('y should be nonnegative')
590+
result: i32
591+
result = _mod(x**y, z)
592+
return result
593+
594+
595+
@overload
596+
def pow(x: i64, y: i64, z: i64) -> i64:
597+
"""
598+
Return `x` raised to the power `y`.
599+
"""
600+
if y < 0:
601+
raise ValueError('y should be nonnegative')
602+
result: i64
603+
result = _mod(x**y, z)
604+
return result

tests/reference/asr-complex1-f26c460.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"outfile": null,
77
"outfile_hash": null,
88
"stdout": "asr-complex1-f26c460.stdout",
9-
"stdout_hash": "0962e62f1546143e61a7f2b8dd8f5a15f3f769accb991a8a382a1390",
9+
"stdout_hash": "c9367e1509aa99dfeede3c7379e41ca9ca30083e9f7cb4d4a44aaa04",
1010
"stderr": null,
1111
"stderr_hash": null,
1212
"returncode": 0

tests/reference/asr-complex1-f26c460.stdout

Lines changed: 1 addition & 1 deletion
Large diffs are not rendered by default.

tests/reference/asr-constants1-5828e8a.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"outfile": null,
77
"outfile_hash": null,
88
"stdout": "asr-constants1-5828e8a.stdout",
9-
"stdout_hash": "a501ee33a6c5ceb7e94b68933e664826aee6e97ec90bdb17078be9ef",
9+
"stdout_hash": "e0ab7f321cf9154c14127feb2ff3843d43456a28163ed2b117468711",
1010
"stderr": null,
1111
"stderr_hash": null,
1212
"returncode": 0

tests/reference/asr-constants1-5828e8a.stdout

Lines changed: 1 addition & 1 deletion
Large diffs are not rendered by default.

tests/reference/asr-elemental_01-b58df26.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"outfile": null,
77
"outfile_hash": null,
88
"stdout": "asr-elemental_01-b58df26.stdout",
9-
"stdout_hash": "15427f36dad87379bfdd0cd738b5f633d5b15fe517db961ca218d5db",
9+
"stdout_hash": "04c9c971491d2fcbad688ca2a88ac38cb5ae071416eb8229d2c70cde",
1010
"stderr": null,
1111
"stderr_hash": null,
1212
"returncode": 0

tests/reference/asr-elemental_01-b58df26.stdout

Lines changed: 1 addition & 1 deletion
Large diffs are not rendered by default.

tests/reference/asr-expr10-efcbb1b.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"outfile": null,
77
"outfile_hash": null,
88
"stdout": "asr-expr10-efcbb1b.stdout",
9-
"stdout_hash": "f7a29a22601c5455d7900da733b156ebd02b8cc3b74d8c3e112ec8c2",
9+
"stdout_hash": "9fff0e599599a527bebbb1bb6c4a1e8465db464704e5b03f1d8450fb",
1010
"stderr": null,
1111
"stderr_hash": null,
1212
"returncode": 0
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
(TranslationUnit (SymbolTable 1 {complex@__lpython_overloaded_13__complex: (ExternalSymbol 1 complex@__lpython_overloaded_13__complex 4 __lpython_overloaded_13__complex lpython_builtin [] __lpython_overloaded_13__complex Public), complex@__lpython_overloaded_9__complex: (ExternalSymbol 1 complex@__lpython_overloaded_9__complex 4 __lpython_overloaded_9__complex lpython_builtin [] __lpython_overloaded_9__complex Public), lpython_builtin: (IntrinsicModule lpython_builtin), main_program: (Program (SymbolTable 75 {}) main_program [] []), test_UnaryOp: (Subroutine (SymbolTable 2 {a: (Variable 2 a Local () () Default (Integer 4 []) Source Public Required .false.), b: (Variable 2 b Local () () Default (Logical 4 []) Source Public Required .false.), b1: (Variable 2 b1 Local () () Default (Logical 4 []) Source Public Required .false.), b2: (Variable 2 b2 Local () () Default (Logical 4 []) Source Public Required .false.), b3: (Variable 2 b3 Local () () Default (Logical 4 []) Source Public Required .false.), c: (Variable 2 c Local () () Default (Complex 4 []) Source Public Required .false.), complex: (ExternalSymbol 2 complex 4 complex lpython_builtin [] complex Private), f: (Variable 2 f Local () () Default (Real 4 []) Source Public Required .false.)}) test_UnaryOp [] [(= (Var 2 a) (IntegerConstant 4 (Integer 4 [])) ()) (= (Var 2 a) (IntegerUnaryMinus (IntegerConstant 500 (Integer 4 [])) (Integer 4 []) (IntegerConstant -500 (Integer 4 []))) ()) (= (Var 2 a) (IntegerBitNot (IntegerConstant 5 (Integer 4 [])) (Integer 4 []) (IntegerConstant -6 (Integer 4 []))) ()) (= (Var 2 b) (LogicalNot (Cast (IntegerConstant 5 (Integer 4 [])) IntegerToLogical (Logical 4 []) (LogicalConstant .false. (Logical 4 []))) (Logical 4 []) (LogicalConstant .false. (Logical 4 []))) ()) (= (Var 2 b) (LogicalNot (Cast (IntegerUnaryMinus (IntegerConstant 1 (Integer 4 [])) (Integer 4 []) (IntegerConstant -1 (Integer 4 []))) IntegerToLogical (Logical 4 []) (LogicalConstant .false. (Logical 4 []))) (Logical 4 []) (LogicalConstant .false. (Logical 4 []))) ()) (= (Var 2 b) (LogicalNot (Cast (IntegerConstant 0 (Integer 4 [])) IntegerToLogical (Logical 4 []) (LogicalConstant .true. (Logical 4 []))) (Logical 4 []) (LogicalConstant .true. (Logical 4 []))) ()) (= (Var 2 f) (Cast (RealConstant 1.00000000000000000e+00 (Real 8 [])) RealToReal (Real 4 []) (RealConstant 1.00000000000000000e+00 (Real 4 []))) ()) (= (Var 2 f) (Cast (RealUnaryMinus (RealConstant 1.83745534000000014e+05 (Real 8 [])) (Real 8 []) (RealConstant -1.83745534000000014e+05 (Real 8 []))) RealToReal (Real 4 []) (RealConstant -1.83745534000000014e+05 (Real 4 []))) ()) (= (Var 2 b1) (LogicalConstant .true. (Logical 4 [])) ()) (= (Var 2 b2) (LogicalNot (LogicalConstant .false. (Logical 4 [])) (Logical 4 []) (LogicalConstant .true. (Logical 4 []))) ()) (= (Var 2 b3) (LogicalNot (Var 2 b2) (Logical 4 []) ()) ()) (= (Var 2 a) (IntegerConstant 1 (Integer 4 [])) ()) (= (Var 2 a) (IntegerUnaryMinus (Cast (LogicalConstant .false. (Logical 4 [])) LogicalToInteger (Integer 4 []) (IntegerConstant 0 (Integer 4 []))) (Integer 4 []) (IntegerConstant 0 (Integer 4 []))) ()) (= (Var 2 a) (IntegerBitNot (Cast (LogicalConstant .true. (Logical 4 [])) LogicalToInteger (Integer 4 []) (IntegerConstant -2 (Integer 4 []))) (Integer 4 []) (IntegerConstant -2 (Integer 4 []))) ()) (= (Var 2 c) (Cast (ComplexConstant 1.00000000000000000e+00 2.00000000000000000e+00 (Complex 8 [])) ComplexToComplex (Complex 4 []) (ComplexConstant 1.00000000000000000e+00 2.00000000000000000e+00 (Complex 4 []))) ()) (= (Var 2 c) (Cast (ComplexUnaryMinus (FunctionCall 1 complex@__lpython_overloaded_13__complex 2 complex [((IntegerConstant 3 (Integer 4 []))) ((RealConstant 6.50000000000000000e+01 (Real 8 [])))] (Complex 8 []) (ComplexConstant 3.00000000000000000e+00 6.50000000000000000e+01 (Complex 8 [])) ()) (Complex 8 []) (ComplexConstant -3.00000000000000000e+00 -6.50000000000000000e+01 (Complex 8 []))) ComplexToComplex (Complex 4 []) (ComplexConstant -3.00000000000000000e+00 -6.50000000000000000e+01 (Complex 4 []))) ()) (= (Var 2 b1) (LogicalConstant .false. (Logical 4 [])) ()) (= (Var 2 b2) (LogicalConstant .true. (Logical 4 [])) ())] Source Public Implementation () .false. .false.)}) [])
1+
(TranslationUnit (SymbolTable 1 {complex@__lpython_overloaded_13__complex: (ExternalSymbol 1 complex@__lpython_overloaded_13__complex 4 __lpython_overloaded_13__complex lpython_builtin [] __lpython_overloaded_13__complex Public), complex@__lpython_overloaded_9__complex: (ExternalSymbol 1 complex@__lpython_overloaded_9__complex 4 __lpython_overloaded_9__complex lpython_builtin [] __lpython_overloaded_9__complex Public), lpython_builtin: (IntrinsicModule lpython_builtin), main_program: (Program (SymbolTable 81 {}) main_program [] []), test_UnaryOp: (Subroutine (SymbolTable 2 {a: (Variable 2 a Local () () Default (Integer 4 []) Source Public Required .false.), b: (Variable 2 b Local () () Default (Logical 4 []) Source Public Required .false.), b1: (Variable 2 b1 Local () () Default (Logical 4 []) Source Public Required .false.), b2: (Variable 2 b2 Local () () Default (Logical 4 []) Source Public Required .false.), b3: (Variable 2 b3 Local () () Default (Logical 4 []) Source Public Required .false.), c: (Variable 2 c Local () () Default (Complex 4 []) Source Public Required .false.), complex: (ExternalSymbol 2 complex 4 complex lpython_builtin [] complex Private), f: (Variable 2 f Local () () Default (Real 4 []) Source Public Required .false.)}) test_UnaryOp [] [(= (Var 2 a) (IntegerConstant 4 (Integer 4 [])) ()) (= (Var 2 a) (IntegerUnaryMinus (IntegerConstant 500 (Integer 4 [])) (Integer 4 []) (IntegerConstant -500 (Integer 4 []))) ()) (= (Var 2 a) (IntegerBitNot (IntegerConstant 5 (Integer 4 [])) (Integer 4 []) (IntegerConstant -6 (Integer 4 []))) ()) (= (Var 2 b) (LogicalNot (Cast (IntegerConstant 5 (Integer 4 [])) IntegerToLogical (Logical 4 []) (LogicalConstant .false. (Logical 4 []))) (Logical 4 []) (LogicalConstant .false. (Logical 4 []))) ()) (= (Var 2 b) (LogicalNot (Cast (IntegerUnaryMinus (IntegerConstant 1 (Integer 4 [])) (Integer 4 []) (IntegerConstant -1 (Integer 4 []))) IntegerToLogical (Logical 4 []) (LogicalConstant .false. (Logical 4 []))) (Logical 4 []) (LogicalConstant .false. (Logical 4 []))) ()) (= (Var 2 b) (LogicalNot (Cast (IntegerConstant 0 (Integer 4 [])) IntegerToLogical (Logical 4 []) (LogicalConstant .true. (Logical 4 []))) (Logical 4 []) (LogicalConstant .true. (Logical 4 []))) ()) (= (Var 2 f) (Cast (RealConstant 1.00000000000000000e+00 (Real 8 [])) RealToReal (Real 4 []) (RealConstant 1.00000000000000000e+00 (Real 4 []))) ()) (= (Var 2 f) (Cast (RealUnaryMinus (RealConstant 1.83745534000000014e+05 (Real 8 [])) (Real 8 []) (RealConstant -1.83745534000000014e+05 (Real 8 []))) RealToReal (Real 4 []) (RealConstant -1.83745534000000014e+05 (Real 4 []))) ()) (= (Var 2 b1) (LogicalConstant .true. (Logical 4 [])) ()) (= (Var 2 b2) (LogicalNot (LogicalConstant .false. (Logical 4 [])) (Logical 4 []) (LogicalConstant .true. (Logical 4 []))) ()) (= (Var 2 b3) (LogicalNot (Var 2 b2) (Logical 4 []) ()) ()) (= (Var 2 a) (IntegerConstant 1 (Integer 4 [])) ()) (= (Var 2 a) (IntegerUnaryMinus (Cast (LogicalConstant .false. (Logical 4 [])) LogicalToInteger (Integer 4 []) (IntegerConstant 0 (Integer 4 []))) (Integer 4 []) (IntegerConstant 0 (Integer 4 []))) ()) (= (Var 2 a) (IntegerBitNot (Cast (LogicalConstant .true. (Logical 4 [])) LogicalToInteger (Integer 4 []) (IntegerConstant -2 (Integer 4 []))) (Integer 4 []) (IntegerConstant -2 (Integer 4 []))) ()) (= (Var 2 c) (Cast (ComplexConstant 1.00000000000000000e+00 2.00000000000000000e+00 (Complex 8 [])) ComplexToComplex (Complex 4 []) (ComplexConstant 1.00000000000000000e+00 2.00000000000000000e+00 (Complex 4 []))) ()) (= (Var 2 c) (Cast (ComplexUnaryMinus (FunctionCall 1 complex@__lpython_overloaded_13__complex 2 complex [((IntegerConstant 3 (Integer 4 []))) ((RealConstant 6.50000000000000000e+01 (Real 8 [])))] (Complex 8 []) (ComplexConstant 3.00000000000000000e+00 6.50000000000000000e+01 (Complex 8 [])) ()) (Complex 8 []) (ComplexConstant -3.00000000000000000e+00 -6.50000000000000000e+01 (Complex 8 []))) ComplexToComplex (Complex 4 []) (ComplexConstant -3.00000000000000000e+00 -6.50000000000000000e+01 (Complex 4 []))) ()) (= (Var 2 b1) (LogicalConstant .false. (Logical 4 [])) ()) (= (Var 2 b2) (LogicalConstant .true. (Logical 4 [])) ())] Source Public Implementation () .false. .false.)}) [])

tests/reference/asr-expr13-81bdb5a.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"outfile": null,
77
"outfile_hash": null,
88
"stdout": "asr-expr13-81bdb5a.stdout",
9-
"stdout_hash": "8404ff452194d6b7f12ddd016ebf75253391cc020712ddee2624e001",
9+
"stdout_hash": "577c708474fbfeda9c09f83756b7d22929c0644397b578eff508ad32",
1010
"stderr": null,
1111
"stderr_hash": null,
1212
"returncode": 0

0 commit comments

Comments
 (0)