diff --git a/src/server/dflycmd.cc b/src/server/dflycmd.cc index 784fa21da33e..496e4cd38049 100644 --- a/src/server/dflycmd.cc +++ b/src/server/dflycmd.cc @@ -40,6 +40,8 @@ ABSL_RETIRED_FLAG(uint32_t, allow_partial_sync_with_lsn_diff, 0, ABSL_DECLARE_FLAG(bool, info_replication_valkey_compatible); ABSL_DECLARE_FLAG(uint32_t, replication_timeout); ABSL_DECLARE_FLAG(uint32_t, shard_repl_backlog_len); +ABSL_FLAG(bool, experimental_force_takeover, false, + "Attempts to force takeover in case of stuck connections"); namespace dfly { @@ -466,6 +468,49 @@ std::optional DflyCmd::ParseLsnVec(std::string_view last_master_lsn, return {lsn_vec[flow_id]}; } +void DflyCmd::ForceShutdownStuckConnections(uint64_t timeout) { + // per proactor map + vector conn_refs; + + auto cb = [&](unsigned thread_index, util::Connection* conn) { + facade::Connection* dcon = static_cast(conn); + LOG(INFO) << dcon->DebugInfo(); + // Kill Connection here + facade::Connection* dfly_conn = static_cast(conn); + using Phase = facade::Connection::Phase; + auto phase = dfly_conn->phase(); + if (dfly_conn->cntx() && dfly_conn->cntx()->replica_conn) { + return; + } + + bool idle_read = phase == Phase::READ_SOCKET && dfly_conn->idle_time() > timeout; + + bool stuck_sending = dfly_conn->IsSending() && dfly_conn->GetSendWaitTimeSec() > timeout; + + if (idle_read || stuck_sending) { + LOG(INFO) << "Connection check: " << dfly_conn->GetClientInfo() + << ", phase=" << static_cast(phase) << ", idle_time=" << dfly_conn->idle_time() + << ", is_sending=" << dfly_conn->IsSending() << ", idle_read=" << idle_read + << ", stuck_sending=" << stuck_sending; + } + conn_refs.push_back(dfly_conn->Borrow()); + }; + + for (auto* listener : sf_->GetListeners()) { + if (listener->IsMainInterface()) { + listener->TraverseConnectionsOnThread(cb, UINT32_MAX, nullptr); + } + } + + VLOG(1) << "Found " << conn_refs.size() << " stucked connections "; + for (auto& ref : conn_refs) { + facade::Connection* conn = ref.Get(); + if (conn) { + conn->ShutdownSelfBlocking(); + } + } +} + // DFLY TAKEOVER [SAVE] // timeout_sec - number of seconds to wait for TAKEOVER to converge. // SAVE option is used only by tests. @@ -506,7 +551,7 @@ void DflyCmd::TakeOver(CmdArgList args, RedisReplyBuilder* rb, ConnectionContext LOG(INFO) << "Takeover initiated, locking down the database."; absl::Duration timeout_dur = absl::Seconds(timeout); absl::Time end_time = absl::Now() + timeout_dur; - AggregateStatus status; + OpStatus status = OpStatus::OK; // We need to await for all dispatches to finish: Otherwise a transaction might be scheduled // after this function exits but before the actual shutdown. @@ -520,13 +565,20 @@ void DflyCmd::TakeOver(CmdArgList args, RedisReplyBuilder* rb, ConnectionContext LOG(WARNING) << "Couldn't wait for commands to finish dispatching. " << timeout_dur; status = OpStatus::TIMED_OUT; - auto cb = [&](unsigned thread_index, util::Connection* conn) { - facade::Connection* dcon = static_cast(conn); - LOG(INFO) << dcon->DebugInfo(); - }; - - for (auto* listener : sf_->GetListeners()) { - listener->TraverseConnections(cb); + // Force takeover on the same duration if flag is set + if (absl::GetFlag(FLAGS_experimental_force_takeover)) { + facade::DispatchTracker tracker{sf_->GetNonPriviligedListeners(), cntx->conn(), false, false}; + shard_set->pool()->AwaitFiberOnAll([&](unsigned index, auto* pb) { + ForceShutdownStuckConnections(uint64_t(timeout)); + sf_->CancelBlockingOnThread(); + tracker.TrackOnThread(); + }); + + status = OpStatus::OK; + if (!tracker.Wait(timeout_dur)) { + LOG(ERROR) << "Could not force execute takeover"; + status = OpStatus::TIMED_OUT; + } } } @@ -540,10 +592,11 @@ void DflyCmd::TakeOver(CmdArgList args, RedisReplyBuilder* rb, ConnectionContext }); atomic_bool catchup_success = true; - if (*status == OpStatus::OK) { + if (status == OpStatus::OK) { dfly::SharedLock lk{replica_ptr->shared_mu}; - auto cb = [replica_ptr = replica_ptr, end_time, &catchup_success](EngineShard* shard) { - if (!WaitReplicaFlowToCatchup(end_time, replica_ptr.get(), shard)) { + auto time = end_time + timeout_dur; + auto cb = [replica_ptr = replica_ptr, time, &catchup_success](EngineShard* shard) { + if (!WaitReplicaFlowToCatchup(time, replica_ptr.get(), shard)) { catchup_success.store(false); } }; @@ -552,8 +605,9 @@ void DflyCmd::TakeOver(CmdArgList args, RedisReplyBuilder* rb, ConnectionContext VLOG(1) << "WaitReplicaFlowToCatchup done"; - if (*status != OpStatus::OK || !catchup_success.load()) { + if (status != OpStatus::OK || !catchup_success.load()) { sf_->service().SwitchState(GlobalState::TAKEN_OVER, GlobalState::ACTIVE); + LOG(INFO) << status << " " << catchup_success.load() << " " << &status; return rb->SendError("Takeover failed!"); } diff --git a/src/server/dflycmd.h b/src/server/dflycmd.h index 530d176eea8d..bd35a70dbac8 100644 --- a/src/server/dflycmd.h +++ b/src/server/dflycmd.h @@ -183,6 +183,11 @@ class DflyCmd { // Switch to stable state replication. void StartStable(CmdArgList args, Transaction* tx, RedisReplyBuilder* rb); + // Helper for takeover flow. Sometimes connections get stuck on send (because of pipelines) + // and this causes the takeover flow to fail because checkpoint messages are not processed. + // This function force shuts down those connection and allows the node to complete the takeover. + void ForceShutdownStuckConnections(uint64_t timeout); + // TAKEOVER // Shut this master down atomically with replica promotion. void TakeOver(CmdArgList args, RedisReplyBuilder* rb, ConnectionContext* cntx); diff --git a/tests/dragonfly/replication_test.py b/tests/dragonfly/replication_test.py index cdbbfaaef1af..14ce752b5767 100644 --- a/tests/dragonfly/replication_test.py +++ b/tests/dragonfly/replication_test.py @@ -3666,3 +3666,61 @@ async def test_replica_of_self(async_client): with pytest.raises(redis.exceptions.ResponseError): await async_client.execute_command(f"replicaof 127.0.0.1 {port}") + + +@dfly_args({"proactor_threads": 2, "experimental_force_takeover": True}) +async def test_takeover_with_stuck_connections(df_factory: DflyInstanceFactory): + master = df_factory.create() + master.start() + + async_client = master.client() + await async_client.execute_command("debug populate 2000") + + reader, writer = await asyncio.open_connection("127.0.0.1", master.port) + writer.write(f"client setname writer_test\n".encode()) + await writer.drain() + assert "OK" in (await reader.readline()).decode() + size = 1024 * 1024 + writer.write(f"SET a {'v'*size}\n".encode()) + await writer.drain() + + async def get_task(): + while True: + # Will get killed by takeover because it's stucked + try: + writer.write(f"GET a\n".encode()) + await writer.drain() + except: + return + + await asyncio.sleep(0.1) + + get = asyncio.create_task(get_task()) + + @assert_eventually(times=600) + async def wait_for_stuck_on_send(): + clients = await async_client.client_list() + logging.info("wait_for_stuck_on_send clients: %s", clients) + phase = next( + (client["phase"] for client in clients if client["name"] == "writer_test"), None + ) + assert phase == "send" + + await wait_for_stuck_on_send() + + replica = df_factory.create() + replica.start() + + replica_cl = replica.client() + + res = await replica_cl.execute_command(f"replicaof localhost {master.port}") + assert res == "OK" + + # Wait for all replicas to transition into stable sync + async with async_timeout.timeout(240): + await wait_for_replicas_state(replica_cl) + + res = await replica_cl.execute_command("REPLTAKEOVER 5") + assert res == "OK" + + await get