Skip to content

Implement string comparison #744

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ RUN(NAME test_platform LABELS cpython llvm)
RUN(NAME test_vars_01 LABELS cpython llvm)
RUN(NAME test_version LABELS cpython llvm)
RUN(NAME vec_01 LABELS cpython llvm)
RUN(NAME test_str_comparison LABELS cpython llvm)

# Just CPython
RUN(NAME test_builtin_bin LABELS cpython)
36 changes: 36 additions & 0 deletions integration_tests/test_str_comparison.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
def f():
s1: str = "abcd"
s2: str = "abcd"
assert s1 == s2
assert s1 <= s2
assert s1 >= s2
s1 = "abcde"
assert s1 >= s2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A test for the != operator would be nice too!

assert s1 > s2
s1 = "abc"
assert s1 < s2
assert s1 <= s2
s1 = "Abcd"
s2 = "abcd"
assert s1 < s2
s1 = "orange"
s2 = "apple"
assert s1 >= s2
assert s1 > s2
s1 = "albatross"
s2 = "albany"
assert s1 >= s2
assert s1 > s2
assert s1 != s2
s1 = "maple"
s2 = "morning"
assert s1 <= s2
assert s1 < s2
assert s1 != s2
s1 = "Zebra"
s2 = "ant"
assert s1 <= s2
assert s1 < s2
assert s1 != s2

f()
40 changes: 31 additions & 9 deletions src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,29 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
return CreateLoad(presult);
}

llvm::Value* lfortran_str_cmp(llvm::Value* left_arg, llvm::Value* right_arg,
std::string runtime_func_name)
{
llvm::Function *fn = module->getFunction(runtime_func_name);
if(!fn) {
llvm::FunctionType *function_type = llvm::FunctionType::get(
llvm::Type::getInt1Ty(context), {
character_type->getPointerTo(),
character_type->getPointerTo()
}, false);
fn = llvm::Function::Create(function_type,
llvm::Function::ExternalLinkage, runtime_func_name, *module);
}
llvm::AllocaInst *pleft_arg = builder->CreateAlloca(character_type,
nullptr);
builder->CreateStore(left_arg, pleft_arg);
llvm::AllocaInst *pright_arg = builder->CreateAlloca(character_type,
nullptr);
builder->CreateStore(right_arg, pright_arg);
std::vector<llvm::Value*> args = {pleft_arg, pright_arg};
return builder->CreateCall(fn, args);
}

llvm::Value* lfortran_strrepeat(llvm::Value* left_arg, llvm::Value* right_arg)
{
std::string runtime_func_name = "_lfortran_strrepeat";
Expand Down Expand Up @@ -3036,39 +3059,38 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
llvm::Value *left = tmp;
this->visit_expr_wrapper(x.m_right, true);
llvm::Value *right = tmp;
// TODO: For now we only compare the first character of the strings
left = CreateLoad(left);
right = CreateLoad(right);
std::string fn;
switch (x.m_op) {
case (ASR::cmpopType::Eq) : {
tmp = builder->CreateICmpEQ(left, right);
fn = "_lpython_str_compare_eq";
break;
}
case (ASR::cmpopType::NotEq) : {
tmp = builder->CreateICmpNE(left, right);
fn = "_lpython_str_compare_noteq";
break;
}
case (ASR::cmpopType::Gt) : {
tmp = builder->CreateICmpUGT(left, right);
fn = "_lpython_str_compare_gt";
break;
}
case (ASR::cmpopType::GtE) : {
tmp = builder->CreateICmpUGE(left, right);
fn = "_lpython_str_compare_gte";
break;
}
case (ASR::cmpopType::Lt) : {
tmp = builder->CreateICmpULT(left, right);
fn = "_lpython_str_compare_lt";
break;
}
case (ASR::cmpopType::LtE) : {
tmp = builder->CreateICmpULE(left, right);
fn = "_lpython_str_compare_lte";
break;
}
default : {
throw CodeGenError("Comparison operator not implemented",
x.base.base.loc);
}
}
tmp = lfortran_str_cmp(left, right, fn);
}

void visit_LogicalCompare(const ASR::LogicalCompare_t &x) {
Expand Down
47 changes: 47 additions & 0 deletions src/runtime/impure/lfortran_intrinsics.c
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,53 @@ LFORTRAN_API void _lfortran_strcat(char** s1, char** s2, char** dest)
*dest = &(dest_char[0]);
}

#define MIN(x, y) ((x < y) ? x : y)

int str_compare(char **s1, char **s2)
{
int s1_len = strlen(*s1);
int s2_len = strlen(*s2);
int lim = MIN(s1_len, s2_len);
int res = 0;
int i ;
for (i = 0; i < lim; i++) {
if ((*s1)[i] != (*s2)[i]) {
res = (*s1)[i] - (*s2)[i];
break;
}
}
res = (i == lim)? s1_len - s2_len : res;
return res;
}
LFORTRAN_API bool _lpython_str_compare_eq(char **s1, char **s2)
{
return str_compare(s1, s2) == 0;
}

LFORTRAN_API bool _lpython_str_compare_noteq(char **s1, char **s2)
{
return str_compare(s1, s2) != 0;
}

LFORTRAN_API bool _lpython_str_compare_gt(char **s1, char **s2)
{
return str_compare(s1, s2) > 0;
}

LFORTRAN_API bool _lpython_str_compare_lte(char **s1, char **s2)
{
return str_compare(s1, s2) <= 0;
}

LFORTRAN_API bool _lpython_str_compare_lt(char **s1, char **s2)
{
return str_compare(s1, s2) < 0;
}

LFORTRAN_API bool _lpython_str_compare_gte(char **s1, char **s2)
{
return str_compare(s1, s2) >= 0;
}
//repeat str for n time
LFORTRAN_API void _lfortran_strrepeat(char** s, int32_t n, char** dest)
{
Expand Down
7 changes: 7 additions & 0 deletions src/runtime/impure/lfortran_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <stdarg.h>
#include <complex.h>
#include <inttypes.h>
#include <stdbool.h>

#ifdef __cplusplus
extern "C" {
Expand Down Expand Up @@ -136,6 +137,12 @@ LFORTRAN_API float _lfortran_satanh(float x);
LFORTRAN_API double _lfortran_datanh(double x);
LFORTRAN_API float_complex_t _lfortran_catanh(float_complex_t x);
LFORTRAN_API double_complex_t _lfortran_zatanh(double_complex_t x);
LFORTRAN_API bool _lpython_str_compare_eq(char** s1, char** s2);
LFORTRAN_API bool _lpython_str_compare_noteq(char** s1, char** s2);
LFORTRAN_API bool _lpython_str_compare_gt(char** s1, char** s2);
LFORTRAN_API bool _lpython_str_compare_lte(char** s1, char** s2);
LFORTRAN_API bool _lpython_str_compare_lt(char** s1, char** s2);
LFORTRAN_API bool _lpython_str_compare_gte(char** s1, char** s2);
LFORTRAN_API void _lfortran_strrepeat(char** s, int32_t n, char** dest);
LFORTRAN_API void _lfortran_strcat(char** s1, char** s2, char** dest);
LFORTRAN_API int _lfortran_str_len(char** s);
Expand Down