diff --git a/src/libasr/ASR.asdl b/src/libasr/ASR.asdl index cb1d693d84..7e80eb3e51 100644 --- a/src/libasr/ASR.asdl +++ b/src/libasr/ASR.asdl @@ -211,9 +211,7 @@ stmt expr - = BoolOp(expr left, boolop op, expr right, ttype type, expr? value) - | BinOp(expr left, binop op, expr right, ttype type, expr? value, expr? overloaded) - | UnaryOp(unaryop op, expr operand, ttype type, expr? value) + = UnaryOp(unaryop op, expr operand, ttype type, expr? value) -- Such as: (x, y+z), (3.0, 2.0) generally not known at compile time | ComplexConstructor(expr re, expr im, ttype type, expr? value) | NamedExpr(expr target, expr value, ttype type) @@ -224,10 +222,14 @@ expr | DerivedTypeConstructor(symbol dt_sym, expr* args, ttype type) | ImpliedDoLoop(expr* values, expr var, expr start, expr end, expr? increment, ttype type, expr? value) + | IntegerBinOp(expr left, binop op, expr right, ttype type, expr? value) | IntegerConstant(int n, ttype type) | IntegerBOZ(int v, integerboz intboz_type, ttype? type) + | RealBinOp(expr left, binop op, expr right, ttype type, expr? value) | RealConstant(float r, ttype type) + | ComplexBinOp(expr left, binop op, expr right, ttype type, expr? value) | ComplexConstant(float re, float im, ttype type) + | LogicalBinOp(expr left, logicalbinop op, expr right, ttype type, expr? value) | LogicalConstant(bool value, ttype type) | ListConstant(expr* args, ttype type) @@ -256,6 +258,7 @@ expr | ArraySize(expr v, expr? dim, ttype type, expr? value) | ArrayBound(expr v, expr? dim ttype type, arraybound bound, expr? value) + | OverloadedBinOp(expr left, binop op, expr right, ttype type, expr? value, expr overloaded) | DerivedRef(expr v, symbol m, ttype type, expr? value) | Cast(expr arg, cast_kind kind, ttype type, expr? value) | ComplexRe(expr arg, ttype type, expr? value) @@ -301,10 +304,11 @@ ttype | Dict(ttype key_type, ttype value_type) | Pointer(ttype type) -boolop = And | Or | Xor | NEqv | Eqv binop = Add | Sub | Mul | Div | Pow +logicalbinop = And | Or | Xor | NEqv | Eqv + unaryop = Invert | Not | UAdd | USub cmpop = Eq | NotEq | Lt | LtE | Gt | GtE diff --git a/src/libasr/asr_utils.h b/src/libasr/asr_utils.h index a097051edb..5bf172c300 100644 --- a/src/libasr/asr_utils.h +++ b/src/libasr/asr_utils.h @@ -93,8 +93,10 @@ static inline ASR::ttype_t* expr_type(const ASR::expr_t *f) { LFORTRAN_ASSERT(f != nullptr); switch (f->type) { - case ASR::exprType::BoolOp: { return ((ASR::BoolOp_t*)f)->m_type; } - case ASR::exprType::BinOp: { return ((ASR::BinOp_t*)f)->m_type; } + case ASR::exprType::LogicalBinOp: { return ((ASR::LogicalBinOp_t*)f)->m_type; } + case ASR::exprType::IntegerBinOp: { return ((ASR::IntegerBinOp_t*)f)->m_type; } + case ASR::exprType::RealBinOp: { return ((ASR::RealBinOp_t*)f)->m_type; } + case ASR::exprType::ComplexBinOp: { return ((ASR::ComplexBinOp_t*)f)->m_type; } case ASR::exprType::UnaryOp: { return ((ASR::UnaryOp_t*)f)->m_type; } case ASR::exprType::ComplexConstructor: { return ((ASR::ComplexConstructor_t*)f)->m_type; } case ASR::exprType::NamedExpr: { return ((ASR::NamedExpr_t*)f)->m_type; } @@ -272,12 +274,12 @@ static inline std::string cmpop_to_str(const ASR::cmpopType t) { } } -static inline std::string boolop_to_str(const ASR::boolopType t) { +static inline std::string logicalbinop_to_str(const ASR::logicalbinopType t) { switch (t) { - case (ASR::boolopType::And): { return " && "; } - case (ASR::boolopType::Or): { return " || "; } - case (ASR::boolopType::Eqv): { return " == "; } - case (ASR::boolopType::NEqv): { return " != "; } + case (ASR::logicalbinopType::And): { return " && "; } + case (ASR::logicalbinopType::Or): { return " || "; } + case (ASR::logicalbinopType::Eqv): { return " == "; } + case (ASR::logicalbinopType::NEqv): { return " != "; } default : throw LFortranException("Cannot represent the boolean operator as a string"); } } @@ -285,8 +287,10 @@ static inline std::string boolop_to_str(const ASR::boolopType t) { static inline ASR::expr_t* expr_value(ASR::expr_t *f) { switch (f->type) { - case ASR::exprType::BoolOp: { return ASR::down_cast(f)->m_value; } - case ASR::exprType::BinOp: { return ASR::down_cast(f)->m_value; } + case ASR::exprType::LogicalBinOp: { return ASR::down_cast(f)->m_value; } + case ASR::exprType::IntegerBinOp: { return ASR::down_cast(f)->m_value; } + case ASR::exprType::RealBinOp: { return ASR::down_cast(f)->m_value; } + case ASR::exprType::ComplexBinOp: { return ASR::down_cast(f)->m_value; } case ASR::exprType::UnaryOp: { return ASR::down_cast(f)->m_value; } case ASR::exprType::ComplexConstructor: { return ASR::down_cast(f)->m_value; } case ASR::exprType::Compare: { return ASR::down_cast(f)->m_value; } @@ -1070,7 +1074,7 @@ inline bool is_same_type_pointer(ASR::ttype_t* source, ASR::ttype_t* dest) { a_len = -3; break; } - case ASR::exprType::BinOp: { + case ASR::exprType::IntegerBinOp: { a_len = -3; break; } diff --git a/src/libasr/codegen/asr_to_c_cpp.h b/src/libasr/codegen/asr_to_c_cpp.h index 0a5eb0c170..b40028c30f 100644 --- a/src/libasr/codegen/asr_to_c_cpp.h +++ b/src/libasr/codegen/asr_to_c_cpp.h @@ -663,7 +663,20 @@ R"(#include } } - void visit_BinOp(const ASR::BinOp_t &x) { + void visit_IntegerBinOp(const ASR::IntegerBinOp_t &x) { + handle_BinOp(x); + } + + void visit_RealBinOp(const ASR::RealBinOp_t &x) { + handle_BinOp(x); + } + + void visit_ComplexBinOp(const ASR::ComplexBinOp_t &x) { + handle_BinOp(x); + } + + template + void handle_BinOp(const T &x) { self().visit_expr(*x.m_left); std::string left = std::move(src); int left_precedence = last_expr_precedence; @@ -710,7 +723,7 @@ R"(#include } } - void visit_BoolOp(const ASR::BoolOp_t &x) { + void visit_LogicalBinOp(const ASR::LogicalBinOp_t &x) { self().visit_expr(*x.m_left); std::string left = std::move(src); int left_precedence = last_expr_precedence; @@ -718,19 +731,19 @@ R"(#include std::string right = std::move(src); int right_precedence = last_expr_precedence; switch (x.m_op) { - case (ASR::boolopType::And): { + case (ASR::logicalbinopType::And): { last_expr_precedence = 14; break; } - case (ASR::boolopType::Or): { + case (ASR::logicalbinopType::Or): { last_expr_precedence = 15; break; } - case (ASR::boolopType::NEqv): { + case (ASR::logicalbinopType::NEqv): { last_expr_precedence = 10; break; } - case (ASR::boolopType::Eqv): { + case (ASR::logicalbinopType::Eqv): { last_expr_precedence = 10; break; } @@ -742,7 +755,7 @@ R"(#include } else { src += "(" + left + ")"; } - src += ASRUtils::boolop_to_str(x.m_op); + src += ASRUtils::logicalbinop_to_str(x.m_op); if (right_precedence <= last_expr_precedence) { src += right; } else { diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index 4060b59821..09cc40056e 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -2878,7 +2878,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor start_new_block(target); } - void visit_BoolOp(const ASR::BoolOp_t &x) { + void visit_LogicalBinOp(const ASR::LogicalBinOp_t &x) { if (x.m_value) { this->visit_expr_wrapper(x.m_value, true); return; @@ -2889,23 +2889,23 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor llvm::Value *right_val = tmp; if (x.m_type->type == ASR::ttypeType::Logical) { switch (x.m_op) { - case ASR::boolopType::And: { + case ASR::logicalbinopType::And: { tmp = builder->CreateAnd(left_val, right_val); break; }; - case ASR::boolopType::Or: { + case ASR::logicalbinopType::Or: { tmp = builder->CreateOr(left_val, right_val); break; }; - case ASR::boolopType::Xor: { + case ASR::logicalbinopType::Xor: { tmp = builder->CreateXor(left_val, right_val); break; }; - case ASR::boolopType::NEqv: { + case ASR::logicalbinopType::NEqv: { tmp = builder->CreateXor(left_val, right_val); break; }; - case ASR::boolopType::Eqv: { + case ASR::logicalbinopType::Eqv: { tmp = builder->CreateXor(left_val, right_val); tmp = builder->CreateNot(tmp); }; @@ -2950,11 +2950,60 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor tmp = lfortran_str_len(parg); } - void visit_BinOp(const ASR::BinOp_t &x) { - if( x.m_overloaded ) { - this->visit_expr(*x.m_overloaded); - return ; + void visit_IntegerBinOp(const ASR::IntegerBinOp_t &x) { + if (x.m_value) { + this->visit_expr_wrapper(x.m_value, true); + return; + } + this->visit_expr_wrapper(x.m_left, true); + llvm::Value *left_val = tmp; + this->visit_expr_wrapper(x.m_right, true); + llvm::Value *right_val = tmp; + LFORTRAN_ASSERT(ASRUtils::is_integer(*x.m_type)) + switch (x.m_op) { + case ASR::binopType::Add: { + tmp = builder->CreateAdd(left_val, right_val); + break; + }; + case ASR::binopType::Sub: { + tmp = builder->CreateSub(left_val, right_val); + break; + }; + case ASR::binopType::Mul: { + tmp = builder->CreateMul(left_val, right_val); + break; + }; + case ASR::binopType::Div: { + tmp = builder->CreateUDiv(left_val, right_val); + break; + }; + case ASR::binopType::Pow: { + llvm::Type *type; + int a_kind; + a_kind = down_cast(ASRUtils::type_get_past_pointer(x.m_type))->m_kind; + type = getFPType(a_kind); + llvm::Value *fleft = builder->CreateSIToFP(left_val, + type); + llvm::Value *fright = builder->CreateSIToFP(right_val, + type); + std::string func_name = a_kind == 4 ? "llvm.pow.f32" : "llvm.pow.f64"; + llvm::Function *fn_pow = module->getFunction(func_name); + if (!fn_pow) { + llvm::FunctionType *function_type = llvm::FunctionType::get( + type, { type, type}, false); + fn_pow = llvm::Function::Create(function_type, + llvm::Function::ExternalLinkage, func_name, + module.get()); + } + tmp = builder->CreateCall(fn_pow, {fleft, fright}); + type = getIntType(a_kind); + tmp = builder->CreateFPToSI(tmp, type); + break; + }; } + } + + void visit_RealBinOp(const ASR::RealBinOp_t &x) { if (x.m_value) { this->visit_expr_wrapper(x.m_value, true); return; @@ -2963,142 +3012,108 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor llvm::Value *left_val = tmp; this->visit_expr_wrapper(x.m_right, true); llvm::Value *right_val = tmp; - if (ASRUtils::is_integer(*x.m_type)) { - switch (x.m_op) { - case ASR::binopType::Add: { - tmp = builder->CreateAdd(left_val, right_val); - break; - }; - case ASR::binopType::Sub: { - tmp = builder->CreateSub(left_val, right_val); - break; - }; - case ASR::binopType::Mul: { - tmp = builder->CreateMul(left_val, right_val); - break; - }; - case ASR::binopType::Div: { - tmp = builder->CreateUDiv(left_val, right_val); - break; - }; - case ASR::binopType::Pow: { - llvm::Type *type; - int a_kind; - a_kind = down_cast(ASRUtils::type_get_past_pointer(x.m_type))->m_kind; - type = getFPType(a_kind); - llvm::Value *fleft = builder->CreateSIToFP(left_val, - type); - llvm::Value *fright = builder->CreateSIToFP(right_val, - type); - std::string func_name = a_kind == 4 ? "llvm.pow.f32" : "llvm.pow.f64"; - llvm::Function *fn_pow = module->getFunction(func_name); - if (!fn_pow) { - llvm::FunctionType *function_type = llvm::FunctionType::get( - type, { type, type}, false); - fn_pow = llvm::Function::Create(function_type, - llvm::Function::ExternalLinkage, func_name, - module.get()); - } - tmp = builder->CreateCall(fn_pow, {fleft, fright}); - type = getIntType(a_kind); - tmp = builder->CreateFPToSI(tmp, type); - break; - }; - } - } else if (ASRUtils::is_real(*x.m_type)) { - switch (x.m_op) { - case ASR::binopType::Add: { - tmp = builder->CreateFAdd(left_val, right_val); - break; - }; - case ASR::binopType::Sub: { - tmp = builder->CreateFSub(left_val, right_val); - break; - }; - case ASR::binopType::Mul: { - tmp = builder->CreateFMul(left_val, right_val); - break; - }; - case ASR::binopType::Div: { - tmp = builder->CreateFDiv(left_val, right_val); - break; - }; - case ASR::binopType::Pow: { - llvm::Type *type; - int a_kind; - a_kind = down_cast(ASRUtils::type_get_past_pointer(x.m_type))->m_kind; - type = getFPType(a_kind); - std::string func_name = a_kind == 4 ? "llvm.pow.f32" : "llvm.pow.f64"; - llvm::Function *fn_pow = module->getFunction(func_name); - if (!fn_pow) { - llvm::FunctionType *function_type = llvm::FunctionType::get( - type, { type, type }, false); - fn_pow = llvm::Function::Create(function_type, - llvm::Function::ExternalLinkage, func_name, - module.get()); - } - tmp = builder->CreateCall(fn_pow, {left_val, right_val}); - break; - }; - } - } else if (ASRUtils::is_complex(*x.m_type)) { - llvm::Type *type; - int a_kind; - a_kind = down_cast(ASRUtils::type_get_past_pointer(x.m_type))->m_kind; - type = getComplexType(a_kind); - if( left_val->getType()->isPointerTy() ) { - left_val = CreateLoad(left_val); - } - if( right_val->getType()->isPointerTy() ) { - right_val = CreateLoad(right_val); - } - std::string fn_name; - switch (x.m_op) { - case ASR::binopType::Add: { - if (a_kind == 4) { - fn_name = "_lfortran_complex_add_32"; - } else { - fn_name = "_lfortran_complex_add_64"; - } - break; - }; - case ASR::binopType::Sub: { - if (a_kind == 4) { - fn_name = "_lfortran_complex_sub_32"; - } else { - fn_name = "_lfortran_complex_sub_64"; - } - break; - }; - case ASR::binopType::Mul: { - if (a_kind == 4) { - fn_name = "_lfortran_complex_mul_32"; - } else { - fn_name = "_lfortran_complex_mul_64"; - } - break; - }; - case ASR::binopType::Div: { - if (a_kind == 4) { - fn_name = "_lfortran_complex_div_32"; - } else { - fn_name = "_lfortran_complex_div_64"; - } - break; - }; - case ASR::binopType::Pow: { - if (a_kind == 4) { - fn_name = "_lfortran_complex_pow_32"; - } else { - fn_name = "_lfortran_complex_pow_64"; - } - break; - }; - } - tmp = lfortran_complex_bin_op(left_val, right_val, fn_name, type); - } else { - throw CodeGenError("Binop: Only Real, Integer and Complex types are allowed"); + LFORTRAN_ASSERT(ASRUtils::is_real(*x.m_type)) + switch (x.m_op) { + case ASR::binopType::Add: { + tmp = builder->CreateFAdd(left_val, right_val); + break; + }; + case ASR::binopType::Sub: { + tmp = builder->CreateFSub(left_val, right_val); + break; + }; + case ASR::binopType::Mul: { + tmp = builder->CreateFMul(left_val, right_val); + break; + }; + case ASR::binopType::Div: { + tmp = builder->CreateFDiv(left_val, right_val); + break; + }; + case ASR::binopType::Pow: { + llvm::Type *type; + int a_kind; + a_kind = down_cast(ASRUtils::type_get_past_pointer(x.m_type))->m_kind; + type = getFPType(a_kind); + std::string func_name = a_kind == 4 ? "llvm.pow.f32" : "llvm.pow.f64"; + llvm::Function *fn_pow = module->getFunction(func_name); + if (!fn_pow) { + llvm::FunctionType *function_type = llvm::FunctionType::get( + type, { type, type }, false); + fn_pow = llvm::Function::Create(function_type, + llvm::Function::ExternalLinkage, func_name, + module.get()); + } + tmp = builder->CreateCall(fn_pow, {left_val, right_val}); + break; + }; + } + } + + void visit_ComplexBinOp(const ASR::ComplexBinOp_t &x) { + if (x.m_value) { + this->visit_expr_wrapper(x.m_value, true); + return; + } + this->visit_expr_wrapper(x.m_left, true); + llvm::Value *left_val = tmp; + this->visit_expr_wrapper(x.m_right, true); + llvm::Value *right_val = tmp; + LFORTRAN_ASSERT(ASRUtils::is_complex(*x.m_type)); + llvm::Type *type; + int a_kind; + a_kind = down_cast(ASRUtils::type_get_past_pointer(x.m_type))->m_kind; + type = getComplexType(a_kind); + if( left_val->getType()->isPointerTy() ) { + left_val = CreateLoad(left_val); + } + if( right_val->getType()->isPointerTy() ) { + right_val = CreateLoad(right_val); + } + std::string fn_name; + switch (x.m_op) { + case ASR::binopType::Add: { + if (a_kind == 4) { + fn_name = "_lfortran_complex_add_32"; + } else { + fn_name = "_lfortran_complex_add_64"; + } + break; + }; + case ASR::binopType::Sub: { + if (a_kind == 4) { + fn_name = "_lfortran_complex_sub_32"; + } else { + fn_name = "_lfortran_complex_sub_64"; + } + break; + }; + case ASR::binopType::Mul: { + if (a_kind == 4) { + fn_name = "_lfortran_complex_mul_32"; + } else { + fn_name = "_lfortran_complex_mul_64"; + } + break; + }; + case ASR::binopType::Div: { + if (a_kind == 4) { + fn_name = "_lfortran_complex_div_32"; + } else { + fn_name = "_lfortran_complex_div_64"; + } + break; + }; + case ASR::binopType::Pow: { + if (a_kind == 4) { + fn_name = "_lfortran_complex_pow_32"; + } else { + fn_name = "_lfortran_complex_pow_64"; + } + break; + }; } + tmp = lfortran_complex_bin_op(left_val, right_val, fn_name, type); } void visit_UnaryOp(const ASR::UnaryOp_t &x) { diff --git a/src/libasr/codegen/asr_to_wasm.cpp b/src/libasr/codegen/asr_to_wasm.cpp index 5cbbb974c8..9b14f3f1d4 100644 --- a/src/libasr/codegen/asr_to_wasm.cpp +++ b/src/libasr/codegen/asr_to_wasm.cpp @@ -175,35 +175,31 @@ class ASRToWASMVisitor : public ASR::BaseVisitor { } } - void visit_BinOp(const ASR::BinOp_t &x) { + void visit_IntegerBinOp(const ASR::IntegerBinOp_t &x) { this->visit_expr(*x.m_left); this->visit_expr(*x.m_right); - if (ASRUtils::is_integer(*x.m_type)) { - switch (x.m_op) { - case ASR::binopType::Add: { - wasm::emit_i32_add(m_code_section, m_al); - break; - }; - case ASR::binopType::Sub: { - wasm::emit_i32_sub(m_code_section, m_al); - break; - }; - case ASR::binopType::Mul: { - wasm::emit_i32_mul(m_code_section, m_al); - break; - }; - case ASR::binopType::Div: { - wasm::emit_i32_div(m_code_section, m_al); - break; - }; - default: - throw CodeGenError( - "Binop: Pow Operation not yet implemented"); - } - } else { - throw CodeGenError("Binop: Only Integer type implemented"); + switch (x.m_op) { + case ASR::binopType::Add: { + wasm::emit_i32_add(m_code_section, m_al); + break; + }; + case ASR::binopType::Sub: { + wasm::emit_i32_sub(m_code_section, m_al); + break; + }; + case ASR::binopType::Mul: { + wasm::emit_i32_mul(m_code_section, m_al); + break; + }; + case ASR::binopType::Div: { + wasm::emit_i32_div(m_code_section, m_al); + break; + }; + default: + throw CodeGenError( + "Binop: Pow Operation not yet implemented"); } } diff --git a/src/libasr/codegen/asr_to_x86.cpp b/src/libasr/codegen/asr_to_x86.cpp index 03327b2eab..d9a2651ba1 100644 --- a/src/libasr/codegen/asr_to_x86.cpp +++ b/src/libasr/codegen/asr_to_x86.cpp @@ -302,7 +302,7 @@ class ASRToX86Visitor : public ASR::BaseVisitor } } - void visit_BinOp(const ASR::BinOp_t &x) { + void visit_IntegerBinOp(const ASR::IntegerBinOp_t &x) { this->visit_expr(*x.m_right); m_a.asm_push_r32(X86Reg::eax); this->visit_expr(*x.m_left); diff --git a/src/libasr/pass/arr_slice.cpp b/src/libasr/pass/arr_slice.cpp index c58ccac9e0..b9df99215f 100644 --- a/src/libasr/pass/arr_slice.cpp +++ b/src/libasr/pass/arr_slice.cpp @@ -83,18 +83,15 @@ class ArrSliceVisitor : public PassUtils::PassVisitor end = PassUtils::to_int32(end, int32_type, al); step = PassUtils::to_int32(step, int32_type, al); - ASR::expr_t* gap = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, x.base.base.loc, + ASR::expr_t* gap = LFortran::ASRUtils::EXPR(ASR::make_IntegerBinOp_t(al, x.base.base.loc, end, ASR::binopType::Sub, start, - int32_type, nullptr, nullptr)); - // ASR::expr_t* slice_size = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, x.base.base.loc, - // gap, ASR::binopType::Add, const_1, - // int64_type, nullptr)); - ASR::expr_t* slice_size = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, x.base.base.loc, + int32_type, nullptr)); + ASR::expr_t* slice_size = LFortran::ASRUtils::EXPR(ASR::make_IntegerBinOp_t(al, x.base.base.loc, gap, ASR::binopType::Div, step, - int32_type, nullptr, nullptr)); - ASR::expr_t* actual_size = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, x.base.base.loc, + int32_type, nullptr)); + ASR::expr_t* actual_size = LFortran::ASRUtils::EXPR(ASR::make_IntegerBinOp_t(al, x.base.base.loc, slice_size, ASR::binopType::Add, const_1, - int32_type, nullptr, nullptr)); + int32_type, nullptr)); ASR::dimension_t curr_dim; curr_dim.loc = x.base.base.loc; curr_dim.m_start = const_1; @@ -205,7 +202,7 @@ class ArrSliceVisitor : public PassUtils::PassVisitor doloop_body.push_back(al, set_to_one); doloop_body.push_back(al, doloop); } - ASR::expr_t* inc_expr = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, x.base.base.loc, idx_vars_target[i], ASR::binopType::Add, const_1, int32_type, nullptr, nullptr)); + ASR::expr_t* inc_expr = LFortran::ASRUtils::EXPR(ASR::make_IntegerBinOp_t(al, x.base.base.loc, idx_vars_target[i], ASR::binopType::Add, const_1, int32_type, nullptr)); ASR::stmt_t* assign_stmt = LFortran::ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, idx_vars_target[i], inc_expr, nullptr)); doloop_body.push_back(al, assign_stmt); doloop = LFortran::ASRUtils::STMT(ASR::make_DoLoop_t(al, x.base.base.loc, head, doloop_body.p, doloop_body.size())); @@ -226,8 +223,9 @@ class ArrSliceVisitor : public PassUtils::PassVisitor } } - void visit_BinOp(const ASR::BinOp_t& x) { - ASR::BinOp_t& xx = const_cast(x); + template + void visit_BinOp(const T& x) { + T& xx = const_cast(x); create_slice_var = true; slice_var = nullptr; this->visit_expr(*x.m_left); @@ -243,6 +241,19 @@ class ArrSliceVisitor : public PassUtils::PassVisitor create_slice_var = false; } + void visit_IntegerBinOp(const ASR::IntegerBinOp_t& x) { + visit_BinOp(x); + } + void visit_RealBinOp(const ASR::RealBinOp_t& x) { + visit_BinOp(x); + } + void visit_ComplexBinOp(const ASR::ComplexBinOp_t& x) { + visit_BinOp(x); + } + void visit_LogicalBinOp(const ASR::LogicalBinOp_t& x) { + visit_BinOp(x); + } + void visit_Print(const ASR::Print_t& x) { ASR::Print_t& xx = const_cast(x); for( size_t i = 0; i < xx.n_values; i++ ) { diff --git a/src/libasr/pass/array_op.cpp b/src/libasr/pass/array_op.cpp index 3d66c14d17..6aaad2fae7 100644 --- a/src/libasr/pass/array_op.cpp +++ b/src/libasr/pass/array_op.cpp @@ -530,21 +530,33 @@ class ArrayOpVisitor : public PassUtils::PassVisitor ASR::expr_t* res = PassUtils::create_array_ref(result_var, idx_vars, al); ASR::expr_t* op_el_wise = nullptr; switch( x.class_type ) { - case ASR::exprType::BinOp: - op_el_wise = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t( + case ASR::exprType::IntegerBinOp: + op_el_wise = LFortran::ASRUtils::EXPR(ASR::make_IntegerBinOp_t( al, x.base.base.loc, ref_1, (ASR::binopType)x.m_op, ref_2, - x.m_type, nullptr, nullptr)); + x.m_type, nullptr)); + break; + case ASR::exprType::RealBinOp: + op_el_wise = LFortran::ASRUtils::EXPR(ASR::make_RealBinOp_t( + al, x.base.base.loc, + ref_1, (ASR::binopType)x.m_op, ref_2, + x.m_type, nullptr)); + break; + case ASR::exprType::ComplexBinOp: + op_el_wise = LFortran::ASRUtils::EXPR(ASR::make_ComplexBinOp_t( + al, x.base.base.loc, + ref_1, (ASR::binopType)x.m_op, ref_2, + x.m_type, nullptr)); break; case ASR::exprType::Compare: op_el_wise = LFortran::ASRUtils::EXPR(ASR::make_Compare_t( al, x.base.base.loc, ref_1, (ASR::cmpopType)x.m_op, ref_2, x.m_type, nullptr, nullptr)); break; - case ASR::exprType::BoolOp: - op_el_wise = LFortran::ASRUtils::EXPR(ASR::make_BoolOp_t( + case ASR::exprType::LogicalBinOp: + op_el_wise = LFortran::ASRUtils::EXPR(ASR::make_LogicalBinOp_t( al, x.base.base.loc, - ref_1, (ASR::boolopType)x.m_op, ref_2, x.m_type, nullptr)); + ref_1, (ASR::logicalbinopType)x.m_op, ref_2, x.m_type, nullptr)); break; default: throw LFortranException("The desired operation is not supported yet for arrays."); @@ -556,9 +568,7 @@ class ArrayOpVisitor : public PassUtils::PassVisitor doloop_body.push_back(al, set_to_one); doloop_body.push_back(al, doloop); } - ASR::expr_t* inc_expr = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, x.base.base.loc, idx_vars_value[i], - ASR::binopType::Add, const_1, int32_type, - nullptr, nullptr)); + ASR::expr_t* inc_expr = PassUtils::create_binop_add(al, x.base.base.loc, idx_vars_value[i], const_1); ASR::stmt_t* assign_stmt = LFortran::ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, idx_vars_value[i], inc_expr, nullptr)); doloop_body.push_back(al, assign_stmt); doloop = LFortran::ASRUtils::STMT(ASR::make_DoLoop_t(al, x.base.base.loc, head, doloop_body.p, doloop_body.size())); @@ -613,21 +623,33 @@ class ArrayOpVisitor : public PassUtils::PassVisitor ASR::expr_t* res = PassUtils::create_array_ref(result_var, idx_vars, al); ASR::expr_t* op_el_wise = nullptr; switch( x.class_type ) { - case ASR::exprType::BinOp: - op_el_wise = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t( + case ASR::exprType::IntegerBinOp: + op_el_wise = LFortran::ASRUtils::EXPR(ASR::make_IntegerBinOp_t( + al, x.base.base.loc, + ref, (ASR::binopType)x.m_op, other_expr, + x.m_type, nullptr)); + break; + case ASR::exprType::RealBinOp: + op_el_wise = LFortran::ASRUtils::EXPR(ASR::make_RealBinOp_t( + al, x.base.base.loc, + ref, (ASR::binopType)x.m_op, other_expr, + x.m_type, nullptr)); + break; + case ASR::exprType::ComplexBinOp: + op_el_wise = LFortran::ASRUtils::EXPR(ASR::make_ComplexBinOp_t( al, x.base.base.loc, ref, (ASR::binopType)x.m_op, other_expr, - x.m_type, nullptr, nullptr)); + x.m_type, nullptr)); break; case ASR::exprType::Compare: op_el_wise = LFortran::ASRUtils::EXPR(ASR::make_Compare_t( al, x.base.base.loc, ref, (ASR::cmpopType)x.m_op, other_expr, x.m_type, nullptr, nullptr)); break; - case ASR::exprType::BoolOp: - op_el_wise = LFortran::ASRUtils::EXPR(ASR::make_BoolOp_t( + case ASR::exprType::LogicalBinOp: + op_el_wise = LFortran::ASRUtils::EXPR(ASR::make_LogicalBinOp_t( al, x.base.base.loc, - ref, (ASR::boolopType)x.m_op, other_expr, x.m_type, nullptr)); + ref, (ASR::logicalbinopType)x.m_op, other_expr, x.m_type, nullptr)); break; default: throw LFortranException("The desired operation is not supported yet for arrays."); @@ -639,7 +661,7 @@ class ArrayOpVisitor : public PassUtils::PassVisitor doloop_body.push_back(al, set_to_one); doloop_body.push_back(al, doloop); } - ASR::expr_t* inc_expr = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, x.base.base.loc, idx_vars_value[i], ASR::binopType::Add, const_1, int32_type, nullptr, nullptr)); + ASR::expr_t* inc_expr = PassUtils::create_binop_add(al, x.base.base.loc, idx_vars_value[i], const_1); ASR::stmt_t* assign_stmt = LFortran::ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, idx_vars_value[i], inc_expr, nullptr)); doloop_body.push_back(al, assign_stmt); doloop = LFortran::ASRUtils::STMT(ASR::make_DoLoop_t(al, x.base.base.loc, head, doloop_body.p, doloop_body.size())); @@ -650,16 +672,24 @@ class ArrayOpVisitor : public PassUtils::PassVisitor } } - void visit_BinOp(const ASR::BinOp_t &x) { - visit_ArrayOpCommon(x, "_bin_op_res"); + void visit_IntegerBinOp(const ASR::IntegerBinOp_t &x) { + visit_ArrayOpCommon(x, "_bin_op_res"); } - void visit_Compare(const ASR::Compare_t &x) { - visit_ArrayOpCommon(x, "_comp_op_res"); + void visit_RealBinOp(const ASR::RealBinOp_t &x) { + visit_ArrayOpCommon(x, "_bin_op_res"); + } + + void visit_ComplexBinOp(const ASR::ComplexBinOp_t &x) { + visit_ArrayOpCommon(x, "_bin_op_res"); + } + + void visit_LogicalBinOp(const ASR::LogicalBinOp_t &x) { + visit_ArrayOpCommon(x, "_bool_op_res"); } - void visit_BoolOp(const ASR::BoolOp_t &x) { - visit_ArrayOpCommon(x, "_bool_op_res"); + void visit_Compare(const ASR::Compare_t &x) { + visit_ArrayOpCommon(x, "_comp_op_res"); } void visit_ArraySize(const ASR::ArraySize_t& x) { diff --git a/src/libasr/pass/div_to_mul.cpp b/src/libasr/pass/div_to_mul.cpp index 0833e18a00..a4cfeccb0e 100644 --- a/src/libasr/pass/div_to_mul.cpp +++ b/src/libasr/pass/div_to_mul.cpp @@ -46,7 +46,7 @@ class DivToMulVisitor : public PassUtils::PassVisitor pass_result.reserve(al, 1); } - void visit_BinOp(const ASR::BinOp_t& x) { + void visit_RealBinOp(const ASR::RealBinOp_t& x) { visit_expr(*x.m_left); visit_expr(*x.m_right); if( x.m_op == ASR::binopType::Div ) { @@ -66,7 +66,7 @@ class DivToMulVisitor : public PassUtils::PassVisitor break; } if( is_feasible ) { - ASR::BinOp_t& xx = const_cast(x); + ASR::RealBinOp_t& xx = const_cast(x); xx.m_op = ASR::binopType::Mul; xx.m_right = right_inverse; } diff --git a/src/libasr/pass/fma.cpp b/src/libasr/pass/fma.cpp index bffc9b790f..c2127e93c6 100644 --- a/src/libasr/pass/fma.cpp +++ b/src/libasr/pass/fma.cpp @@ -53,23 +53,25 @@ class FMAVisitor : public PassUtils::SkipOptimizationFunctionVisitor } bool is_BinOpMul(ASR::expr_t* expr) { - if( expr->type == ASR::exprType::BinOp ) { - ASR::BinOp_t* expr_binop = ASR::down_cast(expr); + if (ASR::is_a(*expr)) { + ASR::RealBinOp_t* expr_binop = ASR::down_cast(expr); return expr_binop->m_op == ASR::binopType::Mul; } return false; } - void visit_BinOp(const ASR::BinOp_t& x_const) { + void visit_IntegerBinOp(const ASR::IntegerBinOp_t& /*x*/) { } + void visit_ComplexBinOp(const ASR::ComplexBinOp_t& /*x*/) { } + void visit_LogicalBinOp(const ASR::LogicalBinOp_t& /*x*/) { } + + void visit_RealBinOp(const ASR::RealBinOp_t& x_const) { if( !from_fma ) { return ; } from_fma = true; - if( x_const.m_type->type != ASR::ttypeType::Real ) { - return ; - } - ASR::BinOp_t& x = const_cast(x_const); + LFORTRAN_ASSERT(ASRUtils::is_real(*x_const.m_type)) + ASR::RealBinOp_t& x = const_cast(x_const); fma_var = nullptr; visit_expr(*x.m_left); @@ -107,7 +109,7 @@ class FMAVisitor : public PassUtils::SkipOptimizationFunctionVisitor nullptr)); } - ASR::BinOp_t* mul_binop = ASR::down_cast(mul_expr); + ASR::RealBinOp_t* mul_binop = ASR::down_cast(mul_expr); ASR::expr_t *first_arg = mul_binop->m_left, *second_arg = mul_binop->m_right; if( is_mul_expr_negative ) { diff --git a/src/libasr/pass/implied_do_loops.cpp b/src/libasr/pass/implied_do_loops.cpp index 4047600d38..15a7fb6473 100644 --- a/src/libasr/pass/implied_do_loops.cpp +++ b/src/libasr/pass/implied_do_loops.cpp @@ -62,7 +62,8 @@ class ImpliedDoLoopVisitor : public PassUtils::PassVisitor contains_array = false; } - void visit_BinOp(const ASR::BinOp_t& x) { + template + void visit_BinOp(const T& x) { if( contains_array ) { return ; } @@ -74,6 +75,19 @@ class ImpliedDoLoopVisitor : public PassUtils::PassVisitor contains_array = left_array || right_array; } + void visit_IntegerBinOp(const ASR::IntegerBinOp_t& x) { + visit_BinOp(x); + } + void visit_RealBinOp(const ASR::RealBinOp_t& x) { + visit_BinOp(x); + } + void visit_ComplexBinOp(const ASR::ComplexBinOp_t& x) { + visit_BinOp(x); + } + void visit_LogicalBinOp(const ASR::LogicalBinOp_t& x) { + visit_BinOp(x); + } + void create_do_loop(ASR::ImpliedDoLoop_t* idoloop, ASR::Var_t* arr_var, ASR::expr_t* arr_idx=nullptr) { ASR::do_loop_head_t head; head.m_v = idoloop->m_var; @@ -90,9 +104,9 @@ class ImpliedDoLoopVisitor : public PassUtils::PassVisitor const_n = offset = num_grps = grp_start = nullptr; if( arr_idx == nullptr ) { const_n = LFortran::ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, arr_var->base.base.loc, idoloop->n_values, _type)); - offset = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, arr_var->base.base.loc, idoloop->m_var, ASR::binopType::Sub, idoloop->m_start, _type, nullptr, nullptr)); - num_grps = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, arr_var->base.base.loc, offset, ASR::binopType::Mul, const_n, _type, nullptr, nullptr)); - grp_start = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, arr_var->base.base.loc, num_grps, ASR::binopType::Add, const_1, _type, nullptr, nullptr)); + offset = LFortran::ASRUtils::EXPR(ASR::make_IntegerBinOp_t(al, arr_var->base.base.loc, idoloop->m_var, ASR::binopType::Sub, idoloop->m_start, _type, nullptr)); + num_grps = LFortran::ASRUtils::EXPR(ASR::make_IntegerBinOp_t(al, arr_var->base.base.loc, offset, ASR::binopType::Mul, const_n, _type, nullptr)); + grp_start = LFortran::ASRUtils::EXPR(ASR::make_IntegerBinOp_t(al, arr_var->base.base.loc, num_grps, ASR::binopType::Add, const_1, _type, nullptr)); } for( size_t i = 0; i < idoloop->n_values; i++ ) { Vec args; @@ -101,9 +115,9 @@ class ImpliedDoLoopVisitor : public PassUtils::PassVisitor ai.m_left = nullptr; if( arr_idx == nullptr ) { ASR::expr_t* const_i = LFortran::ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, arr_var->base.base.loc, i, _type)); - ASR::expr_t* idx = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, arr_var->base.base.loc, + ASR::expr_t* idx = LFortran::ASRUtils::EXPR(ASR::make_IntegerBinOp_t(al, arr_var->base.base.loc, grp_start, ASR::binopType::Add, const_i, - _type, nullptr, nullptr)); + _type, nullptr)); ai.m_right = idx; } else { ai.m_right = arr_idx; @@ -120,7 +134,7 @@ class ImpliedDoLoopVisitor : public PassUtils::PassVisitor ASR::stmt_t* doloop_stmt = LFortran::ASRUtils::STMT(ASR::make_Assignment_t(al, arr_var->base.base.loc, array_ref, idoloop->m_values[i], nullptr)); doloop_body.push_back(al, doloop_stmt); if( arr_idx != nullptr ) { - ASR::expr_t* increment = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, arr_var->base.base.loc, arr_idx, ASR::binopType::Add, const_1, LFortran::ASRUtils::expr_type(arr_idx), nullptr, nullptr)); + ASR::expr_t* increment = LFortran::ASRUtils::EXPR(ASR::make_IntegerBinOp_t(al, arr_var->base.base.loc, arr_idx, ASR::binopType::Add, const_1, LFortran::ASRUtils::expr_type(arr_idx), nullptr)); ASR::stmt_t* assign_stmt = LFortran::ASRUtils::STMT(ASR::make_Assignment_t(al, arr_var->base.base.loc, arr_idx, increment, nullptr)); doloop_body.push_back(al, assign_stmt); } @@ -176,7 +190,7 @@ class ImpliedDoLoopVisitor : public PassUtils::PassVisitor LFortran::ASRUtils::expr_type(LFortran::ASRUtils::EXPR((ASR::asr_t*)arr_var)), nullptr)); ASR::stmt_t* assign_stmt = LFortran::ASRUtils::STMT(ASR::make_Assignment_t(al, arr_var->base.base.loc, array_ref, arr_init->m_args[k], nullptr)); pass_result.push_back(al, assign_stmt); - ASR::expr_t* increment = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, arr_var->base.base.loc, idx_var, ASR::binopType::Add, const_1, LFortran::ASRUtils::expr_type(idx_var), nullptr, nullptr)); + ASR::expr_t* increment = LFortran::ASRUtils::EXPR(ASR::make_IntegerBinOp_t(al, arr_var->base.base.loc, idx_var, ASR::binopType::Add, const_1, LFortran::ASRUtils::expr_type(idx_var), nullptr)); assign_stmt = LFortran::ASRUtils::STMT(ASR::make_Assignment_t(al, arr_var->base.base.loc, idx_var, increment, nullptr)); pass_result.push_back(al, assign_stmt); } diff --git a/src/libasr/pass/inline_function_calls.cpp b/src/libasr/pass/inline_function_calls.cpp index b6d72627fc..b45ccb1eb9 100644 --- a/src/libasr/pass/inline_function_calls.cpp +++ b/src/libasr/pass/inline_function_calls.cpp @@ -349,8 +349,9 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitor(x); + template + void visit_BinOp(const T& x) { + T& xx = const_cast(x); from_inline_function_call = true; function_result_var = nullptr; visit_expr(*x.m_left); @@ -366,6 +367,19 @@ class InlineFunctionCallVisitor : public PassUtils::PassVisitor asr = const_cast(&(x.base)); } - void visit_BinOp(const ASR::BinOp_t& x) { - ASR::BinOp_t& x_unconst = const_cast(x); + template + void visit_BinOp(const T& x) { + T& x_unconst = const_cast(x); asr = nullptr; this->visit_expr(*x.m_left); if( asr != nullptr ) { @@ -89,6 +90,19 @@ class VarVisitor : public ASR::BaseWalkVisitor asr = const_cast(&(x.base)); } + void visit_IntegerBinOp(const ASR::IntegerBinOp_t& x) { + visit_BinOp(x); + } + void visit_RealBinOp(const ASR::RealBinOp_t& x) { + visit_BinOp(x); + } + void visit_ComplexBinOp(const ASR::ComplexBinOp_t& x) { + visit_BinOp(x); + } + void visit_LogicalBinOp(const ASR::LogicalBinOp_t& x) { + visit_BinOp(x); + } + void visit_Cast(const ASR::Cast_t& x) { /* asr = nullptr; diff --git a/src/libasr/pass/pass_utils.cpp b/src/libasr/pass/pass_utils.cpp index 292043197e..62f644a122 100644 --- a/src/libasr/pass/pass_utils.cpp +++ b/src/libasr/pass/pass_utils.cpp @@ -179,6 +179,22 @@ namespace LFortran { return array_ref; } + ASR::expr_t* create_binop_add(Allocator &al, const Location &loc, ASR::expr_t* left, ASR::expr_t* right) { + LFORTRAN_ASSERT(ASRUtils::expr_type(left) == ASRUtils::expr_type(right)) + ASR::ttype_t* type = ASRUtils::expr_type(left); + // TODO: compute `value`: + if (ASRUtils::is_integer(*type)) { + return ASRUtils::EXPR(ASR::make_IntegerBinOp_t(al, loc, left, ASR::binopType::Add, right, type, nullptr)); + } else if (ASRUtils::is_real(*type)) { + return ASRUtils::EXPR(ASR::make_RealBinOp_t(al, loc, left, ASR::binopType::Add, right, type, nullptr)); + } else if (ASRUtils::is_complex(*type)) { + return ASRUtils::EXPR(ASR::make_ComplexBinOp_t(al, loc, left, ASR::binopType::Add, right, type, nullptr)); + } else { + LFORTRAN_ASSERT(false); + return nullptr; + } + } + void create_idx_vars(Vec& idx_vars, int n_dims, const Location& loc, Allocator& al, SymbolTable*& current_scope, std::string suffix) { idx_vars.reserve(al, n_dims); @@ -517,15 +533,15 @@ namespace LFortran { ASR::expr_t *target = loop.m_head.m_v; ASR::ttype_t *type = LFortran::ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4, nullptr, 0)); stmt1 = LFortran::ASRUtils::STMT(ASR::make_Assignment_t(al, loc, target, - LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, loc, a, ASR::binopType::Sub, c, type, nullptr, nullptr)), + LFortran::ASRUtils::EXPR(ASR::make_IntegerBinOp_t(al, loc, a, ASR::binopType::Sub, c, type, nullptr)), nullptr)); cond = LFortran::ASRUtils::EXPR(ASR::make_Compare_t(al, loc, - LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, loc, target, ASR::binopType::Add, c, type, nullptr, nullptr)), + LFortran::ASRUtils::EXPR(ASR::make_IntegerBinOp_t(al, loc, target, ASR::binopType::Add, c, type, nullptr)), cmp_op, b, type, nullptr, nullptr)); inc_stmt = LFortran::ASRUtils::STMT(ASR::make_Assignment_t(al, loc, target, - LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, loc, target, ASR::binopType::Add, c, type, nullptr, nullptr)), + LFortran::ASRUtils::EXPR(ASR::make_IntegerBinOp_t(al, loc, target, ASR::binopType::Add, c, type, nullptr)), nullptr)); } Vec body; diff --git a/src/libasr/pass/pass_utils.h b/src/libasr/pass/pass_utils.h index e4cab450a0..fd1553822f 100644 --- a/src/libasr/pass/pass_utils.h +++ b/src/libasr/pass/pass_utils.h @@ -22,6 +22,8 @@ namespace LFortran { ASR::expr_t* create_array_ref(ASR::symbol_t* arr, Vec& idx_vars, Allocator& al, const Location& loc, ASR::ttype_t* _type); + ASR::expr_t* create_binop_add(Allocator &al, const Location &loc, ASR::expr_t* left, ASR::expr_t* right); + void create_idx_vars(Vec& idx_vars, int n_dims, const Location& loc, Allocator& al, SymbolTable*& current_scope, std::string suffix="_k"); diff --git a/src/libasr/pass/select_case.cpp b/src/libasr/pass/select_case.cpp index ec06ecf599..60068b768f 100644 --- a/src/libasr/pass/select_case.cpp +++ b/src/libasr/pass/select_case.cpp @@ -53,16 +53,16 @@ inline ASR::expr_t* gen_test_expr_CaseStmt(Allocator& al, const Location& loc, A if( Case_Stmt->n_test == 1 ) { test_expr = LFortran::ASRUtils::EXPR(ASR::make_Compare_t(al, loc, a_test, ASR::cmpopType::Eq, Case_Stmt->m_test[0], LFortran::ASRUtils::expr_type(a_test), nullptr, nullptr)); } else if( Case_Stmt->n_test == 2 ) { - ASR::expr_t* left = LFortran::ASRUtils::EXPR(ASR::make_Compare_t(al, loc, a_test, ASR::cmpopType::Eq, Case_Stmt->m_test[0], LFortran::ASRUtils::expr_type(a_test), nullptr, nullptr)); - ASR::expr_t* right = LFortran::ASRUtils::EXPR(ASR::make_Compare_t(al, loc, a_test, ASR::cmpopType::Eq, Case_Stmt->m_test[1], LFortran::ASRUtils::expr_type(a_test), nullptr, nullptr)); - test_expr = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, loc, left, ASR::binopType::Add, right, LFortran::ASRUtils::expr_type(left), nullptr, nullptr)); + ASR::expr_t* left = ASRUtils::EXPR(ASR::make_Compare_t(al, loc, a_test, ASR::cmpopType::Eq, Case_Stmt->m_test[0], LFortran::ASRUtils::expr_type(a_test), nullptr, nullptr)); + ASR::expr_t* right = ASRUtils::EXPR(ASR::make_Compare_t(al, loc, a_test, ASR::cmpopType::Eq, Case_Stmt->m_test[1], LFortran::ASRUtils::expr_type(a_test), nullptr, nullptr)); + test_expr = PassUtils::create_binop_add(al, loc, left, right); } else { ASR::expr_t* left = LFortran::ASRUtils::EXPR(ASR::make_Compare_t(al, loc, a_test, ASR::cmpopType::Eq, Case_Stmt->m_test[0], LFortran::ASRUtils::expr_type(a_test), nullptr, nullptr)); ASR::expr_t* right = LFortran::ASRUtils::EXPR(ASR::make_Compare_t(al, loc, a_test, ASR::cmpopType::Eq, Case_Stmt->m_test[1], LFortran::ASRUtils::expr_type(a_test), nullptr, nullptr)); - test_expr = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, loc, left, ASR::binopType::Add, right, LFortran::ASRUtils::expr_type(left), nullptr, nullptr)); + test_expr = PassUtils::create_binop_add(al, loc, left, right); for( std::uint32_t j = 2; j < Case_Stmt->n_test; j++ ) { ASR::expr_t* newExpr = LFortran::ASRUtils::EXPR(ASR::make_Compare_t(al, loc, a_test, ASR::cmpopType::Eq, Case_Stmt->m_test[j], LFortran::ASRUtils::expr_type(a_test), nullptr, nullptr)); - test_expr = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, loc, test_expr, ASR::binopType::Add, newExpr, LFortran::ASRUtils::expr_type(newExpr), nullptr, nullptr)); + test_expr = PassUtils::create_binop_add(al, loc, test_expr, newExpr); } } return test_expr; @@ -77,7 +77,7 @@ inline ASR::expr_t* gen_test_expr_CaseStmt_Range(Allocator& al, const Location& } else if( Case_Stmt->m_start != nullptr && Case_Stmt->m_end != nullptr ) { ASR::expr_t* left = LFortran::ASRUtils::EXPR(ASR::make_Compare_t(al, loc, Case_Stmt->m_start, ASR::cmpopType::LtE, a_test, LFortran::ASRUtils::expr_type(a_test), nullptr, nullptr)); ASR::expr_t* right = LFortran::ASRUtils::EXPR(ASR::make_Compare_t(al, loc, a_test, ASR::cmpopType::LtE, Case_Stmt->m_end, LFortran::ASRUtils::expr_type(a_test), nullptr, nullptr)); - test_expr = LFortran::ASRUtils::EXPR(ASR::make_BinOp_t(al, loc, left, ASR::binopType::Mul, right, LFortran::ASRUtils::expr_type(left), nullptr, nullptr)); + test_expr = PassUtils::create_binop_add(al, loc, left, right); } return test_expr; } diff --git a/src/libasr/pass/sign_from_value.cpp b/src/libasr/pass/sign_from_value.cpp index a21b2876a7..414653c4ca 100644 --- a/src/libasr/pass/sign_from_value.cpp +++ b/src/libasr/pass/sign_from_value.cpp @@ -80,13 +80,14 @@ class SignFromValueVisitor : public PassUtils::SkipOptimizationFunctionVisitor + void visit_BinOp(const T& x_const) { if( !from_sign_from_value ) { return ; } from_sign_from_value = true; - ASR::BinOp_t& x = const_cast(x_const); + T& x = const_cast(x_const); sign_from_value_var = nullptr; visit_expr(*x.m_left); @@ -124,6 +125,16 @@ class SignFromValueVisitor : public PassUtils::SkipOptimizationFunctionVisitor(x); diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index c6109fc979..baa128ecb9 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -789,15 +789,14 @@ class CommonVisitor : public AST::BaseVisitor { ASR::expr_t *index_add_one(const Location &loc, ASR::expr_t *idx) { // Add 1 to the index `idx`, assumes `idx` is of type Integer 4 - ASR::expr_t *overloaded = nullptr; ASR::expr_t *comptime_value = nullptr; ASR::ttype_t *a_type = ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4, nullptr, 0)); ASR::expr_t *constant_one = ASR::down_cast(ASR::make_IntegerConstant_t( al, loc, 1, a_type)); - return ASRUtils::EXPR(ASR::make_BinOp_t(al, loc, idx, + return ASRUtils::EXPR(ASR::make_IntegerBinOp_t(al, loc, idx, ASR::binopType::Add, constant_one, a_type, - comptime_value, overloaded)); + comptime_value)); } // Casts `right` if needed to the type of `left` @@ -1160,6 +1159,7 @@ class CommonVisitor : public AST::BaseVisitor { } value = ASR::down_cast(ASR::make_IntegerConstant_t( al, loc, result, dest_type)); + tmp = ASR::make_IntegerBinOp_t(al, loc, left, op, right, dest_type, value); } else if (ASRUtils::is_real(*dest_type)) { double left_value = ASR::down_cast( @@ -1177,6 +1177,7 @@ class CommonVisitor : public AST::BaseVisitor { } value = ASR::down_cast(ASR::make_RealConstant_t( al, loc, result, dest_type)); + tmp = ASR::make_RealBinOp_t(al, loc, left, op, right, dest_type, value); } else if (ASRUtils::is_complex(*dest_type)) { ASR::ComplexConstant_t *left0 = ASR::down_cast( @@ -1196,6 +1197,7 @@ class CommonVisitor : public AST::BaseVisitor { } value = ASR::down_cast(ASR::make_ComplexConstant_t(al, loc, std::real(result), std::imag(result), dest_type)); + tmp = ASR::make_ComplexBinOp_t(al, loc, left, op, right, dest_type, value); } else if (ASRUtils::is_logical(*dest_type)) { int8_t left_value = ASR::down_cast( @@ -1213,12 +1215,9 @@ class CommonVisitor : public AST::BaseVisitor { } value = ASR::down_cast(ASR::make_IntegerConstant_t( al, loc, result, int_type)); - dest_type = int_type; + tmp = ASR::make_IntegerBinOp_t(al, loc, left, op, right, int_type, value); } } - ASR::expr_t *overloaded = nullptr; - tmp = ASR::make_BinOp_t(al, loc, left, op, right, dest_type, - value, overloaded); } void visit_Name(const AST::Name_t &x) { @@ -1280,7 +1279,7 @@ class CommonVisitor : public AST::BaseVisitor { } void visit_BoolOp(const AST::BoolOp_t &x) { - ASR::boolopType op; + ASR::logicalbinopType op; if (x.n_values > 2) { throw SemanticError("Only two operands supported for boolean operations", x.base.base.loc); @@ -1290,8 +1289,8 @@ class CommonVisitor : public AST::BaseVisitor { this->visit_expr(*x.m_values[1]); ASR::expr_t *rhs = ASRUtils::EXPR(tmp); switch (x.m_op) { - case (AST::boolopType::And): { op = ASR::boolopType::And; break; } - case (AST::boolopType::Or): { op = ASR::boolopType::Or; break; } + case (AST::boolopType::And): { op = ASR::logicalbinopType::And; break; } + case (AST::boolopType::Or): { op = ASR::logicalbinopType::Or; break; } default : { throw SemanticError("Boolean operator type not supported", x.base.base.loc); @@ -1311,8 +1310,8 @@ class CommonVisitor : public AST::BaseVisitor { ASRUtils::expr_value(rhs))->m_value; bool result; switch (op) { - case (ASR::boolopType::And): { result = left_value && right_value; break; } - case (ASR::boolopType::Or): { result = left_value || right_value; break; } + case (ASR::logicalbinopType::And): { result = left_value && right_value; break; } + case (ASR::logicalbinopType::Or): { result = left_value || right_value; break; } default : { throw SemanticError("Boolean operator type not supported", x.base.base.loc); @@ -1321,7 +1320,7 @@ class CommonVisitor : public AST::BaseVisitor { value = ASR::down_cast(ASR::make_LogicalConstant_t( al, x.base.base.loc, result, dest_type)); } - tmp = ASR::make_BoolOp_t(al, x.base.base.loc, lhs, op, rhs, dest_type, value); + tmp = ASR::make_LogicalBinOp_t(al, x.base.base.loc, lhs, op, rhs, dest_type, value); } void visit_BinOp(const AST::BinOp_t &x) {