Skip to content

Commit a99fcbc

Browse files
authored
Merge pull request #105 from chdb-io/feat/add-has_error-and-error_message
Add has_error and error_message methods
2 parents af60d3e + b93d44e commit a99fcbc

File tree

8 files changed

+149
-49
lines changed

8 files changed

+149
-49
lines changed

chdb/__init__.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
import os
33

44

5+
_arrow_format = set({"dataframe", "arrowtable"})
6+
_process_result_format_funs = {
7+
"dataframe" : lambda x : to_df(x),
8+
"arrowtable": lambda x : to_arrowTable(x)
9+
}
10+
511
# If any UDF is defined, the path of the UDF will be set to this variable
612
# and the path will be deleted when the process exits
713
# UDF config path will be f"{g_udf_path}/udf_config.xml"
@@ -60,9 +66,10 @@ def query(sql, output_format="CSV", path="", udf_path=""):
6066
if udf_path != "":
6167
g_udf_path = udf_path
6268
lower_output_format = output_format.lower()
63-
if lower_output_format == "dataframe":
64-
return to_df(_chdb.query(sql, "Arrow", path=path, udf_path=g_udf_path))
65-
elif lower_output_format == "arrowtable":
66-
return to_arrowTable(_chdb.query(sql, "Arrow", path=path, udf_path=g_udf_path))
67-
else:
68-
return _chdb.query(sql, output_format, path=path, udf_path=g_udf_path)
69+
result_func = _process_result_format_funs.get(lower_output_format, lambda x : x)
70+
if lower_output_format in _arrow_format:
71+
output_format = "Arrow"
72+
res = _chdb.query(sql, output_format, path=path, udf_path=g_udf_path)
73+
if res.has_error():
74+
raise Exception(res.error_message())
75+
return result_func(res)

chdb/dbapi/connections.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,10 @@ def _execute_command(self, sql):
123123
if DEBUG:
124124
print("DEBUG: query:", sql)
125125
try:
126-
self._resp = self._session.query(sql, fmt="JSON").data()
126+
res = self._session.query(sql, output_format="JSON")
127+
if res.has_error():
128+
raise err.DatabaseError(res.error_message())
129+
self._resp = res.data()
127130
except Exception as error:
128131
raise err.InterfaceError("query err: %s" % error)
129132

programs/local/LocalChdb.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
extern bool inside_main = true;
77

88

9-
local_result * queryToBuffer(
9+
local_result_v2 * queryToBuffer(
1010
const std::string & queryStr,
1111
const std::string & output_format = "CSV",
1212
const std::string & path = {},
@@ -51,7 +51,7 @@ local_result * queryToBuffer(
5151
for (auto & arg : argv)
5252
argv_char.push_back(const_cast<char *>(arg.c_str()));
5353

54-
return query_stable(argv_char.size(), argv_char.data());
54+
return query_stable_v2(argv_char.size(), argv_char.data());
5555
}
5656

5757
// Pybind11 will take over the ownership of the `query_result` object
@@ -132,7 +132,7 @@ PYBIND11_MODULE(_chdb, m)
132132
.def("view", &memoryview_wrapper::view);
133133

134134
py::class_<query_result>(m, "query_result")
135-
.def(py::init<local_result *>(), py::return_value_policy::take_ownership)
135+
.def(py::init<local_result_v2 *>(), py::return_value_policy::take_ownership)
136136
.def("data", &query_result::data)
137137
.def("bytes", &query_result::bytes)
138138
.def("__str__", &query_result::str)
@@ -142,7 +142,9 @@ PYBIND11_MODULE(_chdb, m)
142142
.def("rows_read", &query_result::rows_read)
143143
.def("bytes_read", &query_result::bytes_read)
144144
.def("elapsed", &query_result::elapsed)
145-
.def("get_memview", &query_result::get_memview);
145+
.def("get_memview", &query_result::get_memview)
146+
.def("has_error", &query_result::has_error)
147+
.def("error_message", &query_result::error_message);
146148

147149

148150
m.def(

programs/local/LocalChdb.h

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@ class __attribute__((visibility("default"))) query_result;
1515
class local_result_wrapper
1616
{
1717
private:
18-
local_result * result;
18+
local_result_v2 * result;
1919

2020
public:
21-
local_result_wrapper(local_result * result) : result(result) { }
21+
local_result_wrapper(local_result_v2 * result) : result(result) { }
2222
~local_result_wrapper()
2323
{
24-
free_result(result);
24+
free_result_v2(result);
2525
delete result;
2626
}
2727
char * data()
@@ -81,6 +81,22 @@ class local_result_wrapper
8181
}
8282
return result->elapsed;
8383
}
84+
bool has_error()
85+
{
86+
if (result == nullptr)
87+
{
88+
return false;
89+
}
90+
return result->error_message != nullptr;
91+
}
92+
py::str error_message()
93+
{
94+
if (has_error())
95+
{
96+
return py::str(result->error_message);
97+
}
98+
return py::str();
99+
}
84100
};
85101

86102
class query_result
@@ -89,7 +105,7 @@ class query_result
89105
std::shared_ptr<local_result_wrapper> result_wrapper;
90106

91107
public:
92-
query_result(local_result * result) : result_wrapper(std::make_shared<local_result_wrapper>(result)) { }
108+
query_result(local_result_v2 * result) : result_wrapper(std::make_shared<local_result_wrapper>(result)) { }
93109
~query_result() { }
94110
char * data() { return result_wrapper->data(); }
95111
py::bytes bytes() { return result_wrapper->bytes(); }
@@ -98,6 +114,8 @@ class query_result
98114
size_t rows_read() { return result_wrapper->rows_read(); }
99115
size_t bytes_read() { return result_wrapper->bytes_read(); }
100116
double elapsed() { return result_wrapper->elapsed(); }
117+
bool has_error() { return result_wrapper->has_error(); }
118+
py::str error_message() { return result_wrapper->error_message(); }
101119
memoryview_wrapper * get_memview();
102120
};
103121

programs/local/LocalServer.cpp

Lines changed: 84 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
#include <Interpreters/ProcessList.h>
2222
#include <Interpreters/loadMetadata.h>
2323
#include <base/getFQDNOrHostName.h>
24-
#include <Common/scope_guard_safe.h>
2524
#include <Interpreters/Session.h>
2625
#include <Access/AccessControl.h>
2726
#include <Common/Exception.h>
@@ -34,15 +33,12 @@
3433
#include <Common/quoteString.h>
3534
#include <Common/randomSeed.h>
3635
#include <Common/ThreadPool.h>
37-
#include <Loggers/Loggers.h>
3836
#include <Loggers/OwnFormattingChannel.h>
3937
#include <Loggers/OwnPatternFormatter.h>
4038
#include <IO/ReadBufferFromFile.h>
41-
#include <IO/ReadBufferFromString.h>
4239
#include <IO/WriteBufferFromFileDescriptor.h>
4340
#include <IO/UseSSL.h>
4441
#include <IO/SharedThreadPools.h>
45-
#include <Parsers/IAST.h>
4642
#include <Parsers/ASTInsertQuery.h>
4743
#include <Common/ErrorHandlers.h>
4844
#include <Functions/UserDefined/IUserDefinedSQLObjectsLoader.h>
@@ -59,6 +55,7 @@
5955
#include <base/argsToConfig.h>
6056
#include <filesystem>
6157
#include <fstream>
58+
#include <memory>
6259

6360
#include "config.h"
6461

@@ -563,14 +560,14 @@ catch (const DB::Exception & e)
563560
cleanup();
564561

565562
bool need_print_stack_trace = config().getBool("stacktrace", false);
566-
std::cerr << getExceptionMessage(e, need_print_stack_trace, true) << std::endl;
563+
error_message_oss << getExceptionMessage(e, need_print_stack_trace, true);
567564
return e.code() ? e.code() : -1;
568565
}
569566
catch (...)
570567
{
571568
cleanup();
572569

573-
std::cerr << getCurrentExceptionMessage(false) << std::endl;
570+
error_message_oss << getCurrentExceptionMessage(false);
574571
return getCurrentExceptionCode();
575572
}
576573

@@ -1024,12 +1021,26 @@ void LocalServer::readArguments(int argc, char ** argv, Arguments & common_argum
10241021
class query_result_
10251022
{
10261023
public:
1027-
uint64_t rows;
1028-
uint64_t bytes;
1029-
double elapsed;
1030-
std::vector<char> * buf;
1024+
explicit query_result_(std::vector<char>* buf, uint64_t rows,
1025+
uint64_t bytes, double elapsed):
1026+
rows_(rows), bytes_(bytes), elapsed_(elapsed),
1027+
buf_(buf) { }
1028+
1029+
explicit query_result_(std::string&& error_msg): error_msg_(error_msg) { }
1030+
1031+
std::string string()
1032+
{
1033+
return std::string(buf_->begin(), buf_->end());
1034+
}
1035+
1036+
uint64_t rows_;
1037+
uint64_t bytes_;
1038+
double elapsed_;
1039+
std::vector<char> * buf_;
1040+
std::string error_msg_;
10311041
};
10321042

1043+
10331044
std::unique_ptr<query_result_> pyEntryClickHouseLocal(int argc, char ** argv)
10341045
{
10351046
try
@@ -1039,18 +1050,13 @@ std::unique_ptr<query_result_> pyEntryClickHouseLocal(int argc, char ** argv)
10391050
int ret = app.run();
10401051
if (ret == 0)
10411052
{
1042-
auto result = std::make_unique<query_result_>();
1043-
result->buf = app.getQueryOutputVector();
1044-
result->rows = app.getProcessedRows();
1045-
result->bytes = app.getProcessedBytes();
1046-
result->elapsed = app.getElapsedTime();
1047-
1048-
// std::cerr << std::string(out->begin(), out->end()) << std::endl;
1049-
return result;
1050-
}
1051-
else
1052-
{
1053-
return nullptr;
1053+
return std::make_unique<query_result_>(
1054+
app.getQueryOutputVector(),
1055+
app.getProcessedRows(),
1056+
app.getProcessedBytes(),
1057+
app.getElapsedTime());
1058+
} else {
1059+
return std::make_unique<query_result_>(app.get_error_msg());
10541060
}
10551061
}
10561062
catch (const DB::Exception & e)
@@ -1072,29 +1078,73 @@ std::unique_ptr<query_result_> pyEntryClickHouseLocal(int argc, char ** argv)
10721078
local_result * query_stable(int argc, char ** argv)
10731079
{
10741080
auto result = pyEntryClickHouseLocal(argc, argv);
1075-
if (!result || !result->buf)
1081+
if (result->error_msg_.empty())
10761082
{
10771083
return nullptr;
10781084
}
10791085
local_result * res = new local_result;
1080-
res->len = result->buf->size();
1081-
res->buf = result->buf->data();
1082-
res->_vec = result->buf;
1083-
res->rows_read = result->rows;
1084-
res->bytes_read = result->bytes;
1085-
res->elapsed = result->elapsed;
1086+
res->len = result->buf_->size();
1087+
res->buf = result->buf_->data();
1088+
res->_vec = result->buf_;
1089+
res->rows_read = result->rows_;
1090+
res->bytes_read = result->bytes_;
1091+
res->elapsed = result->elapsed_;
10861092
return res;
10871093
}
10881094

10891095
void free_result(local_result * result)
10901096
{
1091-
if (!result || !result->_vec)
1097+
if (!result)
1098+
{
1099+
return;
1100+
}
1101+
if (result->_vec)
1102+
{
1103+
std::vector<char> * vec = reinterpret_cast<std::vector<char> *>(result->_vec);
1104+
delete vec;
1105+
result->_vec = nullptr;
1106+
}
1107+
}
1108+
1109+
local_result_v2 * query_stable_v2(int argc, char ** argv)
1110+
{
1111+
auto result = pyEntryClickHouseLocal(argc, argv);
1112+
local_result_v2 * res = new local_result_v2;
1113+
if (!result->error_msg_.empty())
1114+
{
1115+
res->error_message = new char[result->error_msg_.size() + 1];
1116+
memcpy(res->error_message, result->error_msg_.c_str(), result->error_msg_.size() + 1);
1117+
res->_vec = nullptr;
1118+
res->buf = nullptr;
1119+
} else {
1120+
res->len = result->buf_->size();
1121+
res->buf = result->buf_->data();
1122+
res->_vec = result->buf_;
1123+
res->rows_read = result->rows_;
1124+
res->bytes_read = result->bytes_;
1125+
res->elapsed = result->elapsed_;
1126+
res->error_message = nullptr;
1127+
}
1128+
return res;
1129+
}
1130+
1131+
void free_result_v2(local_result_v2 * result)
1132+
{
1133+
if (!result)
10921134
{
10931135
return;
10941136
}
1095-
std::vector<char> * vec = reinterpret_cast<std::vector<char> *>(result->_vec);
1096-
delete vec;
1097-
result->_vec = nullptr;
1137+
if (result->_vec)
1138+
{
1139+
std::vector<char> * vec = reinterpret_cast<std::vector<char> *>(result->_vec);
1140+
delete vec;
1141+
result->_vec = nullptr;
1142+
}
1143+
if (result->error_message)
1144+
{
1145+
delete[] result->error_message;
1146+
result->error_message = nullptr;
1147+
}
10981148
}
10991149

11001150
/**
@@ -1125,7 +1175,7 @@ int mainEntryClickHouseLocal(int argc, char ** argv)
11251175
auto result = pyEntryClickHouseLocal(argc, argv);
11261176
if (result)
11271177
{
1128-
std::cout << std::string(result->buf->begin(), result->buf->end()) << std::endl;
1178+
std::cout << result->string() << std::endl;
11291179
return 0;
11301180
}
11311181
else

programs/local/chdb.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,23 @@ struct CHDB_EXPORT local_result
2020
uint64_t bytes_read;
2121
};
2222

23+
struct CHDB_EXPORT local_result_v2
24+
{
25+
char * buf;
26+
size_t len;
27+
void * _vec; // std::vector<char> *, for freeing
28+
double elapsed;
29+
uint64_t rows_read;
30+
uint64_t bytes_read;
31+
char * error_message;
32+
};
33+
2334
CHDB_EXPORT struct local_result * query_stable(int argc, char ** argv);
2435
CHDB_EXPORT void free_result(struct local_result * result);
2536

37+
CHDB_EXPORT struct local_result_v2 * query_stable_v2(int argc, char ** argv);
38+
CHDB_EXPORT void free_result_v2(struct local_result_v2 * result);
39+
2640
#ifdef __cplusplus
2741
}
2842
#endif

src/Client/ClientBase.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ class ClientBase : public Poco::Util::Application, public IHints<2>
8383
size_t getProcessedRows() const { return processed_rows; }
8484
size_t getProcessedBytes() const { return processed_bytes; }
8585
double getElapsedTime() const { return progress_indication.elapsedSeconds(); }
86+
std::string get_error_msg() const { return error_message_oss.str(); }
8687

8788
std::vector<String> getAllRegisteredNames() const override { return cmd_options; }
8889

@@ -292,6 +293,7 @@ class ClientBase : public Poco::Util::Application, public IHints<2>
292293
size_t processed_rows = 0; /// How many rows have been read or written.
293294
size_t processed_bytes = 0; /// How many bytes have been read or written.
294295
bool print_num_processed_rows = false; /// Whether to print the number of processed rows at
296+
std::stringstream error_message_oss; /// error message stringstream
295297

296298
bool print_stack_trace = false;
297299
/// The last exception that was received from the server. Is used for the

tests/test_basic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ class TestBasic(unittest.TestCase):
1010
def test_basic(self):
1111
res = chdb.query("SELECT 1", "CSV")
1212
self.assertEqual(len(res), 2) # "1\n"
13+
self.assertFalse(res.has_error())
14+
self.assertTrue(len(res.error_message()) == 0)
15+
with self.assertRaises(Exception):
16+
res = chdb.query("SELECT 1", "csv")
1317
class TestOutput(unittest.TestCase):
1418
def test_output(self):
1519
for format, output in format_output.items():

0 commit comments

Comments
 (0)