Skip to content

Commit 5f3de98

Browse files
committed
chore: force shutdown connections on takeover
Signed-off-by: Kostas Kyrimis <[email protected]>
1 parent 308ee6b commit 5f3de98

File tree

5 files changed

+104
-40
lines changed

5 files changed

+104
-40
lines changed

src/facade/dragonfly_connection.cc

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ void LogTraffic(uint32_t id, bool has_more, absl::Span<RespExpr> resp,
272272

273273
// Write the data itself.
274274
array<iovec, 16> blobs;
275-
unsigned index = 0;
275+
uint64_t index = 0;
276276
if (next != stack_buf) {
277277
blobs[index++] = iovec{.iov_base = stack_buf, .iov_len = size_t(next - stack_buf)};
278278
}
@@ -538,7 +538,7 @@ void Connection::AsyncOperations::operator()(const PubMessage& pub_msg) {
538538
return;
539539
}
540540

541-
unsigned i = 0;
541+
uint64_t i = 0;
542542
array<string_view, 4> arr;
543543
if (pub_msg.pattern.empty()) {
544544
arr[i++] = pub_msg.is_sharded ? "smessage" : "message";
@@ -609,11 +609,11 @@ void UpdateLibNameVerMap(const string& name, const string& ver, int delta) {
609609
}
610610
} // namespace
611611

612-
void Connection::Init(unsigned io_threads) {
612+
void Connection::Init(uint64_t io_threads) {
613613
CHECK(thread_queue_backpressure == nullptr);
614614
thread_queue_backpressure = new QueueBackpressure[io_threads];
615615

616-
for (unsigned i = 0; i < io_threads; ++i) {
616+
for (uint64_t i = 0; i < io_threads; ++i) {
617617
auto& qbp = thread_queue_backpressure[i];
618618
qbp.publish_buffer_limit = GetFlag(FLAGS_publish_buffer_limit);
619619
qbp.pipeline_cache_limit = GetFlag(FLAGS_request_cache_limit);
@@ -861,7 +861,7 @@ void Connection::HandleRequests() {
861861
}
862862
}
863863

864-
unsigned Connection::GetSendWaitTimeSec() const {
864+
uint64_t Connection::GetSendWaitTimeSec() const {
865865
if (reply_builder_ && reply_builder_->IsSendActive()) {
866866
return (util::fb2::ProactorBase::GetMonotonicTimeNs() - reply_builder_->GetLastSendTimeNs()) /
867867
1'000'000'000;
@@ -877,11 +877,11 @@ void Connection::RegisterBreakHook(BreakerCb breaker_cb) {
877877
pair<string, string> Connection::GetClientInfoBeforeAfterTid() const {
878878
if (!socket_) {
879879
LOG(DFATAL) << "unexpected null socket_ "
880-
<< " phase " << unsigned(phase_) << ", is_http: " << unsigned(is_http_);
880+
<< " phase " << uint64_t(phase_) << ", is_http: " << unsigned(is_http_);
881881
return {};
882882
}
883883

884-
CHECK_LT(unsigned(phase_), NUM_PHASES);
884+
CHECK_LT(uint64_t(phase_), NUM_PHASES);
885885

886886
string before;
887887
auto le = LocalBindStr();
@@ -947,7 +947,7 @@ pair<string, string> Connection::GetClientInfoBeforeAfterTid() const {
947947
return {std::move(before), std::move(after)};
948948
}
949949

950-
string Connection::GetClientInfo(unsigned thread_id) const {
950+
string Connection::GetClientInfo(uint64_t thread_id) const {
951951
auto [before, after] = GetClientInfoBeforeAfterTid();
952952
absl::StrAppend(&before, " tid=", thread_id);
953953
absl::StrAppend(&before, after);
@@ -1213,7 +1213,7 @@ void Connection::DispatchSingle(bool has_more, absl::FunctionRef<void()> invoke_
12131213
}
12141214
}
12151215

1216-
Connection::ParserStatus Connection::ParseRedis(unsigned max_busy_cycles) {
1216+
Connection::ParserStatus Connection::ParseRedis(uint64_t max_busy_cycles) {
12171217
uint32_t consumed = 0;
12181218
RedisParser::Result result = RedisParser::OK;
12191219

@@ -1827,7 +1827,7 @@ bool Connection::Migrate(util::fb2::ProactorBase* dest) {
18271827
Connection::WeakRef Connection::Borrow() {
18281828
DCHECK(self_);
18291829

1830-
return {self_, unsigned(socket_->proactor()->GetPoolIndex()), id_};
1830+
return {self_, uint64_t(socket_->proactor()->GetPoolIndex()), id_};
18311831
}
18321832

18331833
void Connection::ShutdownThreadLocal() {
@@ -2095,7 +2095,7 @@ bool Connection::IsReplySizeOverLimit() const {
20952095
}
20962096

20972097
void Connection::UpdateFromFlags() {
2098-
unsigned tid = fb2::ProactorBase::me()->GetPoolIndex();
2098+
uint64_t tid = fb2::ProactorBase::me()->GetPoolIndex();
20992099
thread_queue_backpressure[tid].pipeline_queue_max_len = GetFlag(FLAGS_pipeline_queue_limit);
21002100
thread_queue_backpressure[tid].pipeline_buffer_limit = GetFlag(FLAGS_pipeline_buffer_limit);
21012101
thread_queue_backpressure[tid].pipeline_cnd.notify_all();
@@ -2126,11 +2126,11 @@ void Connection::TrackRequestSize(bool enable) {
21262126
}
21272127
}
21282128

2129-
void Connection::EnsureMemoryBudget(unsigned tid) {
2129+
void Connection::EnsureMemoryBudget(uint64_t tid) {
21302130
thread_queue_backpressure[tid].EnsureBelowLimit();
21312131
}
21322132

2133-
Connection::WeakRef::WeakRef(const std::shared_ptr<Connection>& ptr, unsigned thread_id,
2133+
Connection::WeakRef::WeakRef(const std::shared_ptr<Connection>& ptr, uint64_t thread_id,
21342134
uint32_t client_id)
21352135
: ptr_{ptr}, last_known_thread_id_{thread_id}, client_id_{client_id} {
21362136
}

src/facade/dragonfly_connection.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class SinkReplyBuilder;
5353
// a separate dispatch queue that is processed on a separate fiber.
5454
class Connection : public util::Connection {
5555
public:
56-
static void Init(unsigned io_threads);
56+
static void Init(uint64_t io_threads);
5757
static void Shutdown();
5858
static void ShutdownThreadLocal();
5959

@@ -186,7 +186,7 @@ class Connection : public util::Connection {
186186
struct WeakRef {
187187
public:
188188
// Get residing thread of connection. Thread-safe.
189-
unsigned LastKnownThreadId() const {
189+
uint64_t LastKnownThreadId() const {
190190
return last_known_thread_id_;
191191
}
192192
// Get pointer to connection if still valid, nullptr if expired.
@@ -206,10 +206,10 @@ class Connection : public util::Connection {
206206
private:
207207
friend class Connection;
208208

209-
WeakRef(const std::shared_ptr<Connection>& ptr, unsigned thread_id, uint32_t client_id);
209+
WeakRef(const std::shared_ptr<Connection>& ptr, uint64_t thread_id, uint32_t client_id);
210210

211211
std::weak_ptr<Connection> ptr_;
212-
unsigned last_known_thread_id_;
212+
uint64_t last_known_thread_id_;
213213
uint32_t client_id_;
214214
};
215215

@@ -247,7 +247,7 @@ class Connection : public util::Connection {
247247

248248
bool IsCurrentlyDispatching() const;
249249

250-
std::string GetClientInfo(unsigned thread_id) const;
250+
std::string GetClientInfo(uint64_t thread_id) const;
251251
std::string GetClientInfo() const;
252252

253253
virtual std::string RemoteEndpointStr() const; // virtual because overwritten in test_utils
@@ -313,14 +313,14 @@ class Connection : public util::Connection {
313313
static std::vector<std::string> GetMutableFlagNames(); // Triggers UpdateFromFlags
314314

315315
static void TrackRequestSize(bool enable);
316-
static void EnsureMemoryBudget(unsigned tid);
316+
static void EnsureMemoryBudget(uint64_t tid);
317317
static void GetRequestSizeHistogramThreadLocal(std::string* hist);
318318

319-
unsigned idle_time() const {
319+
uint64_t idle_time() const {
320320
return time(nullptr) - last_interaction_;
321321
}
322322

323-
unsigned GetSendWaitTimeSec() const;
323+
uint64_t GetSendWaitTimeSec() const;
324324

325325
Phase phase() const {
326326
return phase_;
@@ -370,7 +370,7 @@ class Connection : public util::Connection {
370370
// Create new pipeline request, re-use from pool when possible.
371371
PipelineMessagePtr FromArgs(const RespVec& args);
372372

373-
ParserStatus ParseRedis(unsigned max_busy_cycles);
373+
ParserStatus ParseRedis(uint64_t max_busy_cycles);
374374
ParserStatus ParseMemcache();
375375

376376
void OnBreakCb(int32_t mask);
@@ -456,7 +456,7 @@ class Connection : public util::Connection {
456456
std::string lib_name_;
457457
std::string lib_ver_;
458458

459-
unsigned parser_error_ = 0;
459+
uint64_t parser_error_ = 0;
460460

461461
BreakerCb breaker_cb_;
462462

src/server/dflycmd.cc

Lines changed: 64 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ ABSL_RETIRED_FLAG(uint32_t, allow_partial_sync_with_lsn_diff, 0,
4040
ABSL_DECLARE_FLAG(bool, info_replication_valkey_compatible);
4141
ABSL_DECLARE_FLAG(uint32_t, replication_timeout);
4242
ABSL_DECLARE_FLAG(uint32_t, shard_repl_backlog_len);
43+
ABSL_FLAG(bool, experimental_force_takeover, false,
44+
"Attempts to force takeover in case of stuck connections");
4345

4446
namespace dfly {
4547

@@ -466,6 +468,45 @@ std::optional<LSN> DflyCmd::ParseLsnVec(std::string_view last_master_lsn,
466468
return {lsn_vec[flow_id]};
467469
}
468470

471+
void DflyCmd::ForceShutdownStuckConnections(uint64_t timeout) {
472+
vector<facade::Connection::WeakRef> conn_refs;
473+
auto cb = [&](unsigned thread_index, util::Connection* conn) {
474+
facade::Connection* dcon = static_cast<facade::Connection*>(conn);
475+
LOG(INFO) << dcon->DebugInfo();
476+
// Kill Connection here
477+
facade::Connection* dfly_conn = static_cast<facade::Connection*>(conn);
478+
using Phase = facade::Connection::Phase;
479+
auto phase = dfly_conn->phase();
480+
if (dfly_conn->cntx() && dfly_conn->cntx()->replica_conn) {
481+
return;
482+
}
483+
484+
bool idle_read = phase == Phase::READ_SOCKET && dfly_conn->idle_time() > timeout;
485+
486+
bool stuck_sending = dfly_conn->IsSending() && dfly_conn->GetSendWaitTimeSec() > timeout;
487+
488+
if (idle_read || stuck_sending) {
489+
LOG(INFO) << "Connection check: " << dfly_conn->GetClientInfo()
490+
<< ", phase=" << static_cast<int>(phase) << ", idle_time=" << dfly_conn->idle_time()
491+
<< ", is_sending=" << dfly_conn->IsSending() << ", idle_read=" << idle_read
492+
<< ", stuck_sending=" << stuck_sending;
493+
}
494+
conn_refs.push_back(dfly_conn->Borrow());
495+
};
496+
497+
for (auto* listener : sf_->GetListeners()) {
498+
listener->TraverseConnections(cb);
499+
}
500+
501+
VLOG(1) << "Found " << conn_refs.size() << " stucked connections ";
502+
for (auto& ref : conn_refs) {
503+
facade::Connection* conn = ref.Get();
504+
if (conn) {
505+
conn->ShutdownSelfBlocking();
506+
}
507+
}
508+
}
509+
469510
// DFLY TAKEOVER <timeout_sec> [SAVE] <sync_id>
470511
// timeout_sec - number of seconds to wait for TAKEOVER to converge.
471512
// SAVE option is used only by tests.
@@ -506,7 +547,7 @@ void DflyCmd::TakeOver(CmdArgList args, RedisReplyBuilder* rb, ConnectionContext
506547
LOG(INFO) << "Takeover initiated, locking down the database.";
507548
absl::Duration timeout_dur = absl::Seconds(timeout);
508549
absl::Time end_time = absl::Now() + timeout_dur;
509-
AggregateStatus status;
550+
OpStatus status = OpStatus::OK;
510551

511552
// We need to await for all dispatches to finish: Otherwise a transaction might be scheduled
512553
// after this function exits but before the actual shutdown.
@@ -520,13 +561,22 @@ void DflyCmd::TakeOver(CmdArgList args, RedisReplyBuilder* rb, ConnectionContext
520561
LOG(WARNING) << "Couldn't wait for commands to finish dispatching. " << timeout_dur;
521562
status = OpStatus::TIMED_OUT;
522563

523-
auto cb = [&](unsigned thread_index, util::Connection* conn) {
524-
facade::Connection* dcon = static_cast<facade::Connection*>(conn);
525-
LOG(INFO) << dcon->DebugInfo();
526-
};
527-
528-
for (auto* listener : sf_->GetListeners()) {
529-
listener->TraverseConnections(cb);
564+
// Force takeover on the same duration if flag is set
565+
if (absl::GetFlag(FLAGS_experimental_force_takeover)) {
566+
ForceShutdownStuckConnections(uint64_t(timeout));
567+
568+
// Safety net.
569+
facade::DispatchTracker tracker{sf_->GetNonPriviligedListeners(), cntx->conn(), false, false};
570+
shard_set->pool()->AwaitFiberOnAll([&](unsigned index, auto* pb) {
571+
sf_->CancelBlockingOnThread();
572+
tracker.TrackOnThread();
573+
});
574+
575+
status = OpStatus::OK;
576+
if (!tracker.Wait(timeout_dur)) {
577+
LOG(ERROR) << "Could not force execute takeover";
578+
status = OpStatus::TIMED_OUT;
579+
}
530580
}
531581
}
532582

@@ -540,10 +590,11 @@ void DflyCmd::TakeOver(CmdArgList args, RedisReplyBuilder* rb, ConnectionContext
540590
});
541591

542592
atomic_bool catchup_success = true;
543-
if (*status == OpStatus::OK) {
593+
if (status == OpStatus::OK) {
544594
dfly::SharedLock lk{replica_ptr->shared_mu};
545-
auto cb = [replica_ptr = replica_ptr, end_time, &catchup_success](EngineShard* shard) {
546-
if (!WaitReplicaFlowToCatchup(end_time, replica_ptr.get(), shard)) {
595+
auto time = end_time + timeout_dur;
596+
auto cb = [replica_ptr = replica_ptr, time, &catchup_success](EngineShard* shard) {
597+
if (!WaitReplicaFlowToCatchup(time, replica_ptr.get(), shard)) {
547598
catchup_success.store(false);
548599
}
549600
};
@@ -552,8 +603,9 @@ void DflyCmd::TakeOver(CmdArgList args, RedisReplyBuilder* rb, ConnectionContext
552603

553604
VLOG(1) << "WaitReplicaFlowToCatchup done";
554605

555-
if (*status != OpStatus::OK || !catchup_success.load()) {
606+
if (status != OpStatus::OK || !catchup_success.load()) {
556607
sf_->service().SwitchState(GlobalState::TAKEN_OVER, GlobalState::ACTIVE);
608+
LOG(INFO) << status << " " << catchup_success.load() << " " << &status;
557609
return rb->SendError("Takeover failed!");
558610
}
559611

src/server/dflycmd.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,11 @@ class DflyCmd {
183183
// Switch to stable state replication.
184184
void StartStable(CmdArgList args, Transaction* tx, RedisReplyBuilder* rb);
185185

186+
// Helper for takeover flow. Sometimes connections get stuck on send (because of pipelines)
187+
// and this causes the takeover flow to fail because checkpoint messages are not processed.
188+
// This function force shuts down those connection and allows the node to complete the takeover.
189+
void ForceShutdownStuckConnections(uint64_t timeout);
190+
186191
// TAKEOVER <syncid>
187192
// Shut this master down atomically with replica promotion.
188193
void TakeOver(CmdArgList args, RedisReplyBuilder* rb, ConnectionContext* cntx);

tests/dragonfly/replication_test.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3668,7 +3668,7 @@ async def test_replica_of_self(async_client):
36683668
await async_client.execute_command(f"replicaof 127.0.0.1 {port}")
36693669

36703670

3671-
@dfly_args({"proactor_threads": 2})
3671+
@dfly_args({"proactor_threads": 2, "experimental_force_takeover": True})
36723672
async def test_takeover_with_stuck_connections(df_factory: DflyInstanceFactory):
36733673
master = df_factory.create()
36743674
master.start()
@@ -3686,8 +3686,13 @@ async def test_takeover_with_stuck_connections(df_factory: DflyInstanceFactory):
36863686

36873687
async def get_task():
36883688
while True:
3689-
writer.write(f"GET a\n".encode())
3690-
await writer.drain()
3689+
# Will get killed by takeover because it's stucked
3690+
try:
3691+
writer.write(f"GET a\n".encode())
3692+
await writer.drain()
3693+
except:
3694+
return
3695+
36913696
await asyncio.sleep(0.1)
36923697

36933698
get = asyncio.create_task(get_task())
@@ -3715,5 +3720,7 @@ async def wait_for_stuck_on_send():
37153720
async with async_timeout.timeout(240):
37163721
await wait_for_replicas_state(replica_cl)
37173722

3718-
with pytest.raises(redis.exceptions.ResponseError) as e:
3719-
await replica_cl.execute_command("REPLTAKEOVER 5")
3723+
res = await replica_cl.execute_command("REPLTAKEOVER 5")
3724+
assert res == "OK"
3725+
3726+
await get

0 commit comments

Comments
 (0)