diff --git a/src/solvers/smt2_incremental/construct_value_expr_from_smt.cpp b/src/solvers/smt2_incremental/construct_value_expr_from_smt.cpp index 560df31c7c2..409f1e9997c 100644 --- a/src/solvers/smt2_incremental/construct_value_expr_from_smt.cpp +++ b/src/solvers/smt2_incremental/construct_value_expr_from_smt.cpp @@ -87,6 +87,18 @@ class value_expr_from_smt_factoryt : public smt_term_const_downcast_visitort "Unexpected conversion of function application to value expression."); } + void visit(const smt_forall_termt &forall) override + { + INVARIANT( + false, "Unexpected conversion of forall quantifier to value expression."); + } + + void visit(const smt_exists_termt &exists) override + { + INVARIANT( + false, "Unexpected conversion of exists quantifier to value expression."); + } + public: /// \brief This function is complete the external interface to this class. All /// construction of this class and construction of expressions should be diff --git a/src/solvers/smt2_incremental/smt_terms.cpp b/src/solvers/smt2_incremental/smt_terms.cpp index e3022cb3160..1c81c0f9f4d 100644 --- a/src/solvers/smt2_incremental/smt_terms.cpp +++ b/src/solvers/smt2_incremental/smt_terms.cpp @@ -151,6 +151,74 @@ smt_function_application_termt::arguments() const }); } +smt_forall_termt::smt_forall_termt( + std::vector bound_variables, + smt_termt predicate) + : smt_termt{ID_smt_forall_term, smt_bool_sortt{}} +{ + INVARIANT( + !bound_variables.empty(), + "A forall term should bind at least one variable."); + std::transform( + std::make_move_iterator(bound_variables.begin()), + std::make_move_iterator(bound_variables.end()), + std::back_inserter(get_sub()), + [](smt_identifier_termt &&bound_variable) { + return irept{std::move(bound_variable)}; + }); + INVARIANT( + predicate.get_sort().cast(), + "Predicate of forall quantifier is expected to have bool sort."); + set(ID_body, std::move(predicate)); +} + +const smt_termt &smt_forall_termt::predicate() const +{ + return static_cast(find(ID_body)); +} + +std::vector> +smt_forall_termt::bound_variables() const +{ + return make_range(get_sub()).map([](const irept &variable) { + return std::cref(static_cast(variable)); + }); +} + +smt_exists_termt::smt_exists_termt( + std::vector bound_variables, + smt_termt predicate) + : smt_termt{ID_smt_exists_term, smt_bool_sortt{}} +{ + INVARIANT( + !bound_variables.empty(), + "A exists term should bind at least one variable."); + std::transform( + std::make_move_iterator(bound_variables.begin()), + std::make_move_iterator(bound_variables.end()), + std::back_inserter(get_sub()), + [](smt_identifier_termt &&bound_variable) { + return irept{std::move(bound_variable)}; + }); + INVARIANT( + predicate.get_sort().cast(), + "Predicate of exists quantifier is expected to have bool sort."); + set(ID_body, std::move(predicate)); +} + +const smt_termt &smt_exists_termt::predicate() const +{ + return static_cast(find(ID_body)); +} + +std::vector> +smt_exists_termt::bound_variables() const +{ + return make_range(get_sub()).map([](const irept &variable) { + return std::cref(static_cast(variable)); + }); +} + template void accept(const smt_termt &term, const irep_idt &id, visitort &&visitor) { diff --git a/src/solvers/smt2_incremental/smt_terms.def b/src/solvers/smt2_incremental/smt_terms.def index fbd46311e03..33fc99248e9 100644 --- a/src/solvers/smt2_incremental/smt_terms.def +++ b/src/solvers/smt2_incremental/smt_terms.def @@ -10,3 +10,5 @@ TERM_ID(bool_literal) TERM_ID(identifier) TERM_ID(bit_vector_constant) TERM_ID(function_application) +TERM_ID(forall) +TERM_ID(exists) diff --git a/src/solvers/smt2_incremental/smt_terms.h b/src/solvers/smt2_incremental/smt_terms.h index b8cfc4b4fa8..53faa567c5d 100644 --- a/src/solvers/smt2_incremental/smt_terms.h +++ b/src/solvers/smt2_incremental/smt_terms.h @@ -207,6 +207,28 @@ class smt_function_application_termt : public smt_termt }; }; +class smt_forall_termt : public smt_termt +{ +public: + smt_forall_termt( + std::vector bound_variables, + smt_termt predicate); + std::vector> + bound_variables() const; + const smt_termt &predicate() const; +}; + +class smt_exists_termt : public smt_termt +{ +public: + smt_exists_termt( + std::vector bound_variables, + smt_termt predicate); + std::vector> + bound_variables() const; + const smt_termt &predicate() const; +}; + class smt_term_const_downcast_visitort { public: diff --git a/src/solvers/smt2_incremental/smt_to_smt2_string.cpp b/src/solvers/smt2_incremental/smt_to_smt2_string.cpp index 8cfb44c31f3..246baff557d 100644 --- a/src/solvers/smt2_incremental/smt_to_smt2_string.cpp +++ b/src/solvers/smt2_incremental/smt_to_smt2_string.cpp @@ -96,6 +96,11 @@ std::string smt_to_smt2_string(const smt_sortt &sort) return ss.str(); } +struct sorted_variablest final +{ + std::vector> identifiers; +}; + /// \note The printing algorithm in the `smt_term_to_string_convertert` class is /// implemented using an explicit `std::stack` rather than using recursion /// and the call stack. This is done in order to ensure we can print smt terms @@ -124,6 +129,7 @@ class smt_term_to_string_convertert : private smt_term_const_downcast_visitort template output_functiont make_output_function( const std::vector> &output); + output_functiont make_output_function(const sorted_variablest &output); /// \brief Single argument version of `push_outputs`. template @@ -153,6 +159,8 @@ class smt_term_to_string_convertert : private smt_term_const_downcast_visitort void visit(const smt_bit_vector_constant_termt &bit_vector_constant) override; void visit(const smt_function_application_termt &function_application) override; + void visit(const smt_forall_termt &forall) override; + void visit(const smt_exists_termt &exists) override; public: /// \brief This function is complete the external interface to this class. All @@ -199,6 +207,25 @@ smt_term_to_string_convertert::make_output_function( }; } +smt_term_to_string_convertert::output_functiont +smt_term_to_string_convertert::make_output_function( + const sorted_variablest &output) +{ + return [=](std::ostream &os) { + const auto push_sorted_variable = + [&](const smt_identifier_termt &identifier) { + push_outputs("(", identifier, " ", identifier.get_sort(), ")"); + }; + for(const auto &bound_variable : + make_range(output.identifiers.rbegin(), --output.identifiers.rend())) + { + push_sorted_variable(bound_variable); + push_output(" "); + } + push_sorted_variable(output.identifiers.front()); + }; +} + template void smt_term_to_string_convertert::push_output(outputt &&output) { @@ -255,6 +282,20 @@ void smt_term_to_string_convertert::visit( push_outputs("(", id, std::move(arguments), ")"); } +void smt_term_to_string_convertert::visit(const smt_forall_termt &forall) +{ + sorted_variablest bound_variables{forall.bound_variables()}; + auto predicate = forall.predicate(); + push_outputs("(forall (", bound_variables, ") ", std::move(predicate), ")"); +} + +void smt_term_to_string_convertert::visit(const smt_exists_termt &exists) +{ + sorted_variablest bound_variables{exists.bound_variables()}; + auto predicate = exists.predicate(); + push_outputs("(exists (", bound_variables, ") ", std::move(predicate), ")"); +} + std::ostream & smt_term_to_string_convertert::convert(std::ostream &os, const smt_termt &term) { diff --git a/unit/solvers/smt2_incremental/construct_value_expr_from_smt.cpp b/unit/solvers/smt2_incremental/construct_value_expr_from_smt.cpp index d0a570277ce..9f151531441 100644 --- a/unit/solvers/smt2_incremental/construct_value_expr_from_smt.cpp +++ b/unit/solvers/smt2_incremental/construct_value_expr_from_smt.cpp @@ -127,6 +127,18 @@ TEST_CASE( smt_core_theoryt::make_not(smt_bool_literal_termt{true}), unsignedbv_typet{16}, "Unexpected conversion of function application to value expression."}, + rowt{ + smt_forall_termt{ + {smt_identifier_termt{"i", smt_bool_sortt{}}}, + smt_bool_literal_termt{true}}, + bool_typet{}, + "Unexpected conversion of forall quantifier to value expression."}, + rowt{ + smt_exists_termt{ + {smt_identifier_termt{"j", smt_bool_sortt{}}}, + smt_bool_literal_termt{true}}, + bool_typet{}, + "Unexpected conversion of exists quantifier to value expression."}, rowt{ smt_bit_vector_constant_termt{0, 16}, pointer_typet{unsigned_int_type(), 0}, diff --git a/unit/solvers/smt2_incremental/smt_terms.cpp b/unit/solvers/smt2_incremental/smt_terms.cpp index 0405b5fb109..847cadd987c 100644 --- a/unit/solvers/smt2_incremental/smt_terms.cpp +++ b/unit/solvers/smt2_incremental/smt_terms.cpp @@ -1,11 +1,11 @@ // Author: Diffblue Ltd. -#include +#include #include #include - -#include +#include +#include #include @@ -85,6 +85,82 @@ TEST_CASE("smt_termt equality.", "[core][smt2_incremental]") smt_bit_vector_constant_termt{12, 8}); } +template +std::string term_description(); + +template <> +std::string term_description() +{ + return "forall"; +} + +template <> +std::string term_description() +{ + return "exists"; +} + +TEMPLATE_TEST_CASE( + "smt quantifier terms", + "[core][smt2_incremental]", + smt_forall_termt, + smt_exists_termt) +{ + using quantifiert = TestType; + const smt_identifier_termt i{"i", smt_bit_vector_sortt{8}}; + const smt_identifier_termt j{"j", smt_bit_vector_sortt{8}}; + SECTION("Getters") + { + SECTION("One bound variable") + { + const auto predicate = smt_core_theoryt::equal(i, i); + const quantifiert quantifier{{i}, predicate}; + CHECK(quantifier.get_sort() == smt_bool_sortt{}); + const auto variables = quantifier.bound_variables(); + CHECK(quantifier.predicate() == predicate); + REQUIRE(variables.size() == 1); + CHECK(variables[0].get() == i); + } + SECTION("Two bound variables") + { + const auto predicate = smt_core_theoryt::distinct(i, j); + const quantifiert quantifier{{i, j}, predicate}; + CHECK(quantifier.get_sort() == smt_bool_sortt{}); + const auto variables = quantifier.bound_variables(); + CHECK(quantifier.predicate() == predicate); + REQUIRE(variables.size() == 2); + CHECK(variables[0].get() == i); + CHECK(variables[1].get() == j); + } + } + SECTION("Constructor validation") + { + cbmc_invariants_should_throwt invariants_throw; + SECTION("Empty variables") + { + const auto generate_error = [&]() { + quantifiert{{}, smt_core_theoryt::equal(i, i)}; + }; + REQUIRE_THROWS_MATCHES( + generate_error(), + invariant_failedt, + invariant_failure_containing( + "A " + term_description() + + " term should bind at least one variable.")); + } + SECTION("Non bool predicate") + { + const auto generate_error = [&]() { quantifiert{{i}, i}; }; + REQUIRE_THROWS_MATCHES( + generate_error(), + invariant_failedt, + invariant_failure_containing( + "Predicate of " + term_description() + + " quantifier is expected to have bool sort.")); + } + } +} + template class term_visit_type_checkert final : public smt_term_const_downcast_visitort { @@ -139,6 +215,30 @@ class term_visit_type_checkert final : public smt_term_const_downcast_visitort unexpected_term_visited = true; } } + + void visit(const smt_forall_termt &) override + { + if(std::is_same::value) + { + expected_term_visited = true; + } + else + { + unexpected_term_visited = true; + } + } + + void visit(const smt_exists_termt &) override + { + if(std::is_same::value) + { + expected_term_visited = true; + } + else + { + unexpected_term_visited = true; + } + } }; template @@ -168,13 +268,31 @@ smt_function_application_termt make_test_term() return smt_core_theoryt::make_not(smt_bool_literal_termt{true}); } +template <> +smt_forall_termt make_test_term() +{ + const smt_identifier_termt identifier{"i", smt_bit_vector_sortt{8}}; + return smt_forall_termt{ + {identifier}, smt_core_theoryt::equal(identifier, identifier)}; +} + +template <> +smt_exists_termt make_test_term() +{ + const smt_identifier_termt identifier{"i", smt_bit_vector_sortt{8}}; + return smt_exists_termt{ + {identifier}, smt_core_theoryt::equal(identifier, identifier)}; +} + TEMPLATE_TEST_CASE( "smt_termt::accept(visitor)", "[core][smt2_incremental]", smt_bool_literal_termt, smt_identifier_termt, smt_bit_vector_constant_termt, - smt_function_application_termt) + smt_function_application_termt, + smt_forall_termt, + smt_exists_termt) { term_visit_type_checkert checker; make_test_term().accept(checker); diff --git a/unit/solvers/smt2_incremental/smt_to_smt2_string.cpp b/unit/solvers/smt2_incremental/smt_to_smt2_string.cpp index 60c74b5aad6..7f2b9575dc1 100644 --- a/unit/solvers/smt2_incremental/smt_to_smt2_string.cpp +++ b/unit/solvers/smt2_incremental/smt_to_smt2_string.cpp @@ -202,3 +202,45 @@ TEST_CASE( CHECK( smt_to_smt2_string(smt_set_logic_commandt{qf_bv}) == "(set-logic QF_BV)"); } + +TEST_CASE("SMT forall term to string conversion", "[core][smt2_incremental]") +{ + const smt_identifier_termt i{"i", smt_bit_vector_sortt{8}}; + const smt_identifier_termt j{"j", smt_bool_sortt{}}; + SECTION("One bound variable") + { + const auto predicate = smt_core_theoryt::equal(i, i); + const smt_forall_termt forall{{i}, predicate}; + CHECK(smt_to_smt2_string(forall) == "(forall ((i (_ BitVec 8))) (= i i))"); + } + SECTION("Two bound variables") + { + const auto predicate = + smt_core_theoryt::make_or(smt_core_theoryt::equal(i, i), j); + const smt_forall_termt forall{{i, j}, predicate}; + CHECK( + smt_to_smt2_string(forall) == + "(forall ((i (_ BitVec 8)) (j Bool)) (or (= i i) j))"); + } +} + +TEST_CASE("SMT exists term to string conversion", "[core][smt2_incremental]") +{ + const smt_identifier_termt i{"i", smt_bit_vector_sortt{8}}; + const smt_identifier_termt j{"j", smt_bool_sortt{}}; + SECTION("One bound variable") + { + const auto predicate = smt_core_theoryt::equal(i, i); + const smt_exists_termt exists{{i}, predicate}; + CHECK(smt_to_smt2_string(exists) == "(exists ((i (_ BitVec 8))) (= i i))"); + } + SECTION("Two bound variables") + { + const auto predicate = + smt_core_theoryt::make_or(smt_core_theoryt::equal(i, i), j); + const smt_exists_termt exists{{i, j}, predicate}; + CHECK( + smt_to_smt2_string(exists) == + "(exists ((i (_ BitVec 8)) (j Bool)) (or (= i i) j))"); + } +}