diff --git a/regression/cbmc-incr-smt2/arrays_traces/array_read.c b/regression/cbmc-incr-smt2/arrays_traces/array_read.c new file mode 100644 index 00000000000..0509fd03d23 --- /dev/null +++ b/regression/cbmc-incr-smt2/arrays_traces/array_read.c @@ -0,0 +1,7 @@ +int main() +{ + int example_array[1025]; + unsigned int index; + __CPROVER_assume(index < 1025); + __CPROVER_assert(example_array[index] != 42, "Array condition"); +} diff --git a/regression/cbmc-incr-smt2/arrays_traces/array_read.desc b/regression/cbmc-incr-smt2/arrays_traces/array_read.desc new file mode 100644 index 00000000000..f9d3522ec42 --- /dev/null +++ b/regression/cbmc-incr-smt2/arrays_traces/array_read.desc @@ -0,0 +1,15 @@ +CORE +array_read.c +--trace +Passing problem to incremental SMT2 solving +\[main\.assertion\.1\] line \d+ Array condition: FAILURE +^Trace for main\.assertion\.1 +example_array=\{ (\d+, )*42 +^EXIT=10$ +^SIGNAL=0$ +-- +-- +Test of reading a value at a non-deterministic index of an array. +Similar to the test in ../arrays/array_read.c, but we want to assert +the value of the array that comes back in the trace to make sure we're +observing the correct values. diff --git a/regression/cbmc-incr-smt2/arrays_traces/array_write.c b/regression/cbmc-incr-smt2/arrays_traces/array_write.c new file mode 100644 index 00000000000..2796e847b7a --- /dev/null +++ b/regression/cbmc-incr-smt2/arrays_traces/array_write.c @@ -0,0 +1,9 @@ +int main() +{ + int example_array[1025]; + unsigned int index; + __CPROVER_assume(index < 1025); + example_array[index] = 42; + __CPROVER_assert(example_array[index] == 42, "Array condition"); + __CPROVER_assert(example_array[index] != 42, "Array condition"); +} diff --git a/regression/cbmc-incr-smt2/arrays_traces/array_write.desc b/regression/cbmc-incr-smt2/arrays_traces/array_write.desc new file mode 100644 index 00000000000..19828b6cb11 --- /dev/null +++ b/regression/cbmc-incr-smt2/arrays_traces/array_write.desc @@ -0,0 +1,12 @@ +CORE +array_write.c +--trace +Passing problem to incremental SMT2 solving +^Trace for main\.assertion\.2 +example_array\[\d{1,4}ll?\]=42 +^EXIT=10$ +^SIGNAL=0$ +-- +-- +Test of writing a value at a non-deterministic index of an array, then asserting +the value we expect. diff --git a/src/solvers/smt2_incremental/response_or_error.h b/src/solvers/smt2_incremental/response_or_error.h new file mode 100644 index 00000000000..48d8eab5b66 --- /dev/null +++ b/src/solvers/smt2_incremental/response_or_error.h @@ -0,0 +1,64 @@ +// Author: Diffblue Ltd. + +#ifndef CPROVER_SOLVERS_SMT2_INCREMENTAL_RESPONSE_OR_ERROR_H +#define CPROVER_SOLVERS_SMT2_INCREMENTAL_RESPONSE_OR_ERROR_H + +#include +#include + +#include +#include + +/// Holds either a valid parsed response or response sub-tree of type \tparam +/// smtt or a collection of message strings explaining why the given input was +/// not valid. +template +class response_or_errort final +{ +public: + explicit response_or_errort(smtt smt) : smt{std::move(smt)} + { + } + + explicit response_or_errort(std::string message) + : messages{std::move(message)} + { + } + + explicit response_or_errort(std::vector messages) + : messages{std::move(messages)} + { + } + + /// \brief Gets the smt response if the response is valid, or returns nullptr + /// otherwise. + const smtt *get_if_valid() const + { + INVARIANT( + smt.has_value() == messages.empty(), + "The response_or_errort class must be in the valid state or error state, " + "exclusively."); + return smt.has_value() ? &smt.value() : nullptr; + } + + /// \brief Gets the error messages if the response is invalid, or returns + /// nullptr otherwise. + const std::vector *get_if_error() const + { + INVARIANT( + smt.has_value() == messages.empty(), + "The response_or_errort class must be in the valid state or error state, " + "exclusively."); + return smt.has_value() ? nullptr : &messages; + } + +private: + // The below two fields could be a single `std::variant` field, if there was + // an implementation of it available in the cbmc repository. However at the + // time of writing we are targeting C++11, `std::variant` was introduced in + // C++17 and we have no backported version. + optionalt smt; + std::vector messages; +}; + +#endif // CPROVER_SOLVERS_SMT2_INCREMENTAL_RESPONSE_OR_ERROR_H diff --git a/src/solvers/smt2_incremental/smt2_incremental_decision_procedure.cpp b/src/solvers/smt2_incremental/smt2_incremental_decision_procedure.cpp index 4d6a1241a58..9e62784a405 100644 --- a/src/solvers/smt2_incremental/smt2_incremental_decision_procedure.cpp +++ b/src/solvers/smt2_incremental/smt2_incremental_decision_procedure.cpp @@ -27,12 +27,13 @@ /// return a success status followed by the actual response of interest. static smt_responset get_response_to_command( smt_base_solver_processt &solver_process, - const smt_commandt &command) + const smt_commandt &command, + const std::unordered_map &identifier_table) { solver_process.send(command); - auto response = solver_process.receive_response(); + auto response = solver_process.receive_response(identifier_table); if(response.cast()) - return solver_process.receive_response(); + return solver_process.receive_response(identifier_table); else return response; } @@ -84,6 +85,7 @@ void smt2_incremental_decision_proceduret::initialize_array_elements( const array_exprt &array, const smt_identifier_termt &array_identifier) { + identifier_table.emplace(array_identifier.identifier(), array_identifier); const std::vector &elements = array.operands(); const typet &index_type = array.type().index_type(); for(std::size_t i = 0; i < elements.size(); ++i) @@ -135,13 +137,15 @@ void send_function_definition( const irep_idt &symbol_identifier, const std::unique_ptr &solver_process, std::unordered_map - &expression_identifiers) + &expression_identifiers, + std::unordered_map &identifier_table) { const smt_declare_function_commandt function{ smt_identifier_termt( symbol_identifier, convert_type_to_smt_sort(expr.type())), {}}; expression_identifiers.emplace(expr, function.identifier()); + identifier_table.emplace(symbol_identifier, function.identifier()); solver_process->send(function); } @@ -183,7 +187,8 @@ void smt2_incremental_decision_proceduret::define_dependent_functions( *symbol_expr, symbol_expr->get_identifier(), solver_process, - expression_identifiers); + expression_identifiers, + identifier_table); } else { @@ -192,6 +197,7 @@ void smt2_incremental_decision_proceduret::define_dependent_functions( const smt_define_function_commandt function{ symbol->name, {}, convert_expr_to_smt(symbol->value)}; expression_identifiers.emplace(*symbol_expr, function.identifier()); + identifier_table.emplace(identifier, function.identifier()); solver_process->send(function); } } @@ -210,7 +216,8 @@ void smt2_incremental_decision_proceduret::define_dependent_functions( *nondet_symbol, nondet_symbol->get_identifier(), solver_process, - expression_identifiers); + expression_identifiers, + identifier_table); } to_be_defined.pop(); } @@ -263,12 +270,36 @@ void smt2_incremental_decision_proceduret::ensure_handle_for_expr_defined( smt_define_function_commandt function{ "B" + std::to_string(handle_sequence()), {}, convert_expr_to_smt(expr)}; expression_handle_identifiers.emplace(expr, function.identifier()); + identifier_table.emplace( + function.identifier().identifier(), function.identifier()); solver_process->send(function); } +void smt2_incremental_decision_proceduret::define_index_identifiers( + const exprt &expr) +{ + expr.visit_pre([&](const exprt &expr_node) { + if(!can_cast_type(expr_node.type())) + return; + if(const auto with_expr = expr_try_dynamic_cast(expr_node)) + { + const auto index_expr = with_expr->where(); + const auto index_term = convert_expr_to_smt(index_expr); + const auto index_identifier = "index_" + std::to_string(index_sequence()); + const auto index_definition = + smt_define_function_commandt{index_identifier, {}, index_term}; + expression_identifiers.emplace(index_expr, index_definition.identifier()); + identifier_table.emplace(index_identifier, index_definition.identifier()); + solver_process->send( + smt_define_function_commandt{index_identifier, {}, index_term}); + } + }); +} + smt_termt smt2_incremental_decision_proceduret::convert_expr_to_smt(const exprt &expr) { + define_index_identifiers(expr); const exprt substituted = substitute_identifiers(expr, expression_identifiers); track_expression_objects(substituted, ns, object_map); @@ -310,13 +341,82 @@ static optionalt get_identifier( return {}; } +array_exprt smt2_incremental_decision_proceduret::get_expr( + const smt_termt &array, + const array_typet &type) const +{ + INVARIANT( + type.is_complete(), "Array size is required for getting array values."); + const auto size = numeric_cast(get(type.size())); + INVARIANT( + size, + "Size of array must be convertible to std::size_t for getting array value"); + std::vector elements; + const auto index_type = type.index_type(); + elements.reserve(*size); + for(std::size_t index = 0; index < size; ++index) + { + elements.push_back(get_expr( + smt_array_theoryt::select( + array, + ::convert_expr_to_smt( + from_integer(index, index_type), + object_map, + pointer_sizes_map, + object_size_function.make_application)), + type.element_type())); + } + return array_exprt{elements, type}; +} + +exprt smt2_incremental_decision_proceduret::get_expr( + const smt_termt &descriptor, + const typet &type) const +{ + const smt_get_value_commandt get_value_command{descriptor}; + const smt_responset response = get_response_to_command( + *solver_process, get_value_command, identifier_table); + const auto get_value_response = response.cast(); + if(!get_value_response) + { + throw analysis_exceptiont{ + "Expected get-value response from solver, but received - " + + response.pretty()}; + } + if(get_value_response->pairs().size() > 1) + { + throw analysis_exceptiont{ + "Expected single valuation pair in get-value response from solver, but " + "received multiple pairs - " + + response.pretty()}; + } + return construct_value_expr_from_smt( + get_value_response->pairs()[0].get().value(), type); +} + exprt smt2_incremental_decision_proceduret::get(const exprt &expr) const { log.conditional_output(log.debug(), [&](messaget::mstreamt &debug) { debug << "`get` - \n " + expr.pretty(2, 0) << messaget::eom; }); - optionalt descriptor = - get_identifier(expr, expression_handle_identifiers, expression_identifiers); + auto descriptor = [&]() -> optionalt { + if(const auto index_expr = expr_try_dynamic_cast(expr)) + { + const auto array = get_identifier( + index_expr->array(), + expression_handle_identifiers, + expression_identifiers); + const auto index = get_identifier( + index_expr->index(), + expression_handle_identifiers, + expression_identifiers); + if(!array || !index) + return {}; + return smt_array_theoryt::select(*array, *index); + } + return get_identifier( + expr, expression_handle_identifiers, expression_identifiers); + }(); if(!descriptor) { if(gather_dependent_expressions(expr).empty()) @@ -349,25 +449,13 @@ exprt smt2_incremental_decision_proceduret::get(const exprt &expr) const return expr; } } - const smt_get_value_commandt get_value_command{*descriptor}; - const smt_responset response = - get_response_to_command(*solver_process, get_value_command); - const auto get_value_response = response.cast(); - if(!get_value_response) - { - throw analysis_exceptiont{ - "Expected get-value response from solver, but received - " + - response.pretty()}; - } - if(get_value_response->pairs().size() > 1) + if(const auto array_type = type_try_dynamic_cast(expr.type())) { - throw analysis_exceptiont{ - "Expected single valuation pair in get-value response from solver, but " - "received multiple pairs - " + - response.pretty()}; + if(array_type->is_incomplete()) + return expr; + return get_expr(*descriptor, *array_type); } - return construct_value_expr_from_smt( - get_value_response->pairs()[0].get().value(), expr.type()); + return get_expr(*descriptor, expr.type()); } void smt2_incremental_decision_proceduret::print_assignment( @@ -464,8 +552,8 @@ decision_proceduret::resultt smt2_incremental_decision_proceduret::dec_solve() { ++number_of_solver_calls; define_object_sizes(); - const smt_responset result = - get_response_to_command(*solver_process, smt_check_sat_commandt{}); + const smt_responset result = get_response_to_command( + *solver_process, smt_check_sat_commandt{}, identifier_table); if(const auto check_sat_response = result.cast()) { if(check_sat_response->kind().cast()) diff --git a/src/solvers/smt2_incremental/smt2_incremental_decision_procedure.h b/src/solvers/smt2_incremental/smt2_incremental_decision_procedure.h index f89fee0b8ab..67fcbbbd919 100644 --- a/src/solvers/smt2_incremental/smt2_incremental_decision_procedure.h +++ b/src/solvers/smt2_incremental/smt2_incremental_decision_procedure.h @@ -52,6 +52,12 @@ class smt2_incremental_decision_proceduret final void push() override; void pop() override; + /// Gets the value of \p descriptor from the solver and returns the solver + /// response expressed as an exprt of type \p type. This is an implementation + /// detail of the `get(exprt)` member function. + exprt get_expr(const smt_termt &descriptor, const typet &type) const; + array_exprt get_expr(const smt_termt &array, const array_typet &type) const; + protected: // Implementation of protected decision_proceduret member function. resultt dec_solve() override; @@ -85,6 +91,7 @@ class smt2_incremental_decision_proceduret final /// \brief Add objects in \p expr to object_map if needed and convert to smt. /// \note This function is non-const because it mutates the object_map. smt_termt convert_expr_to_smt(const exprt &expr); + void define_index_identifiers(const exprt &expr); /// Sends the solver the definitions of the object sizes. void define_object_sizes(); @@ -112,7 +119,7 @@ class smt2_incremental_decision_proceduret final { return next_id++; } - } handle_sequence, array_sequence; + } handle_sequence, array_sequence, index_sequence; /// When the `handle(exprt)` member function is called, the decision procedure /// commands the SMT solver to define a new function corresponding to the /// given expression. The mapping of the expressions to the function @@ -131,6 +138,7 @@ class smt2_incremental_decision_proceduret final /// array expressions when support for them is implemented. std::unordered_map expression_identifiers; + std::unordered_map identifier_table; /// This map is used to track object related state. See documentation in /// object_tracking.h for details. smt_object_mapt object_map; diff --git a/src/solvers/smt2_incremental/smt_array_theory.cpp b/src/solvers/smt2_incremental/smt_array_theory.cpp index 94cb6befb3c..bf9c8349e3b 100644 --- a/src/solvers/smt2_incremental/smt_array_theory.cpp +++ b/src/solvers/smt2_incremental/smt_array_theory.cpp @@ -14,15 +14,24 @@ smt_sortt smt_array_theoryt::selectt::return_sort( return array.get_sort().cast()->element_sort(); } -void smt_array_theoryt::selectt::validate( +std::vector smt_array_theoryt::selectt::validation_errors( const smt_termt &array, const smt_termt &index) { const auto array_sort = array.get_sort().cast(); - INVARIANT(array_sort, "\"select\" may only select from an array."); - INVARIANT( - array_sort->index_sort() == index.get_sort(), - "Sort of arrays index must match the sort of the index supplied."); + if(!array_sort) + return {"\"select\" may only select from an array."}; + if(array_sort->index_sort() != index.get_sort()) + return {"Sort of arrays index must match the sort of the index supplied."}; + return {}; +} + +void smt_array_theoryt::selectt::validate( + const smt_termt &array, + const smt_termt &index) +{ + const auto validation_errors = selectt::validation_errors(array, index); + INVARIANT(validation_errors.empty(), validation_errors[0]); } const smt_function_application_termt::factoryt diff --git a/src/solvers/smt2_incremental/smt_array_theory.h b/src/solvers/smt2_incremental/smt_array_theory.h index d56f049d745..ed48c09df3d 100644 --- a/src/solvers/smt2_incremental/smt_array_theory.h +++ b/src/solvers/smt2_incremental/smt_array_theory.h @@ -13,6 +13,8 @@ class smt_array_theoryt static const char *identifier(); static smt_sortt return_sort(const smt_termt &array, const smt_termt &index); + static std::vector + validation_errors(const smt_termt &array, const smt_termt &index); static void validate(const smt_termt &array, const smt_termt &index); }; static const smt_function_application_termt::factoryt select; diff --git a/src/solvers/smt2_incremental/smt_response_validation.cpp b/src/solvers/smt2_incremental/smt_response_validation.cpp index 9634eb79339..c81026ede29 100644 --- a/src/solvers/smt2_incremental/smt_response_validation.cpp +++ b/src/solvers/smt2_incremental/smt_response_validation.cpp @@ -21,46 +21,13 @@ #include #include -#include - -template -response_or_errort::response_or_errort(smtt smt) : smt{std::move(smt)} -{ -} - -template -response_or_errort::response_or_errort(std::string message) - : messages{std::move(message)} -{ -} - -template -response_or_errort::response_or_errort(std::vector messages) - : messages{std::move(messages)} -{ -} +#include "smt_array_theory.h" -template -const smtt *response_or_errort::get_if_valid() const -{ - INVARIANT( - smt.has_value() == messages.empty(), - "The response_or_errort class must be in the valid state or error state, " - "exclusively."); - return smt.has_value() ? &smt.value() : nullptr; -} - -template -const std::vector *response_or_errort::get_if_error() const -{ - INVARIANT( - smt.has_value() == messages.empty(), - "The response_or_errort class must be in the valid state or error state, " - "exclusively."); - return smt.has_value() ? nullptr : &messages; -} +#include -template class response_or_errort; +static response_or_errort validate_term( + const irept &parse_tree, + const std::unordered_map &identifier_table); // Implementation detail of `collect_messages` below. template @@ -271,52 +238,62 @@ valid_smt_bit_vector_constant(const irept &parse_tree) return {}; } -static optionalt valid_term(const irept &parse_tree) +static optionalt> try_select_validation( + const irept &parse_tree, + const std::unordered_map &identifier_table) { - if(const auto smt_bool = valid_smt_bool(parse_tree)) - return {*smt_bool}; - if(const auto bit_vector_constant = valid_smt_bit_vector_constant(parse_tree)) - return {*bit_vector_constant}; - return {}; + if(parse_tree.get_sub().empty()) + return {}; + if(parse_tree.get_sub()[0].id() != "select") + return {}; + if(parse_tree.get_sub().size() != 3) + { + return response_or_errort{ + "\"select\" is expected to have 2 arguments, but " + + std::to_string(parse_tree.get_sub().size()) + + " arguments were found - \"" + print_parse_tree(parse_tree) + "\"."}; + } + const auto array = validate_term(parse_tree.get_sub()[1], identifier_table); + const auto index = validate_term(parse_tree.get_sub()[2], identifier_table); + const auto messages = collect_messages(array, index); + if(!messages.empty()) + return response_or_errort{messages}; + return {smt_array_theoryt::select.validation( + *array.get_if_valid(), *index.get_if_valid())}; } -static response_or_errort validate_term(const irept &parse_tree) +static response_or_errort validate_term( + const irept &parse_tree, + const std::unordered_map &identifier_table) { - if(const auto term = valid_term(parse_tree)) - return response_or_errort{*term}; + if(const auto smt_bool = valid_smt_bool(parse_tree)) + return response_or_errort{*smt_bool}; + if(const auto bit_vector_constant = valid_smt_bit_vector_constant(parse_tree)) + return response_or_errort{*bit_vector_constant}; + const auto find_result = identifier_table.find(parse_tree.id()); + if(find_result != identifier_table.end()) + return response_or_errort{find_result->second}; + if( + const auto select_validation = + try_select_validation(parse_tree, identifier_table)) + { + return *select_validation; + } return response_or_errort{"Unrecognised SMT term - \"" + print_parse_tree(parse_tree) + "\"."}; } -static response_or_errort -validate_smt_descriptor(const irept &parse_tree, const smt_sortt &sort) -{ - if(const auto term = valid_term(parse_tree)) - return response_or_errort{*term}; - const auto id = parse_tree.id(); - if(!id.empty()) - return response_or_errort{smt_identifier_termt{id, sort}}; - return response_or_errort{ - "Expected descriptor SMT term, found - \"" + print_parse_tree(parse_tree) + - "\"."}; -} - static response_or_errort -validate_valuation_pair(const irept &pair_parse_tree) +validate_valuation_pair( + const irept &pair_parse_tree, + const std::unordered_map &identifier_table) { PRECONDITION(pair_parse_tree.get_sub().size() == 2); const auto &descriptor = pair_parse_tree.get_sub()[0]; const auto &value = pair_parse_tree.get_sub()[1]; - const response_or_errort value_validation = validate_term(value); - if(const auto value_errors = value_validation.get_if_error()) - { - return response_or_errort{ - *value_errors}; - } - const smt_termt value_term = *value_validation.get_if_valid(); return validation_propagating( - validate_smt_descriptor(descriptor, value_term.get_sort()), - validate_term(value)); + validate_term(descriptor, identifier_table), + validate_term(value, identifier_table)); } /// \returns: A response or error in the case where the parse tree appears to be @@ -325,7 +302,9 @@ validate_valuation_pair(const irept &pair_parse_tree) /// keyword, it will be considered that the response is intended to be a /// get-value response if it is composed of a collection of one or more pairs. static optionalt> -valid_smt_get_value_response(const irept &parse_tree) +valid_smt_get_value_response( + const irept &parse_tree, + const std::unordered_map &identifier_table) { // Shape matching for does this look like a get value response? if(!parse_tree.id().empty()) @@ -338,7 +317,8 @@ valid_smt_get_value_response(const irept &parse_tree) std::vector valuation_pairs; for(const auto &pair : parse_tree.get_sub()) { - const auto pair_validation_result = validate_valuation_pair(pair); + const auto pair_validation_result = + validate_valuation_pair(pair, identifier_table); if(const auto error = pair_validation_result.get_if_error()) error_messages.insert(error_messages.end(), error->begin(), error->end()); if(const auto valid_pair = pair_validation_result.get_if_valid()) @@ -353,7 +333,9 @@ valid_smt_get_value_response(const irept &parse_tree) } } -response_or_errort validate_smt_response(const irept &parse_tree) +response_or_errort validate_smt_response( + const irept &parse_tree, + const std::unordered_map &identifier_table) { if(parse_tree.id() == "sat") return response_or_errort{ @@ -370,8 +352,12 @@ response_or_errort validate_smt_response(const irept &parse_tree) return response_or_errort{smt_success_responset{}}; if(parse_tree.id() == "unsupported") return response_or_errort{smt_unsupported_responset{}}; - if(const auto get_value_response = valid_smt_get_value_response(parse_tree)) + if( + const auto get_value_response = + valid_smt_get_value_response(parse_tree, identifier_table)) + { return *get_value_response; + } return response_or_errort{"Invalid SMT response \"" + id2string(parse_tree.id()) + "\""}; } diff --git a/src/solvers/smt2_incremental/smt_response_validation.h b/src/solvers/smt2_incremental/smt_response_validation.h index 81b3d8f411b..306e0bdaa5c 100644 --- a/src/solvers/smt2_incremental/smt_response_validation.h +++ b/src/solvers/smt2_incremental/smt_response_validation.h @@ -3,42 +3,13 @@ #ifndef CPROVER_SOLVERS_SMT2_INCREMENTAL_SMT_RESPONSE_VALIDATION_H #define CPROVER_SOLVERS_SMT2_INCREMENTAL_SMT_RESPONSE_VALIDATION_H -#include -#include #include -#include - -#include -#include - -/// Holds either a valid parsed response or response sub-tree of type \tparam -/// smtt or a collection of message strings explaining why the given input was -/// not valid. -template -class response_or_errort final -{ -public: - explicit response_or_errort(smtt smt); - explicit response_or_errort(std::string message); - explicit response_or_errort(std::vector messages); - /// \brief Gets the smt response if the response is valid, or returns nullptr - /// otherwise. - const smtt *get_if_valid() const; - /// \brief Gets the error messages if the response is invalid, or returns - /// nullptr otherwise. - const std::vector *get_if_error() const; - -private: - // The below two fields could be a single `std::variant` field, if there was - // an implementation of it available in the cbmc repository. However at the - // time of writing we are targeting C++11, `std::variant` was introduced in - // C++17 and we have no backported version. - optionalt smt; - std::vector messages; -}; +#include +#include -NODISCARD response_or_errort -validate_smt_response(const irept &parse_tree); +NODISCARD response_or_errort validate_smt_response( + const irept &parse_tree, + const std::unordered_map &identifier_table); #endif // CPROVER_SOLVERS_SMT2_INCREMENTAL_SMT_RESPONSE_VALIDATION_H diff --git a/src/solvers/smt2_incremental/smt_solver_process.cpp b/src/solvers/smt2_incremental/smt_solver_process.cpp index 1b0079ac33e..9039341653e 100644 --- a/src/solvers/smt2_incremental/smt_solver_process.cpp +++ b/src/solvers/smt2_incremental/smt_solver_process.cpp @@ -53,7 +53,8 @@ static void handle_invalid_smt( throw analysis_exceptiont{"Invalid SMT response received from solver."}; } -smt_responset smt_piped_solver_processt::receive_response() +smt_responset smt_piped_solver_processt::receive_response( + const std::unordered_map &identifier_table) { const auto response_text = process.wait_receive(); log.debug() << "Solver response - " << response_text << messaget::eom; @@ -61,7 +62,8 @@ smt_responset smt_piped_solver_processt::receive_response() const auto parse_tree = smt2irep(response_stream, log.get_message_handler()); if(!parse_tree) throw deserialization_exceptiont{"Incomplete SMT response."}; - const auto validation_result = validate_smt_response(*parse_tree); + const auto validation_result = + validate_smt_response(*parse_tree, identifier_table); if(const auto validation_errors = validation_result.get_if_error()) handle_invalid_smt(*validation_errors, log); return *validation_result.get_if_valid(); diff --git a/src/solvers/smt2_incremental/smt_solver_process.h b/src/solvers/smt2_incremental/smt_solver_process.h index 3c174b486b0..0b677688228 100644 --- a/src/solvers/smt2_incremental/smt_solver_process.h +++ b/src/solvers/smt2_incremental/smt_solver_process.h @@ -21,7 +21,9 @@ class smt_base_solver_processt /// solver process. virtual void send(const smt_commandt &command) = 0; - virtual smt_responset receive_response() = 0; + virtual smt_responset + receive_response(const std::unordered_map + &identifier_table) = 0; virtual ~smt_base_solver_processt() = default; }; @@ -41,7 +43,9 @@ class smt_piped_solver_processt : public smt_base_solver_processt void send(const smt_commandt &smt_command) override; - smt_responset receive_response() override; + smt_responset receive_response( + const std::unordered_map &identifier_table) + override; ~smt_piped_solver_processt() override = default; diff --git a/src/solvers/smt2_incremental/smt_terms.h b/src/solvers/smt2_incremental/smt_terms.h index 53faa567c5d..fe7e51f5f94 100644 --- a/src/solvers/smt2_incremental/smt_terms.h +++ b/src/solvers/smt2_incremental/smt_terms.h @@ -5,6 +5,7 @@ #include +#include #include #include #include @@ -204,6 +205,20 @@ class smt_function_application_termt : public smt_termt function.identifier(), std::move(return_sort), indices(function)}, {std::forward(arguments)...}}; } + + template + response_or_errort + validation(argument_typest &&... arguments) const + { + const auto validation_errors = function.validation_errors(arguments...); + if(!validation_errors.empty()) + return response_or_errort{validation_errors}; + auto return_sort = function.return_sort(arguments...); + return response_or_errort{smt_function_application_termt{ + smt_identifier_termt{ + function.identifier(), std::move(return_sort), indices(function)}, + {std::forward(arguments)...}}}; + } }; }; diff --git a/unit/solvers/smt2_incremental/smt2_incremental_decision_procedure.cpp b/unit/solvers/smt2_incremental/smt2_incremental_decision_procedure.cpp index a5e5c564edc..e6f5a1ba1b8 100644 --- a/unit/solvers/smt2_incremental/smt2_incremental_decision_procedure.cpp +++ b/unit/solvers/smt2_incremental/smt2_incremental_decision_procedure.cpp @@ -81,7 +81,9 @@ class smt_mock_solver_processt : public smt_base_solver_processt _send(smt_command); } - smt_responset receive_response() override + smt_responset receive_response( + const std::unordered_map &identifier_table) + override { return _receive(); } @@ -614,6 +616,27 @@ TEST_CASE( smt_assert_commandt{smt_core_theoryt::equal( foo_term, smt_array_theoryt::select(array_term, index_term))}}; REQUIRE(test.sent_commands == expected_commands); + + SECTION("Get values of array literal") + { + test.sent_commands.clear(); + test.mock_responses = { + // get-value response for array_size + smt_get_value_responset{ + {{{smt_bit_vector_constant_termt{2, 32}}, + smt_bit_vector_constant_termt{2, 32}}}}, + // get-value response for first element + smt_get_value_responset{ + {{{smt_array_theoryt::select( + array_term, smt_bit_vector_constant_termt{0, 32})}, + smt_bit_vector_constant_termt{9, 8}}}}, + // get-value response for second element + smt_get_value_responset{ + {{{smt_array_theoryt::select( + array_term, smt_bit_vector_constant_termt{1, 32})}, + smt_bit_vector_constant_termt{12, 8}}}}}; + REQUIRE(test.procedure.get(array_literal) == array_literal); + } } SECTION("array_of_exprt - all elements set to a given value") { diff --git a/unit/solvers/smt2_incremental/smt_response_validation.cpp b/unit/solvers/smt2_incremental/smt_response_validation.cpp index 0892b31ed07..5fe8344f6bb 100644 --- a/unit/solvers/smt2_incremental/smt_response_validation.cpp +++ b/unit/solvers/smt2_incremental/smt_response_validation.cpp @@ -1,10 +1,11 @@ // Author: Diffblue Ltd. -#include -#include +#include +#include #include -#include +#include +#include // Debug printer for `smt_responset`. This will be used by the catch framework // for printing in the case of failed checks / requirements. @@ -35,27 +36,27 @@ TEST_CASE("response_or_errort storage", "[core][smt2_incremental]") TEST_CASE("Validation of check-sat repsonses", "[core][smt2_incremental]") { CHECK( - *validate_smt_response(*smt2irep("sat").parsed_output).get_if_valid() == + *validate_smt_response(*smt2irep("sat").parsed_output, {}).get_if_valid() == smt_check_sat_responset{smt_sat_responset{}}); CHECK( - *validate_smt_response(*smt2irep("unsat").parsed_output).get_if_valid() == - smt_check_sat_responset{smt_unsat_responset{}}); + *validate_smt_response(*smt2irep("unsat").parsed_output, {}) + .get_if_valid() == smt_check_sat_responset{smt_unsat_responset{}}); CHECK( - *validate_smt_response(*smt2irep("unknown").parsed_output).get_if_valid() == - smt_check_sat_responset{smt_unknown_responset{}}); + *validate_smt_response(*smt2irep("unknown").parsed_output, {}) + .get_if_valid() == smt_check_sat_responset{smt_unknown_responset{}}); } TEST_CASE("Validation of SMT success response", "[core][smt2_incremental]") { CHECK( - *validate_smt_response(*smt2irep("success").parsed_output).get_if_valid() == - smt_success_responset{}); + *validate_smt_response(*smt2irep("success").parsed_output, {}) + .get_if_valid() == smt_success_responset{}); } TEST_CASE("Validation of SMT unsupported response", "[core][smt2_incremental]") { CHECK( - *validate_smt_response(*smt2irep("unsupported").parsed_output) + *validate_smt_response(*smt2irep("unsupported").parsed_output, {}) .get_if_valid() == smt_unsupported_responset{}); } @@ -66,12 +67,13 @@ TEST_CASE( SECTION("Parse tree produced is not a valid SMT-LIB version 2.6 response") { const response_or_errort validation_response = - validate_smt_response(*smt2irep("foobar").parsed_output); + validate_smt_response(*smt2irep("foobar").parsed_output, {}); CHECK( *validation_response.get_if_error() == std::vector{"Invalid SMT response \"foobar\""}); CHECK( - *validate_smt_response(*smt2irep("()").parsed_output).get_if_error() == + *validate_smt_response(*smt2irep("()").parsed_output, {}) + .get_if_error() == std::vector{"Invalid SMT response \"\""}); } } @@ -80,15 +82,17 @@ TEST_CASE("Validation of SMT error response", "[core][smt2_incremental]") { CHECK( *validate_smt_response( - *smt2irep("(error \"Test error message.\")").parsed_output) + *smt2irep("(error \"Test error message.\")").parsed_output, {}) .get_if_valid() == smt_error_responset{"Test error message."}); CHECK( - *validate_smt_response(*smt2irep("(error)").parsed_output).get_if_error() == + *validate_smt_response(*smt2irep("(error)").parsed_output, {}) + .get_if_error() == std::vector{"Error response is missing the error message."}); CHECK( *validate_smt_response( *smt2irep("(error \"Test error message1.\" \"Test error message2.\")") - .parsed_output) + .parsed_output, + {}) .get_if_error() == std::vector{"Error response has multiple error messages - \"\n" "0: error\n" @@ -98,17 +102,32 @@ TEST_CASE("Validation of SMT error response", "[core][smt2_incremental]") TEST_CASE("smt get-value response validation", "[core][smt2_incremental]") { + const auto table_with_identifiers = + [](const std::vector> &identifiers) { + std::unordered_map table; + for(auto &identifier_pair : identifiers) + { + auto &identifier = identifier_pair.first; + auto &sort = identifier_pair.second; + table.insert({identifier, smt_identifier_termt{identifier, sort}}); + } + return table; + }; SECTION("Boolean sorted values.") { + const auto identifier_table = + table_with_identifiers({{"a", smt_bool_sortt{}}}); const response_or_errort true_response = - validate_smt_response(*smt2irep("((a true))").parsed_output); + validate_smt_response( + *smt2irep("((a true))").parsed_output, identifier_table); CHECK( *true_response.get_if_valid() == smt_get_value_responset{{smt_get_value_responset::valuation_pairt{ smt_identifier_termt{"a", smt_bool_sortt{}}, smt_bool_literal_termt{true}}}}); const response_or_errort false_response = - validate_smt_response(*smt2irep("((a false))").parsed_output); + validate_smt_response( + *smt2irep("((a false))").parsed_output, identifier_table); CHECK( *false_response.get_if_valid() == smt_get_value_responset{{smt_get_value_responset::valuation_pairt{ @@ -117,10 +136,13 @@ TEST_CASE("smt get-value response validation", "[core][smt2_incremental]") } SECTION("Bit vector sorted values.") { + const auto identifier_table = + table_with_identifiers({{"a", smt_bit_vector_sortt{8}}}); SECTION("Hex value") { const response_or_errort response_255 = - validate_smt_response(*smt2irep("((a #xff))").parsed_output); + validate_smt_response( + *smt2irep("((a #xff))").parsed_output, identifier_table); CHECK( *response_255.get_if_valid() == smt_get_value_responset{{smt_get_value_responset::valuation_pairt{ @@ -130,17 +152,70 @@ TEST_CASE("smt get-value response validation", "[core][smt2_incremental]") SECTION("Binary value") { const response_or_errort response_42 = - validate_smt_response(*smt2irep("((a #b00101010))").parsed_output); + validate_smt_response( + *smt2irep("((a #b00101010))").parsed_output, identifier_table); CHECK( *response_42.get_if_valid() == smt_get_value_responset{{smt_get_value_responset::valuation_pairt{ smt_identifier_termt{"a", smt_bit_vector_sortt{8}}, smt_bit_vector_constant_termt{42, 8}}}}); } + SECTION("Array values") + { + // Construction of select part of response + const smt_bit_vector_constant_termt index_term( + 0xA, smt_bit_vector_sortt(32)); + const smt_sortt value_sort(smt_bit_vector_sortt(32)); + const smt_identifier_termt array_term( + "b", smt_array_sortt(index_term.get_sort(), value_sort)); + const smt_function_application_termt select = + smt_array_theoryt::select(array_term, index_term); + // Identifier table needs to contain identifier of array + const auto identifier_table = table_with_identifiers( + {{"b", + smt_array_sortt{ + smt_bit_vector_sortt{32}, smt_bit_vector_sortt{32}}}}); + + SECTION("Valid application of smt_array_theoryt::select") + { + const response_or_errort response_get_select = + validate_smt_response( + *smt2irep("(((select |b| (_ bv10 32)) #x0000002a))").parsed_output, + identifier_table); + CHECK( + *response_get_select.get_if_valid() == + smt_get_value_responset{{smt_get_value_responset::valuation_pairt{ + select, smt_bit_vector_constant_termt{0x2A, 32}}}}); + } + SECTION("Invalid due to selecting from non-array") + { + const response_or_errort response_get_select = + validate_smt_response( + *smt2irep("(((select (_ bv10 32) (_ bv10 32)) #x0000002a))") + .parsed_output, + identifier_table); + CHECK( + *response_get_select.get_if_error() == + std::vector{ + "\"select\" may only select from an array."}); + } + SECTION("Invalid due to selecting invalid index sort") + { + const response_or_errort response_get_select = + validate_smt_response( + *smt2irep("(((select |b| (_ bv10 16)) #x0000002a))").parsed_output, + identifier_table); + CHECK( + *response_get_select.get_if_error() == + std::vector{ + "Sort of arrays index must match the sort of the index supplied."}); + } + } SECTION("Descriptors which are bit vector constants") { const response_or_errort response_descriptor = - validate_smt_response(*smt2irep("(((_ bv255 8) #x2A))").parsed_output); + validate_smt_response( + *smt2irep("(((_ bv255 8) #x2A))").parsed_output, {}); CHECK( *response_descriptor.get_if_valid() == smt_get_value_responset{{smt_get_value_responset::valuation_pairt{ @@ -152,60 +227,60 @@ TEST_CASE("smt get-value response validation", "[core][smt2_incremental]") { const response_or_errort pair_value_response = validate_smt_response( - *smt2irep("(((_ bv256 8) #xff))").parsed_output); + *smt2irep("(((_ bv256 8) #xff))").parsed_output, {}); CHECK( *pair_value_response.get_if_error() == - std::vector{ - "Expected descriptor SMT term, found - \"\n" - "0: _\n" - "1: bv256\n" - "2: 8\"."}); + std::vector{"Unrecognised SMT term - \"\n" + "0: _\n" + "1: bv256\n" + "2: 8\"."}); } SECTION("Value missing bv prefix.") { const response_or_errort pair_value_response = - validate_smt_response(*smt2irep("(((_ 42 8) #xff))").parsed_output); + validate_smt_response( + *smt2irep("(((_ 42 8) #xff))").parsed_output, {}); CHECK( *pair_value_response.get_if_error() == - std::vector{ - "Expected descriptor SMT term, found - \"\n" - "0: _\n" - "1: 42\n" - "2: 8\"."}); + std::vector{"Unrecognised SMT term - \"\n" + "0: _\n" + "1: 42\n" + "2: 8\"."}); } SECTION("Hex value.") { const response_or_errort pair_value_response = validate_smt_response( - *smt2irep("(((_ bv2A 8) #xff))").parsed_output); + *smt2irep("(((_ bv2A 8) #xff))").parsed_output, {}); CHECK( *pair_value_response.get_if_error() == - std::vector{ - "Expected descriptor SMT term, found - \"\n" - "0: _\n" - "1: bv2A\n" - "2: 8\"."}); + std::vector{"Unrecognised SMT term - \"\n" + "0: _\n" + "1: bv2A\n" + "2: 8\"."}); } SECTION("Zero width.") { const response_or_errort pair_value_response = validate_smt_response( - *smt2irep("(((_ bv0 0) #xff))").parsed_output); + *smt2irep("(((_ bv0 0) #xff))").parsed_output, {}); CHECK( *pair_value_response.get_if_error() == - std::vector{ - "Expected descriptor SMT term, found - \"\n" - "0: _\n" - "1: bv0\n" - "2: 0\"."}); + std::vector{"Unrecognised SMT term - \"\n" + "0: _\n" + "1: bv0\n" + "2: 0\"."}); } } } } SECTION("Multiple valuation pairs.") { + const auto identifier_table = table_with_identifiers( + {{"a", smt_bool_sortt{}}, {"b", smt_bool_sortt{}}}); const response_or_errort two_pair_response = - validate_smt_response(*smt2irep("((a true) (b false))").parsed_output); + validate_smt_response( + *smt2irep("((a true) (b false))").parsed_output, identifier_table); CHECK( *two_pair_response.get_if_valid() == smt_get_value_responset{{smt_get_value_responset::valuation_pairt{ @@ -217,13 +292,25 @@ TEST_CASE("smt get-value response validation", "[core][smt2_incremental]") } SECTION("Invalid terms.") { + const auto identifier_table = table_with_identifiers( + {{"a", smt_bit_vector_sortt{16}}, {"b", smt_bit_vector_sortt{16}}}); const response_or_errort empty_value_response = - validate_smt_response(*smt2irep("((a ())))").parsed_output); + validate_smt_response( + *smt2irep("((a ())))").parsed_output, identifier_table); CHECK( *empty_value_response.get_if_error() == std::vector{"Unrecognised SMT term - \"\"."}); + const response_or_errort unknown_identifier_response = + validate_smt_response( + *smt2irep("((foo bar)))").parsed_output, identifier_table); + CHECK( + *unknown_identifier_response.get_if_error() == + std::vector{ + "Unrecognised SMT term - \"foo\".", + "Unrecognised SMT term - \"bar\"."}); const response_or_errort pair_value_response = - validate_smt_response(*smt2irep("((a (#xF00D #xBAD))))").parsed_output); + validate_smt_response( + *smt2irep("((a (#xF00D #xBAD))))").parsed_output, identifier_table); CHECK( *pair_value_response.get_if_error() == std::vector{"Unrecognised SMT term - \"\n" @@ -231,7 +318,8 @@ TEST_CASE("smt get-value response validation", "[core][smt2_incremental]") "1: #xBAD\"."}); const response_or_errort two_pair_value_response = validate_smt_response( - *smt2irep("((a (#xF00D #xBAD)) (b (#xDEAD #xFA11)))").parsed_output); + *smt2irep("((a (#xF00D #xBAD)) (b (#xDEAD #xFA11)))").parsed_output, + identifier_table); CHECK( *two_pair_value_response.get_if_error() == std::vector{"Unrecognised SMT term - \"\n" @@ -241,14 +329,15 @@ TEST_CASE("smt get-value response validation", "[core][smt2_incremental]") "0: #xDEAD\n" "1: #xFA11\"."}); const response_or_errort empty_descriptor_response = - validate_smt_response(*smt2irep("((() true))").parsed_output); + validate_smt_response(*smt2irep("((() true))").parsed_output, {}); CHECK( *empty_descriptor_response.get_if_error() == - std::vector{"Expected descriptor SMT term, found - \"\"."}); + std::vector{"Unrecognised SMT term - \"\"."}); const response_or_errort empty_pair = - validate_smt_response(*smt2irep("((() ())))").parsed_output); + validate_smt_response(*smt2irep("((() ())))").parsed_output, {}); CHECK( *empty_pair.get_if_error() == - std::vector{"Unrecognised SMT term - \"\"."}); + std::vector{ + "Unrecognised SMT term - \"\".", "Unrecognised SMT term - \"\"."}); } }