From 42d6939fd5afa97caea6df56c8887093f1840124 Mon Sep 17 00:00:00 2001 From: Abdelrahman Khaled Date: Thu, 7 Jul 2022 04:49:39 +0200 Subject: [PATCH] Implement string comparison --- integration_tests/CMakeLists.txt | 1 + integration_tests/test_str_comparison.py | 36 ++++++++++++++++++ src/libasr/codegen/asr_to_llvm.cpp | 40 +++++++++++++++----- src/runtime/impure/lfortran_intrinsics.c | 47 ++++++++++++++++++++++++ src/runtime/impure/lfortran_intrinsics.h | 7 ++++ 5 files changed, 122 insertions(+), 9 deletions(-) create mode 100644 integration_tests/test_str_comparison.py diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index d0ce1ce315..cd02133bd0 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -191,6 +191,7 @@ RUN(NAME test_platform LABELS cpython llvm) RUN(NAME test_vars_01 LABELS cpython llvm) RUN(NAME test_version LABELS cpython llvm) RUN(NAME vec_01 LABELS cpython llvm) +RUN(NAME test_str_comparison LABELS cpython llvm) # Just CPython RUN(NAME test_builtin_bin LABELS cpython) diff --git a/integration_tests/test_str_comparison.py b/integration_tests/test_str_comparison.py new file mode 100644 index 0000000000..4ac40aaa78 --- /dev/null +++ b/integration_tests/test_str_comparison.py @@ -0,0 +1,36 @@ +def f(): + s1: str = "abcd" + s2: str = "abcd" + assert s1 == s2 + assert s1 <= s2 + assert s1 >= s2 + s1 = "abcde" + assert s1 >= s2 + assert s1 > s2 + s1 = "abc" + assert s1 < s2 + assert s1 <= s2 + s1 = "Abcd" + s2 = "abcd" + assert s1 < s2 + s1 = "orange" + s2 = "apple" + assert s1 >= s2 + assert s1 > s2 + s1 = "albatross" + s2 = "albany" + assert s1 >= s2 + assert s1 > s2 + assert s1 != s2 + s1 = "maple" + s2 = "morning" + assert s1 <= s2 + assert s1 < s2 + assert s1 != s2 + s1 = "Zebra" + s2 = "ant" + assert s1 <= s2 + assert s1 < s2 + assert s1 != s2 + +f() diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index 650d3f10f0..2c08609d2e 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -738,6 +738,29 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor return CreateLoad(presult); } + llvm::Value* lfortran_str_cmp(llvm::Value* left_arg, llvm::Value* right_arg, + std::string runtime_func_name) + { + llvm::Function *fn = module->getFunction(runtime_func_name); + if(!fn) { + llvm::FunctionType *function_type = llvm::FunctionType::get( + llvm::Type::getInt1Ty(context), { + character_type->getPointerTo(), + character_type->getPointerTo() + }, false); + fn = llvm::Function::Create(function_type, + llvm::Function::ExternalLinkage, runtime_func_name, *module); + } + llvm::AllocaInst *pleft_arg = builder->CreateAlloca(character_type, + nullptr); + builder->CreateStore(left_arg, pleft_arg); + llvm::AllocaInst *pright_arg = builder->CreateAlloca(character_type, + nullptr); + builder->CreateStore(right_arg, pright_arg); + std::vector args = {pleft_arg, pright_arg}; + return builder->CreateCall(fn, args); + } + llvm::Value* lfortran_strrepeat(llvm::Value* left_arg, llvm::Value* right_arg) { std::string runtime_func_name = "_lfortran_strrepeat"; @@ -3036,32 +3059,30 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor llvm::Value *left = tmp; this->visit_expr_wrapper(x.m_right, true); llvm::Value *right = tmp; - // TODO: For now we only compare the first character of the strings - left = CreateLoad(left); - right = CreateLoad(right); + std::string fn; switch (x.m_op) { case (ASR::cmpopType::Eq) : { - tmp = builder->CreateICmpEQ(left, right); + fn = "_lpython_str_compare_eq"; break; } case (ASR::cmpopType::NotEq) : { - tmp = builder->CreateICmpNE(left, right); + fn = "_lpython_str_compare_noteq"; break; } case (ASR::cmpopType::Gt) : { - tmp = builder->CreateICmpUGT(left, right); + fn = "_lpython_str_compare_gt"; break; } case (ASR::cmpopType::GtE) : { - tmp = builder->CreateICmpUGE(left, right); + fn = "_lpython_str_compare_gte"; break; } case (ASR::cmpopType::Lt) : { - tmp = builder->CreateICmpULT(left, right); + fn = "_lpython_str_compare_lt"; break; } case (ASR::cmpopType::LtE) : { - tmp = builder->CreateICmpULE(left, right); + fn = "_lpython_str_compare_lte"; break; } default : { @@ -3069,6 +3090,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor x.base.base.loc); } } + tmp = lfortran_str_cmp(left, right, fn); } void visit_LogicalCompare(const ASR::LogicalCompare_t &x) { diff --git a/src/runtime/impure/lfortran_intrinsics.c b/src/runtime/impure/lfortran_intrinsics.c index db8c864166..54e7c5ff27 100644 --- a/src/runtime/impure/lfortran_intrinsics.c +++ b/src/runtime/impure/lfortran_intrinsics.c @@ -614,6 +614,53 @@ LFORTRAN_API void _lfortran_strcat(char** s1, char** s2, char** dest) *dest = &(dest_char[0]); } +#define MIN(x, y) ((x < y) ? x : y) + +int str_compare(char **s1, char **s2) +{ + int s1_len = strlen(*s1); + int s2_len = strlen(*s2); + int lim = MIN(s1_len, s2_len); + int res = 0; + int i ; + for (i = 0; i < lim; i++) { + if ((*s1)[i] != (*s2)[i]) { + res = (*s1)[i] - (*s2)[i]; + break; + } + } + res = (i == lim)? s1_len - s2_len : res; + return res; +} +LFORTRAN_API bool _lpython_str_compare_eq(char **s1, char **s2) +{ + return str_compare(s1, s2) == 0; +} + +LFORTRAN_API bool _lpython_str_compare_noteq(char **s1, char **s2) +{ + return str_compare(s1, s2) != 0; +} + +LFORTRAN_API bool _lpython_str_compare_gt(char **s1, char **s2) +{ + return str_compare(s1, s2) > 0; +} + +LFORTRAN_API bool _lpython_str_compare_lte(char **s1, char **s2) +{ + return str_compare(s1, s2) <= 0; +} + +LFORTRAN_API bool _lpython_str_compare_lt(char **s1, char **s2) +{ + return str_compare(s1, s2) < 0; +} + +LFORTRAN_API bool _lpython_str_compare_gte(char **s1, char **s2) +{ + return str_compare(s1, s2) >= 0; +} //repeat str for n time LFORTRAN_API void _lfortran_strrepeat(char** s, int32_t n, char** dest) { diff --git a/src/runtime/impure/lfortran_intrinsics.h b/src/runtime/impure/lfortran_intrinsics.h index 617ec824b8..4e8573cc74 100644 --- a/src/runtime/impure/lfortran_intrinsics.h +++ b/src/runtime/impure/lfortran_intrinsics.h @@ -4,6 +4,7 @@ #include #include #include +#include #ifdef __cplusplus extern "C" { @@ -136,6 +137,12 @@ LFORTRAN_API float _lfortran_satanh(float x); LFORTRAN_API double _lfortran_datanh(double x); LFORTRAN_API float_complex_t _lfortran_catanh(float_complex_t x); LFORTRAN_API double_complex_t _lfortran_zatanh(double_complex_t x); +LFORTRAN_API bool _lpython_str_compare_eq(char** s1, char** s2); +LFORTRAN_API bool _lpython_str_compare_noteq(char** s1, char** s2); +LFORTRAN_API bool _lpython_str_compare_gt(char** s1, char** s2); +LFORTRAN_API bool _lpython_str_compare_lte(char** s1, char** s2); +LFORTRAN_API bool _lpython_str_compare_lt(char** s1, char** s2); +LFORTRAN_API bool _lpython_str_compare_gte(char** s1, char** s2); LFORTRAN_API void _lfortran_strrepeat(char** s, int32_t n, char** dest); LFORTRAN_API void _lfortran_strcat(char** s1, char** s2, char** dest); LFORTRAN_API int _lfortran_str_len(char** s);