Skip to content

Add list.index #1703

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ RUN(NAME test_list_09 LABELS cpython llvm c)
RUN(NAME test_list_10 LABELS cpython llvm c)
RUN(NAME test_list_section LABELS cpython llvm c)
RUN(NAME test_list_count LABELS cpython llvm)
RUN(NAME test_list_index LABELS cpython llvm)
RUN(NAME test_tuple_01 LABELS cpython llvm c)
RUN(NAME test_tuple_02 LABELS cpython llvm c)
RUN(NAME test_tuple_03 LABELS cpython llvm c)
Expand Down
40 changes: 40 additions & 0 deletions integration_tests/test_list_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from lpython import i32, f64

def test_list_index():
i: i32
x: list[i32] = []
y: list[str] = []
z: list[tuple[i32, str, f64]] = []

for i in range(-5, 0):
x.append(i)
assert x.index(i) == len(x)-1
x.append(i)
assert x.index(i) == len(x)-2
x.remove(i)
assert x.index(i) == len(x)-1

assert x == [-5, -4, -3, -2, -1]

for i in range(-5, 0):
x.append(i)
assert x.index(i) == 0
x.remove(i)
assert x.index(i) == len(x)-1

# str
y = ['a', 'abc', 'a', 'b', 'abc']
assert y.index('a') == 0
assert y.index('abc') == 1

# tuple, float
z = [(i32(1), 'a', f64(2.01)), (i32(-1), 'b', f64(2)), (i32(1), 'a', f64(2.02))]
assert z.index((i32(1), 'a', f64(2.01))) == 0
z.insert(0, (i32(1), 'a', f64(2)))
assert z.index((i32(1), 'a', f64(2.00))) == 0
z.append((i32(1), 'a', f64(2.00)))
assert z.index((i32(1), 'a', f64(2))) == 0
z.remove((i32(1), 'a', f64(2)))
assert z.index((i32(1), 'a', f64(2.00))) == 3

test_list_index()
41 changes: 41 additions & 0 deletions src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
#include <libasr/codegen/llvm_utils.h>
#include <libasr/codegen/llvm_array_utils.h>

#include <libasr/pass/intrinsic_function_registry.h>

#if LLVM_VERSION_MAJOR >= 11
# define FIXED_VECTOR_TYPE llvm::FixedVectorType
#else
Expand Down Expand Up @@ -1929,6 +1931,45 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
tmp = list_api->count(plist, item, asr_el_type, *module);
}

void generate_ListIndex(ASR::expr_t* m_arg, ASR::expr_t* m_ele) {
ASR::ttype_t* asr_el_type = ASRUtils::get_contained_type(ASRUtils::expr_type(m_arg));
int64_t ptr_loads_copy = ptr_loads;
ptr_loads = 0;
this->visit_expr(*m_arg);
llvm::Value* plist = tmp;

ptr_loads = !LLVM::is_llvm_struct(asr_el_type);
this->visit_expr_wrapper(m_ele, true);
ptr_loads = ptr_loads_copy;
llvm::Value *item = tmp;
tmp = list_api->index(plist, item, asr_el_type, *module);
}

void visit_IntrinsicFunction(const ASR::IntrinsicFunction_t& x) {
switch (static_cast<ASRUtils::IntrinsicFunctions>(x.m_intrinsic_id)) {
case ASRUtils::IntrinsicFunctions::ListIndex: {
switch (x.m_overload_id) {
case 0: {
ASR::expr_t* m_arg = x.m_args[0];
ASR::expr_t* m_ele = x.m_args[1];
generate_ListIndex(m_arg, m_ele);
break ;
}
default: {
throw CodeGenError("list.index only accepts one argument",
x.base.base.loc);
}
}
break ;
}
default: {
throw CodeGenError( ASRUtils::IntrinsicFunctionRegistry::
get_intrinsic_function_name(x.m_intrinsic_id) +
" is not implemented by LLVM backend.", x.base.base.loc);
}
}
}

void visit_ListClear(const ASR::ListClear_t& x) {
int64_t ptr_loads_copy = ptr_loads;
ptr_loads = 0;
Expand Down
5 changes: 5 additions & 0 deletions src/libasr/codegen/llvm_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2498,6 +2498,11 @@ namespace LCompilers {
return LLVM::CreateLoad(*builder, i);
}

llvm::Value* LLVMList::index(llvm::Value* list, llvm::Value* item,
ASR::ttype_t* item_type, llvm::Module& module) {
return LLVMList::find_item_position(list, item, item_type, module);
}

llvm::Value* LLVMList::count(llvm::Value* list, llvm::Value* item,
ASR::ttype_t* item_type, llvm::Module& module) {
llvm::Type* pos_type = llvm::Type::getInt32Ty(context);
Expand Down
3 changes: 3 additions & 0 deletions src/libasr/codegen/llvm_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,9 @@ namespace LCompilers {
llvm::Value* item, ASR::ttype_t* item_type,
llvm::Module& module);

llvm::Value* index(llvm::Value* list, llvm::Value* item,
ASR::ttype_t* item_type, llvm::Module& module);

llvm::Value* count(llvm::Value* list, llvm::Value* item,
ASR::ttype_t* item_type, llvm::Module& module);

Expand Down
23 changes: 12 additions & 11 deletions src/libasr/pass/intrinsic_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,30 +38,31 @@ class ReplaceIntrinsicFunction: public ASR::BaseExprReplacer<ReplaceIntrinsicFun


void replace_IntrinsicFunction(ASR::IntrinsicFunction_t* x) {
LCOMPILERS_ASSERT(x->n_args == 1)
Vec<ASR::call_arg_t> new_args;
// Replace any IntrinsicFunctions in the argument first:
{
ASR::expr_t** current_expr_copy_ = current_expr;
current_expr = &(x->m_args[0]);
replace_expr(x->m_args[0]);
new_args.reserve(al, x->n_args);
ASR::call_arg_t arg0;
arg0.m_value = *current_expr; // Use the converted arg
new_args.push_back(al, arg0);
current_expr = current_expr_copy_;
for( size_t i = 0; i < x->n_args; i++ ) {
ASR::expr_t** current_expr_copy_ = current_expr;
current_expr = &(x->m_args[i]);
replace_expr(x->m_args[i]);
ASR::call_arg_t arg0;
arg0.m_value = *current_expr; // Use the converted arg
new_args.push_back(al, arg0);
current_expr = current_expr_copy_;
}
}
// TODO: currently we always instantiate a new function.
// Rather we should reuse the old instantiation if it has
// exactly the same arguments. For that we could use the
// overload_id, and uniquely encode the argument types.
// We could maintain a mapping of type -> id and look it up.
if( !ASRUtils::IntrinsicFunctionRegistry::is_intrinsic_function(x->m_intrinsic_id) ) {
throw LCompilersException("Intrinsic function not implemented");
}

ASRUtils::impl_function instantiate_function =
ASRUtils::IntrinsicFunctionRegistry::get_instantiate_function(x->m_intrinsic_id);
if( instantiate_function == nullptr ) {
return ;
}
Vec<ASR::ttype_t*> arg_types;
arg_types.reserve(al, x->n_args);
for( size_t i = 0; i < x->n_args; i++ ) {
Expand Down
76 changes: 76 additions & 0 deletions src/libasr/pass/intrinsic_function_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ enum class IntrinsicFunctions : int64_t {
Gamma,
LogGamma,
Abs,

ListIndex,
// ...
};

Expand Down Expand Up @@ -453,6 +455,54 @@ namespace Abs {

} // namespace Abs

namespace ListIndex {

static inline ASR::expr_t *eval_list_index(Allocator &/*al*/,
const Location &/*loc*/, Vec<ASR::expr_t*>& /*args*/) {
// TODO: To be implemented for ListConstant expression
return nullptr;
}

static inline ASR::asr_t* create_ListIndex(Allocator& al, const Location& loc,
Vec<ASR::expr_t*>& args,
const std::function<void (const std::string &, const Location &)> err) {
if (args.size() != 2) {
// Support start and end arguments by overloading ListIndex
// intrinsic. We need 3 overload IDs,
// 0 - only list and element
// 1 - list, element and start
// 2 - list, element, start and end
// list, element and end case is not possible as list.index
// doesn't accept keyword arguments
err("For now index() takes exactly one argument", loc);
}

ASR::expr_t* list_expr = args[0];
ASR::ttype_t *type = ASRUtils::expr_type(list_expr);
ASR::ttype_t *list_type = ASR::down_cast<ASR::List_t>(type)->m_type;
ASR::ttype_t *ele_type = ASRUtils::expr_type(args[1]);
if (!ASRUtils::check_equal_type(ele_type, list_type)) {
std::string fnd = ASRUtils::get_type_code(ele_type);
std::string org = ASRUtils::get_type_code(list_type);
err(
"Type mismatch in 'index', the types must be compatible "
"(found: '" + fnd + "', expected: '" + org + "')", loc);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit worried that we have a frontend dependency here (ASRUtils::type_to_str_python).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah! Yes. I would do get_type_code here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We now use get_type_code since its just an error message so our frontend agnostic type codes should convey the error with same clarity.

}
Vec<ASR::expr_t*> arg_values;
arg_values.reserve(al, args.size());
for( size_t i = 0; i < args.size(); i++ ) {
arg_values.push_back(al, ASRUtils::expr_value(args[i]));
}
ASR::expr_t* compile_time_value = eval_list_index(al, loc, arg_values);
ASR::ttype_t *to_type = ASRUtils::TYPE(ASR::make_Integer_t(al, loc,
4, nullptr, 0));
return ASR::make_IntrinsicFunction_t(al, loc,
static_cast<int64_t>(ASRUtils::IntrinsicFunctions::ListIndex),
args.p, args.size(), 0, to_type, compile_time_value);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the idea that create_ListIndex is to be used when we need to check the argument types and compute compile time value (if available), but if we already have everything ready (such as in some transformation passes), we just use make_IntrinsicFunction_t directly?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. All the error checking related to an intrinsic function resides in one place. Also there is no restriction on using make_IntrinsicFunction_t directly (its a public API and if you few feel you can avoid calling, create_ListIndex or any other create function then feel free to do so).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Later we might need to add some checks to verify() to ensure all arguments to IntrinsicFunctions are correct. For now this is good.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can easily do this by registering verify_args function in IntrinsicFunctionRegistry. We can define, LogGamma::verify_args, Abs::verify_args, etc. Then, in asr_verify.cpp we can call these methods in the same way as we call instantiate_function in intrinsic_function.cpp.

}

} // namespace ListIndex


namespace IntrinsicFunctionRegistry {

Expand All @@ -468,13 +518,28 @@ namespace IntrinsicFunctionRegistry {
&Abs::instantiate_Abs}
};

static const std::map<int64_t, std::string>& intrinsic_function_id_to_name = {
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::LogGamma),
"log_gamma"},

{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::Sin),
"sin"},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::Cos),
"cos"},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::Abs),
"abs"},
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::ListIndex),
"list.index"}
};

static const std::map<std::string,
std::pair<create_intrinsic_function,
eval_intrinsic_function>>& intrinsic_function_by_name_db = {
{"log_gamma", {&LogGamma::create_LogGamma, &LogGamma::eval_log_gamma}},
{"sin", {&Sin::create_Sin, &Sin::eval_Sin}},
{"cos", {&Cos::create_Cos, &Cos::eval_Cos}},
{"abs", {&Abs::create_Abs, &Abs::eval_Abs}},
{"list.index", {&ListIndex::create_ListIndex, &ListIndex::eval_list_index}},
};

static inline bool is_intrinsic_function(const std::string& name) {
Expand All @@ -490,9 +555,20 @@ namespace IntrinsicFunctionRegistry {
}

static inline impl_function get_instantiate_function(int64_t id) {
if( intrinsic_function_by_id_db.find(id) == intrinsic_function_by_id_db.end() ) {
return nullptr;
}
return intrinsic_function_by_id_db.at(id);
}

static inline std::string get_intrinsic_function_name(int64_t id) {
if( intrinsic_function_id_to_name.find(id) == intrinsic_function_id_to_name.end() ) {
throw LCompilersException("IntrinsicFunction with ID " + std::to_string(id) +
" has no name registered for it");
}
return intrinsic_function_id_to_name.at(id);
}

} // namespace IntrinsicFunctionRegistry

#define INTRINSIC_NAME_CASE(X) \
Expand Down
16 changes: 16 additions & 0 deletions src/lpython/semantics/python_attribute_eval.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <libasr/string_utils.h>
#include <lpython/utils.h>
#include <lpython/semantics/semantic_exception.h>
#include <libasr/pass/intrinsic_function_registry.h>

namespace LCompilers::LPython {

Expand All @@ -22,6 +23,7 @@ struct AttributeHandler {
{"list@append", &eval_list_append},
{"list@remove", &eval_list_remove},
{"list@count", &eval_list_count},
{"list@index", &eval_list_index},
{"list@clear", &eval_list_clear},
{"list@insert", &eval_list_insert},
{"list@pop", &eval_list_pop},
Expand Down Expand Up @@ -149,6 +151,20 @@ struct AttributeHandler {
return make_ListCount_t(al, loc, s, args[0], to_type, nullptr);
}

static ASR::asr_t* eval_list_index(ASR::expr_t *s, Allocator &al, const Location &loc,
Vec<ASR::expr_t*> &args, diag::Diagnostics &/*diag*/) {
Vec<ASR::expr_t*> args_with_list;
args_with_list.reserve(al, args.size() + 1);
args_with_list.push_back(al, s);
for(size_t i = 0; i < args.size(); i++) {
args_with_list.push_back(al, args[i]);
}
ASRUtils::create_intrinsic_function create_function =
ASRUtils::IntrinsicFunctionRegistry::get_create_function("list.index");
return create_function(al, loc, args_with_list, [&](const std::string &msg, const Location &loc)
{ throw SemanticError(msg, loc); });
}

static ASR::asr_t* eval_list_insert(ASR::expr_t *s, Allocator &al, const Location &loc,
Vec<ASR::expr_t*> &args, diag::Diagnostics &diag) {
if (args.size() != 2) {
Expand Down
9 changes: 9 additions & 0 deletions tests/errors/test_list_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from lpython import i32

def test_list_index_error():
a: list[i32]
a = [1, 2, 3]
# a.index(1.0) # type mismatch
print(a.index(0)) # no error?

test_list_index_error()
13 changes: 13 additions & 0 deletions tests/reference/asr-test_list_index-6b8f30b.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"basename": "asr-test_list_index-6b8f30b",
"cmd": "lpython --show-asr --no-color {infile} -o {outfile}",
"infile": "tests/errors/test_list_index.py",
"infile_hash": "9c629bc9805dde8f11fc465dc8b35f2c16aef28634c7f105b2e5cfa3",
"outfile": null,
"outfile_hash": null,
"stdout": "asr-test_list_index-6b8f30b.stdout",
"stdout_hash": "d91e269b3ba3053cbecf4e08960e027063d78db443c1bc6440cdf3c6",
"stderr": null,
"stderr_hash": null,
"returncode": 0
}
1 change: 1 addition & 0 deletions tests/reference/asr-test_list_index-6b8f30b.stdout
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
(TranslationUnit (SymbolTable 1 {main_program: (Program (SymbolTable 3 {}) main_program [] []), test_list_index_error: (Function (SymbolTable 2 {a: (Variable 2 a [] Local () () Default (List (Integer 4 [])) Source Public Required .false.)}) test_list_index_error (FunctionType [] () Source Implementation () .false. .false. .false. .false. .false. [] [] .false.) [] [] [(= (Var 2 a) (ListConstant [(IntegerConstant 1 (Integer 4 [])) (IntegerConstant 2 (Integer 4 [])) (IntegerConstant 3 (Integer 4 []))] (List (Integer 4 []))) ()) (Print () [(ListIndex (Var 2 a) (IntegerConstant 0 (Integer 4 [])) (Integer 4 []) ())] () ())] () Public .false. .false.)}) [])
13 changes: 13 additions & 0 deletions tests/reference/runtime-test_list_index-0483808.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"basename": "runtime-test_list_index-0483808",
"cmd": "lpython {infile}",
"infile": "tests/errors/test_list_index.py",
"infile_hash": "991dc5eddf2579b7c6b3bc31bb07656cec73bdf91c3196a761985682",
"outfile": null,
"outfile_hash": null,
"stdout": null,
"stdout_hash": null,
"stderr": "runtime-test_list_index-0483808.stderr",
"stderr_hash": "dd3d49b5f2f97ed8f1d27cd73ebca7a8740483660dd4ae702e2048b2",
"returncode": 1
}
1 change: 1 addition & 0 deletions tests/reference/runtime-test_list_index-0483808.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ValueError: The list does not contain the element: 0
4 changes: 4 additions & 0 deletions tests/tests.toml
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,10 @@ asr = true
filename = "errors/test_list_count.py"
asr = true

[[test]]
filename = "errors/test_list_index.py"
run = true

[[test]]
filename = "errors/test_list1.py"
asr = true
Expand Down