diff --git a/.gitignore b/.gitignore index f91793b9e..ffd2bab41 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,5 @@ __pycache__ confluent?kafka.egg-info *.pyc .cache +*.log +confluent-kafka-0.*.* diff --git a/confluent_kafka/src/Consumer.c b/confluent_kafka/src/Consumer.c index 722ca80c2..6f1bb6d2a 100644 --- a/confluent_kafka/src/Consumer.c +++ b/confluent_kafka/src/Consumer.c @@ -366,20 +366,17 @@ static PyObject *Consumer_poll (Handle *self, PyObject *args, static char *kws[] = { "timeout", NULL }; rd_kafka_message_t *rkm; PyObject *msgobj; + CallState cs; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|d", kws, &tmout)) return NULL; - self->callback_crashed = 0; - self->thread_state = PyEval_SaveThread(); + CallState_begin(self, &cs); rkm = rd_kafka_consumer_poll(self->rk, tmout >= 0 ? (int)(tmout * 1000.0f) : -1); - PyEval_RestoreThread(self->thread_state); - self->thread_state = NULL; - - if (self->callback_crashed) + if (!CallState_end(self, &cs)) return NULL; if (!rkm) @@ -393,9 +390,15 @@ static PyObject *Consumer_poll (Handle *self, PyObject *args, static PyObject *Consumer_close (Handle *self, PyObject *ignore) { - self->thread_state = PyEval_SaveThread(); + CallState cs; + + CallState_begin(self, &cs); + rd_kafka_consumer_close(self->rk); - PyEval_RestoreThread(self->thread_state); + + if (!CallState_end(self, &cs)) + return NULL; + Py_RETURN_NONE; } @@ -593,8 +596,9 @@ static void Consumer_rebalance_cb (rd_kafka_t *rk, rd_kafka_resp_err_t err, rd_kafka_topic_partition_list_t *c_parts, void *opaque) { Handle *self = opaque; + CallState *cs; - PyEval_RestoreThread(self->thread_state); + cs = CallState_get(self); self->u.Consumer.rebalance_assigned = 0; @@ -615,8 +619,8 @@ static void Consumer_rebalance_cb (rd_kafka_t *rk, rd_kafka_resp_err_t err, if (!args) { cfl_PyErr_Format(RD_KAFKA_RESP_ERR__FAIL, "Unable to build callback args"); - self->thread_state = PyEval_SaveThread(); - self->callback_crashed++; + CallState_crash(cs); + CallState_resume(cs); return; } @@ -630,7 +634,7 @@ static void Consumer_rebalance_cb (rd_kafka_t *rk, rd_kafka_resp_err_t err, if (result) Py_DECREF(result); else { - self->callback_crashed++; + CallState_crash(cs); rd_kafka_yield(rk); } } @@ -646,7 +650,7 @@ static void Consumer_rebalance_cb (rd_kafka_t *rk, rd_kafka_resp_err_t err, rd_kafka_assign(rk, NULL); } - self->thread_state = PyEval_SaveThread(); + CallState_resume(cs); } @@ -655,11 +659,12 @@ static void Consumer_offset_commit_cb (rd_kafka_t *rk, rd_kafka_resp_err_t err, void *opaque) { Handle *self = opaque; PyObject *parts, *k_err, *args, *result; + CallState *cs; if (!self->u.Consumer.on_commit) return; - PyEval_RestoreThread(self->thread_state); + cs = CallState_get(self); /* Insantiate error object */ k_err = KafkaError_new_or_None(err, NULL); @@ -675,8 +680,8 @@ static void Consumer_offset_commit_cb (rd_kafka_t *rk, rd_kafka_resp_err_t err, if (!args) { cfl_PyErr_Format(RD_KAFKA_RESP_ERR__FAIL, "Unable to build callback args"); - self->thread_state = PyEval_SaveThread(); - self->callback_crashed++; + CallState_crash(cs); + CallState_resume(cs); return; } @@ -687,11 +692,11 @@ static void Consumer_offset_commit_cb (rd_kafka_t *rk, rd_kafka_resp_err_t err, if (result) Py_DECREF(result); else { - self->callback_crashed++; + CallState_crash(cs); rd_kafka_yield(rk); } - self->thread_state = PyEval_SaveThread(); + CallState_resume(cs); } diff --git a/confluent_kafka/src/Producer.c b/confluent_kafka/src/Producer.c index bca70568f..ec61197e8 100644 --- a/confluent_kafka/src/Producer.c +++ b/confluent_kafka/src/Producer.c @@ -131,6 +131,7 @@ static void dr_msg_cb (rd_kafka_t *rk, const rd_kafka_message_t *rkm, void *opaque) { struct Producer_msgstate *msgstate = rkm->_private; Handle *self = opaque; + CallState *cs; PyObject *args; PyObject *result; PyObject *msgobj; @@ -138,7 +139,7 @@ static void dr_msg_cb (rd_kafka_t *rk, const rd_kafka_message_t *rkm, if (!msgstate) return; - PyEval_RestoreThread(self->thread_state); + cs = CallState_get(self); if (!msgstate->dr_cb) { /* No callback defined */ @@ -156,7 +157,7 @@ static void dr_msg_cb (rd_kafka_t *rk, const rd_kafka_message_t *rkm, if (!args) { cfl_PyErr_Format(RD_KAFKA_RESP_ERR__FAIL, "Unable to build callback args"); - self->callback_crashed++; + CallState_crash(cs); goto done; } @@ -166,13 +167,13 @@ static void dr_msg_cb (rd_kafka_t *rk, const rd_kafka_message_t *rkm, if (result) Py_DECREF(result); else { - self->callback_crashed++; + CallState_crash(cs); rd_kafka_yield(rk); } done: Producer_msgstate_destroy(msgstate); - self->thread_state = PyEval_SaveThread(); + CallState_resume(cs); } @@ -279,9 +280,9 @@ static PyObject *Producer_produce (Handle *self, PyObject *args, return NULL; } - if (!dr_cb) + if (!dr_cb || dr_cb == Py_None) dr_cb = self->u.Producer.default_dr_cb; - if (!partitioner_cb) + if (!partitioner_cb || partitioner_cb == Py_None) partitioner_cb = self->u.Producer.partitioner_cb; /* Create msgstate if necessary, may return NULL if no callbacks @@ -321,20 +322,15 @@ static PyObject *Producer_produce (Handle *self, PyObject *args, */ static int Producer_poll0 (Handle *self, int tmout) { int r; + CallState cs; - self->callback_crashed = 0; - self->thread_state = PyEval_SaveThread(); + CallState_begin(self, &cs); r = rd_kafka_poll(self->rk, tmout); - PyEval_RestoreThread(self->thread_state); - self->thread_state = NULL; - - if (PyErr_CheckSignals() == -1) - return -1; - - if (self->callback_crashed) + if (!CallState_end(self, &cs)) { return -1; + } return r; } diff --git a/confluent_kafka/src/confluent_kafka.c b/confluent_kafka/src/confluent_kafka.c index 60d6de10b..9c960df33 100644 --- a/confluent_kafka/src/confluent_kafka.c +++ b/confluent_kafka/src/confluent_kafka.c @@ -813,8 +813,9 @@ rd_kafka_topic_partition_list_t *py_to_c_parts (PyObject *plist) { static void error_cb (rd_kafka_t *rk, int err, const char *reason, void *opaque) { Handle *h = opaque; PyObject *eo, *result; + CallState *cs; - PyEval_RestoreThread(h->thread_state); + cs = CallState_get(h); if (!h->error_cb) { /* No callback defined */ goto done; @@ -827,12 +828,12 @@ static void error_cb (rd_kafka_t *rk, int err, const char *reason, void *opaque) if (result) { Py_DECREF(result); } else { - h->callback_crashed++; + CallState_crash(cs); rd_kafka_yield(h->rk); } done: - h->thread_state = PyEval_SaveThread(); + CallState_resume(cs); } @@ -855,6 +856,8 @@ void Handle_clear (Handle *h) { if (h->error_cb) { Py_DECREF(h->error_cb); } + + PyThread_delete_key(h->tlskey); } /** @@ -1110,10 +1113,14 @@ rd_kafka_conf_t *common_conf_setup (rd_kafka_type_t ktype, continue; } else if (!strcmp(k, "error_cb")) { - if (h->error_cb) + if (h->error_cb) { Py_DECREF(h->error_cb); - h->error_cb = vo; - Py_INCREF(h->error_cb); + h->error_cb = NULL; + } + if (vo != Py_None) { + h->error_cb = vo; + Py_INCREF(h->error_cb); + } Py_DECREF(ks); continue; } @@ -1172,12 +1179,67 @@ rd_kafka_conf_t *common_conf_setup (rd_kafka_type_t ktype, rd_kafka_conf_set_opaque(conf, h); + h->tlskey = PyThread_create_key(); + return conf; } +/** + * @brief Initialiase a CallState and unlock the GIL prior to a + * possibly blocking external call. + */ +void CallState_begin (Handle *h, CallState *cs) { + cs->thread_state = PyEval_SaveThread(); + cs->crashed = 0; + PyThread_set_key_value(h->tlskey, cs); +} + +/** + * @brief Relock the GIL after external call is done. + * @returns 0 if a Python signal was raised or a callback crashed, else 1. + */ +int CallState_end (Handle *h, CallState *cs) { + PyThread_delete_key_value(h->tlskey); + + PyEval_RestoreThread(cs->thread_state); + + if (PyErr_CheckSignals() == -1 || cs->crashed) + return 0; + + return 1; +} + + +/** + * @brief Get the current thread's CallState and re-locks the GIL. + */ +CallState *CallState_get (Handle *h) { + CallState *cs = PyThread_get_key_value(h->tlskey); + assert(cs != NULL); + PyEval_RestoreThread(cs->thread_state); + cs->thread_state = NULL; + return cs; +} + +/** + * @brief Un-locks the GIL to resume blocking external call. + */ +void CallState_resume (CallState *cs) { + assert(cs->thread_state == NULL); + cs->thread_state = PyEval_SaveThread(); +} + +/** + * @brief Indicate that call crashed. + */ +void CallState_crash (CallState *cs) { + cs->crashed++; +} + + /**************************************************************************** * diff --git a/confluent_kafka/src/confluent_kafka.h b/confluent_kafka/src/confluent_kafka.h index b007dc79a..5e8eb7989 100644 --- a/confluent_kafka/src/confluent_kafka.h +++ b/confluent_kafka/src/confluent_kafka.h @@ -16,6 +16,7 @@ #include #include +#include #include @@ -112,9 +113,8 @@ PyObject *KafkaError_new_or_None (rd_kafka_resp_err_t err, const char *str); typedef struct { PyObject_HEAD rd_kafka_t *rk; - int callback_crashed; - PyThreadState *thread_state; PyObject *error_cb; + int tlskey; /* Thread-Local-Storage key */ union { /** @@ -147,6 +147,40 @@ void Handle_clear (Handle *h); int Handle_traverse (Handle *h, visitproc visit, void *arg); +/** + * @brief Current thread's state for "blocking" calls to librdkafka. + */ +typedef struct { + PyThreadState *thread_state; + int crashed; /* Callback crashed */ +} CallState; + +/** + * @brief Initialiase a CallState and unlock the GIL prior to a + * possibly blocking external call. + */ +void CallState_begin (Handle *h, CallState *cs); +/** + * @brief Relock the GIL after external call is done, remove TLS state. + * @returns 0 if a Python signal was raised or a callback crashed, else 1. + */ +int CallState_end (Handle *h, CallState *cs); + +/** + * @brief Get the current thread's CallState and re-locks the GIL. + */ +CallState *CallState_get (Handle *h); +/** + * @brief Un-locks the GIL to resume blocking external call. + */ +void CallState_resume (CallState *cs); + +/** + * @brief Indicate that call crashed. + */ +void CallState_crash (CallState *cs); + + /**************************************************************************** * * diff --git a/tests/test_threads.py b/tests/test_threads.py new file mode 100644 index 000000000..09f646722 --- /dev/null +++ b/tests/test_threads.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python + +from confluent_kafka import Producer, KafkaError, KafkaException +import threading +import time +try: + from queue import Queue, Empty +except: + from Queue import Queue, Empty + + +class IntendedException (Exception): + pass + + +def thread_run(myid, p, q): + def do_crash(err, msg): + raise IntendedException() + + for i in range(1, 3): + cb = None + if i == 2: + cb = do_crash + p.produce('mytopic', value='hi', callback=cb) + t = time.time() + try: + p.flush() + print(myid, 'Flush took %.3f' % (time.time() - t)) + except IntendedException: + print(myid, "Intentional callback crash: ok") + continue + + print(myid, 'Done') + q.put(myid) + + +def test_thread_safety(): + """ Basic thread safety tests. """ + + q = Queue() + p = Producer({'socket.timeout.ms': 10, + 'socket.blocking.max.ms': 10, + 'default.topic.config': {'message.timeout.ms': 10}}) + + threads = list() + for i in range(1, 5): + thr = threading.Thread(target=thread_run, name=str(i), args=[i, p, q]) + thr.start() + threads.append(thr) + + for thr in threads: + thr.join() + + # Count the number of threads that exited cleanly + cnt = 0 + try: + for x in iter(q.get_nowait, None): + cnt += 1 + except Empty: + pass + + if cnt != len(threads): + raise Exception('Only %d/%d threads succeeded' % (cnt, len(threads))) + + print('Done') + + +if __name__ == '__main__': + test_thread_safety()