diff --git a/chdb/__init__.py b/chdb/__init__.py index b34b162494c..645b4535ac6 100644 --- a/chdb/__init__.py +++ b/chdb/__init__.py @@ -2,6 +2,12 @@ import os +_arrow_format = set({"dataframe", "arrowtable"}) +_process_result_format_funs = { + "dataframe" : lambda x : to_df(x), + "arrowtable": lambda x : to_arrowTable(x) + } + # If any UDF is defined, the path of the UDF will be set to this variable # and the path will be deleted when the process exits # UDF config path will be f"{g_udf_path}/udf_config.xml" @@ -60,9 +66,10 @@ def query(sql, output_format="CSV", path="", udf_path=""): if udf_path != "": g_udf_path = udf_path lower_output_format = output_format.lower() - if lower_output_format == "dataframe": - return to_df(_chdb.query(sql, "Arrow", path=path, udf_path=g_udf_path)) - elif lower_output_format == "arrowtable": - return to_arrowTable(_chdb.query(sql, "Arrow", path=path, udf_path=g_udf_path)) - else: - return _chdb.query(sql, output_format, path=path, udf_path=g_udf_path) + result_func = _process_result_format_funs.get(lower_output_format, lambda x : x) + if lower_output_format in _arrow_format: + output_format = "Arrow" + res = _chdb.query(sql, output_format, path=path, udf_path=g_udf_path) + if res.has_error(): + raise Exception(res.error_message()) + return result_func(res) diff --git a/chdb/dbapi/connections.py b/chdb/dbapi/connections.py index 0d719894fe3..2cbbc082ce9 100644 --- a/chdb/dbapi/connections.py +++ b/chdb/dbapi/connections.py @@ -123,7 +123,10 @@ def _execute_command(self, sql): if DEBUG: print("DEBUG: query:", sql) try: - self._resp = self._session.query(sql, fmt="JSON").data() + res = self._session.query(sql, output_format="JSON") + if res.has_error(): + raise err.DatabaseError(res.error_message()) + self._resp = res.data() except Exception as error: raise err.InterfaceError("query err: %s" % error) diff --git a/programs/local/LocalChdb.cpp b/programs/local/LocalChdb.cpp index e332c15dc1a..195aaa4cfb0 100644 --- a/programs/local/LocalChdb.cpp +++ b/programs/local/LocalChdb.cpp @@ -6,7 +6,7 @@ extern bool inside_main = true; -local_result * queryToBuffer( +local_result_v2 * queryToBuffer( const std::string & queryStr, const std::string & output_format = "CSV", const std::string & path = {}, @@ -51,7 +51,7 @@ local_result * queryToBuffer( for (auto & arg : argv) argv_char.push_back(const_cast(arg.c_str())); - return query_stable(argv_char.size(), argv_char.data()); + return query_stable_v2(argv_char.size(), argv_char.data()); } // Pybind11 will take over the ownership of the `query_result` object @@ -132,7 +132,7 @@ PYBIND11_MODULE(_chdb, m) .def("view", &memoryview_wrapper::view); py::class_(m, "query_result") - .def(py::init(), py::return_value_policy::take_ownership) + .def(py::init(), py::return_value_policy::take_ownership) .def("data", &query_result::data) .def("bytes", &query_result::bytes) .def("__str__", &query_result::str) @@ -142,7 +142,9 @@ PYBIND11_MODULE(_chdb, m) .def("rows_read", &query_result::rows_read) .def("bytes_read", &query_result::bytes_read) .def("elapsed", &query_result::elapsed) - .def("get_memview", &query_result::get_memview); + .def("get_memview", &query_result::get_memview) + .def("has_error", &query_result::has_error) + .def("error_message", &query_result::error_message); m.def( diff --git a/programs/local/LocalChdb.h b/programs/local/LocalChdb.h index 08c64bf5f1b..a6f3818b733 100644 --- a/programs/local/LocalChdb.h +++ b/programs/local/LocalChdb.h @@ -15,13 +15,13 @@ class __attribute__((visibility("default"))) query_result; class local_result_wrapper { private: - local_result * result; + local_result_v2 * result; public: - local_result_wrapper(local_result * result) : result(result) { } + local_result_wrapper(local_result_v2 * result) : result(result) { } ~local_result_wrapper() { - free_result(result); + free_result_v2(result); delete result; } char * data() @@ -81,6 +81,22 @@ class local_result_wrapper } return result->elapsed; } + bool has_error() + { + if (result == nullptr) + { + return false; + } + return result->error_message != nullptr; + } + py::str error_message() + { + if (has_error()) + { + return py::str(result->error_message); + } + return py::str(); + } }; class query_result @@ -89,7 +105,7 @@ class query_result std::shared_ptr result_wrapper; public: - query_result(local_result * result) : result_wrapper(std::make_shared(result)) { } + query_result(local_result_v2 * result) : result_wrapper(std::make_shared(result)) { } ~query_result() { } char * data() { return result_wrapper->data(); } py::bytes bytes() { return result_wrapper->bytes(); } @@ -98,6 +114,8 @@ class query_result size_t rows_read() { return result_wrapper->rows_read(); } size_t bytes_read() { return result_wrapper->bytes_read(); } double elapsed() { return result_wrapper->elapsed(); } + bool has_error() { return result_wrapper->has_error(); } + py::str error_message() { return result_wrapper->error_message(); } memoryview_wrapper * get_memview(); }; diff --git a/programs/local/LocalServer.cpp b/programs/local/LocalServer.cpp index e7922c62f71..d5beb4e762b 100644 --- a/programs/local/LocalServer.cpp +++ b/programs/local/LocalServer.cpp @@ -21,7 +21,6 @@ #include #include #include -#include #include #include #include @@ -34,15 +33,12 @@ #include #include #include -#include #include #include #include -#include #include #include #include -#include #include #include #include @@ -59,6 +55,7 @@ #include #include #include +#include #include "config.h" @@ -563,14 +560,14 @@ catch (const DB::Exception & e) cleanup(); bool need_print_stack_trace = config().getBool("stacktrace", false); - std::cerr << getExceptionMessage(e, need_print_stack_trace, true) << std::endl; + error_message_oss << getExceptionMessage(e, need_print_stack_trace, true); return e.code() ? e.code() : -1; } catch (...) { cleanup(); - std::cerr << getCurrentExceptionMessage(false) << std::endl; + error_message_oss << getCurrentExceptionMessage(false); return getCurrentExceptionCode(); } @@ -1024,12 +1021,26 @@ void LocalServer::readArguments(int argc, char ** argv, Arguments & common_argum class query_result_ { public: - uint64_t rows; - uint64_t bytes; - double elapsed; - std::vector * buf; + explicit query_result_(std::vector* buf, uint64_t rows, + uint64_t bytes, double elapsed): + rows_(rows), bytes_(bytes), elapsed_(elapsed), + buf_(buf) { } + + explicit query_result_(std::string&& error_msg): error_msg_(error_msg) { } + + std::string string() + { + return std::string(buf_->begin(), buf_->end()); + } + + uint64_t rows_; + uint64_t bytes_; + double elapsed_; + std::vector * buf_; + std::string error_msg_; }; + std::unique_ptr pyEntryClickHouseLocal(int argc, char ** argv) { try @@ -1039,18 +1050,13 @@ std::unique_ptr pyEntryClickHouseLocal(int argc, char ** argv) int ret = app.run(); if (ret == 0) { - auto result = std::make_unique(); - result->buf = app.getQueryOutputVector(); - result->rows = app.getProcessedRows(); - result->bytes = app.getProcessedBytes(); - result->elapsed = app.getElapsedTime(); - - // std::cerr << std::string(out->begin(), out->end()) << std::endl; - return result; - } - else - { - return nullptr; + return std::make_unique( + app.getQueryOutputVector(), + app.getProcessedRows(), + app.getProcessedBytes(), + app.getElapsedTime()); + } else { + return std::make_unique(app.get_error_msg()); } } catch (const DB::Exception & e) @@ -1072,29 +1078,73 @@ std::unique_ptr pyEntryClickHouseLocal(int argc, char ** argv) local_result * query_stable(int argc, char ** argv) { auto result = pyEntryClickHouseLocal(argc, argv); - if (!result || !result->buf) + if (result->error_msg_.empty()) { return nullptr; } local_result * res = new local_result; - res->len = result->buf->size(); - res->buf = result->buf->data(); - res->_vec = result->buf; - res->rows_read = result->rows; - res->bytes_read = result->bytes; - res->elapsed = result->elapsed; + res->len = result->buf_->size(); + res->buf = result->buf_->data(); + res->_vec = result->buf_; + res->rows_read = result->rows_; + res->bytes_read = result->bytes_; + res->elapsed = result->elapsed_; return res; } void free_result(local_result * result) { - if (!result || !result->_vec) + if (!result) + { + return; + } + if (result->_vec) + { + std::vector * vec = reinterpret_cast *>(result->_vec); + delete vec; + result->_vec = nullptr; + } +} + +local_result_v2 * query_stable_v2(int argc, char ** argv) +{ + auto result = pyEntryClickHouseLocal(argc, argv); + local_result_v2 * res = new local_result_v2; + if (!result->error_msg_.empty()) + { + res->error_message = new char[result->error_msg_.size() + 1]; + memcpy(res->error_message, result->error_msg_.c_str(), result->error_msg_.size() + 1); + res->_vec = nullptr; + res->buf = nullptr; + } else { + res->len = result->buf_->size(); + res->buf = result->buf_->data(); + res->_vec = result->buf_; + res->rows_read = result->rows_; + res->bytes_read = result->bytes_; + res->elapsed = result->elapsed_; + res->error_message = nullptr; + } + return res; +} + +void free_result_v2(local_result_v2 * result) +{ + if (!result) { return; } - std::vector * vec = reinterpret_cast *>(result->_vec); - delete vec; - result->_vec = nullptr; + if (result->_vec) + { + std::vector * vec = reinterpret_cast *>(result->_vec); + delete vec; + result->_vec = nullptr; + } + if (result->error_message) + { + delete[] result->error_message; + result->error_message = nullptr; + } } /** @@ -1125,7 +1175,7 @@ int mainEntryClickHouseLocal(int argc, char ** argv) auto result = pyEntryClickHouseLocal(argc, argv); if (result) { - std::cout << std::string(result->buf->begin(), result->buf->end()) << std::endl; + std::cout << result->string() << std::endl; return 0; } else diff --git a/programs/local/chdb.h b/programs/local/chdb.h index 48a83804591..157ddc0f06d 100644 --- a/programs/local/chdb.h +++ b/programs/local/chdb.h @@ -20,9 +20,23 @@ struct CHDB_EXPORT local_result uint64_t bytes_read; }; +struct CHDB_EXPORT local_result_v2 +{ + char * buf; + size_t len; + void * _vec; // std::vector *, for freeing + double elapsed; + uint64_t rows_read; + uint64_t bytes_read; + char * error_message; +}; + CHDB_EXPORT struct local_result * query_stable(int argc, char ** argv); CHDB_EXPORT void free_result(struct local_result * result); +CHDB_EXPORT struct local_result_v2 * query_stable_v2(int argc, char ** argv); +CHDB_EXPORT void free_result_v2(struct local_result_v2 * result); + #ifdef __cplusplus } #endif diff --git a/src/Client/ClientBase.h b/src/Client/ClientBase.h index 25e8a273fc9..f9dee8a0509 100644 --- a/src/Client/ClientBase.h +++ b/src/Client/ClientBase.h @@ -83,6 +83,7 @@ class ClientBase : public Poco::Util::Application, public IHints<2> size_t getProcessedRows() const { return processed_rows; } size_t getProcessedBytes() const { return processed_bytes; } double getElapsedTime() const { return progress_indication.elapsedSeconds(); } + std::string get_error_msg() const { return error_message_oss.str(); } std::vector getAllRegisteredNames() const override { return cmd_options; } @@ -292,6 +293,7 @@ class ClientBase : public Poco::Util::Application, public IHints<2> size_t processed_rows = 0; /// How many rows have been read or written. size_t processed_bytes = 0; /// How many bytes have been read or written. bool print_num_processed_rows = false; /// Whether to print the number of processed rows at + std::stringstream error_message_oss; /// error message stringstream bool print_stack_trace = false; /// The last exception that was received from the server. Is used for the diff --git a/tests/test_basic.py b/tests/test_basic.py index cf95634170b..d1d27d86b90 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -10,6 +10,10 @@ class TestBasic(unittest.TestCase): def test_basic(self): res = chdb.query("SELECT 1", "CSV") self.assertEqual(len(res), 2) # "1\n" + self.assertFalse(res.has_error()) + self.assertTrue(len(res.error_message()) == 0) + with self.assertRaises(Exception): + res = chdb.query("SELECT 1", "csv") class TestOutput(unittest.TestCase): def test_output(self): for format, output in format_output.items():