Skip to content

Stop reusing stream ids of requests that have timed out due to client-side timeout #1114

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -4361,10 +4361,17 @@ def _on_timeout(self, _attempts=0):

pool = self.session._pools.get(self._current_host)
if pool and not pool.is_shutdown:
# Do not return the stream ID to the pool yet. We cannot reuse it
# because the node might still be processing the query and will
# return a late response to that query - if we used such stream
# before the response to the previous query has arrived, the new
# query could get a response from the old query
with self._connection.lock:
self._connection.request_ids.append(self._req_id)
self._connection.orphaned_request_ids.add(self._req_id)
if len(self._connection.orphaned_request_ids) >= self._connection.orphaned_threshold:
self._connection.orphaned_threshold_reached = True

pool.return_connection(self._connection)
pool.return_connection(self._connection, stream_was_orphaned=True)

errors = self._errors
if not errors:
Expand Down
32 changes: 31 additions & 1 deletion cassandra/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,7 @@ class Connection(object):

# The current number of operations that are in flight. More precisely,
# the number of request IDs that are currently in use.
# This includes orphaned requests.
in_flight = 0

# Max concurrent requests allowed per connection. This is set optimistically high, allowing
Expand All @@ -707,6 +708,20 @@ class Connection(object):
# request_ids set
highest_request_id = 0

# Tracks the request IDs which are no longer waited on (timed out), but
# cannot be reused yet because the node might still send a response
# on this stream
orphaned_request_ids = None

# Set to true if the orphaned stream ID count cross configured threshold
# and the connection will be replaced
orphaned_threshold_reached = False

# If the number of orphaned streams reaches this threshold, this connection
# will become marked and will be replaced with a new connection by the
# owning pool (currently, only HostConnection supports this)
orphaned_threshold = 3 * max_in_flight // 4

is_defunct = False
is_closed = False
lock = None
Expand All @@ -733,6 +748,8 @@ class Connection(object):

_is_checksumming_enabled = False

_on_orphaned_stream_released = None

@property
def _iobuf(self):
# backward compatibility, to avoid any change in the reactors
Expand All @@ -742,7 +759,7 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
ssl_options=None, sockopts=None, compression=True,
cql_version=None, protocol_version=ProtocolVersion.MAX_SUPPORTED, is_control_connection=False,
user_type_map=None, connect_timeout=None, allow_beta_protocol_version=False, no_compact=False,
ssl_context=None):
ssl_context=None, on_orphaned_stream_released=None):

# TODO next major rename host to endpoint and remove port kwarg.
self.endpoint = host if isinstance(host, EndPoint) else DefaultEndPoint(host, port)
Expand All @@ -764,6 +781,8 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
self._io_buffer = _ConnectionIOBuffer(self)
self._continuous_paging_sessions = {}
self._socket_writable = True
self.orphaned_request_ids = set()
self._on_orphaned_stream_released = on_orphaned_stream_released

if ssl_options:
self._check_hostname = bool(self.ssl_options.pop('check_hostname', False))
Expand Down Expand Up @@ -1188,11 +1207,22 @@ def process_msg(self, header, body):
decoder = paging_session.decoder
result_metadata = None
else:
need_notify_of_release = False
with self.lock:
if stream_id in self.orphaned_request_ids:
self.in_flight -= 1
self.orphaned_request_ids.remove(stream_id)
need_notify_of_release = True
if need_notify_of_release and self._on_orphaned_stream_released:
self._on_orphaned_stream_released()

try:
callback, decoder, result_metadata = self._requests.pop(stream_id)
# This can only happen if the stream_id was
# removed due to an OperationTimedOut
except KeyError:
with self.lock:
self.request_ids.append(stream_id)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be very clear... this change isn't explicitly connected to the problem you're trying to address in the PR, correct @haaawk ? The intent here is that if we get a stream ID that doesn't match up to a request we know about we should be safe to re-use that stream ID regardless of other conditions... do I have that right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In 12c9c30:

  1. If a request R has timed out due to a client side timeout, we stopped returning its stream ID to the pool of available stream IDs
  2. If the same request R receives a response after R has timed out on a client side, the stream ID of R will be put back to the pool of available stream IDs

This line performs (2).

if we get a stream ID that doesn't match up to a request we know about we should be safe to re-use that stream ID regardless of other conditions... do I have that right?

Yes. The assumption is that if we got a stream ID we don't know then it must be a stream ID of a request that has already timed out due to a client side timeout.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, thanks for the verification @haaawk !

return

try:
Expand Down
97 changes: 80 additions & 17 deletions cassandra/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,10 @@ def __init__(self, host, host_distance, session):
# this is used in conjunction with the connection streams. Not using the connection lock because the connection can be replaced in the lifetime of the pool.
self._stream_available_condition = Condition(self._lock)
self._is_replacing = False
# Contains connections which shouldn't be used anymore
# and are waiting until all requests time out or complete
# so that we can dispose of them.
self._trash = set()

if host_distance == HostDistance.IGNORED:
log.debug("Not opening connection to ignored host %s", self.host)
Expand All @@ -399,42 +403,59 @@ def __init__(self, host, host_distance, session):
return

log.debug("Initializing connection for host %s", self.host)
self._connection = session.cluster.connection_factory(host.endpoint)
self._connection = session.cluster.connection_factory(host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released)
self._keyspace = session.keyspace
if self._keyspace:
self._connection.set_keyspace_blocking(self._keyspace)
log.debug("Finished initializing connection for host %s", self.host)

def borrow_connection(self, timeout):
def _get_connection(self):
if self.is_shutdown:
raise ConnectionException(
"Pool for %s is shutdown" % (self.host,), self.host)

conn = self._connection
if not conn:
raise NoConnectionsAvailable()
return conn

def borrow_connection(self, timeout):
conn = self._get_connection()
if conn.orphaned_threshold_reached:
with self._lock:
if not self._is_replacing:
self._is_replacing = True
self._session.submit(self._replace, conn)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once this completes self._connection will be reset but conn will still be pointing to the prior value of self._connection which... has hit it's threshold. Seems like we want to update conn here as well. Shouldn't be much more than just another "conn = self._get_connection()" after the log message below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that this code is ok.

The intention is to keep using the old connection until a new connection is ready to operate. Otherwise we would block the client until the new connection is read and we probably don't want to do this. self._get_connection() will start to return a new connection after _replace assigns self._connection. It's ok to use the old connection for a bit longer as the new connection should be established relatively quickly. Does that make sense, @absurdfarce?

Until the replacing is called self._get_connection() will return the same connection that's already assigned to conn.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, after reading this more carefully it became apparent I wasn't evaluating the full context here @haaawk . If anything I'd say the requirement is stronger than your argument above. "conn" is scoped to borrow_connection() and we're using it to validate a good working connection before returning from that fn in the "while True" loop below. There's no obvious point to setting conn at this point since the loop handles all of that.

I'm good with what you have here.

log.debug(
"Connection to host %s reached orphaned stream limit, replacing...",
self.host
)

start = time.time()
remaining = timeout
while True:
with conn.lock:
if conn.in_flight < conn.max_request_id:
if not (conn.orphaned_threshold_reached and conn.is_closed) and conn.in_flight < conn.max_request_id:
conn.in_flight += 1
return conn, conn.get_request_id()
if timeout is not None:
remaining = timeout - time.time() + start
if remaining < 0:
break
with self._stream_available_condition:
self._stream_available_condition.wait(remaining)
if conn.orphaned_threshold_reached and conn.is_closed:
conn = self._get_connection()
else:
self._stream_available_condition.wait(remaining)

raise NoConnectionsAvailable("All request IDs are currently in use")

def return_connection(self, connection):
with connection.lock:
connection.in_flight -= 1
with self._stream_available_condition:
self._stream_available_condition.notify()
def return_connection(self, connection, stream_was_orphaned=False):
if not stream_was_orphaned:
with connection.lock:
connection.in_flight -= 1
with self._stream_available_condition:
self._stream_available_condition.notify()

if connection.is_defunct or connection.is_closed:
if connection.signaled_error and not self.shutdown_on_error:
Expand All @@ -461,6 +482,24 @@ def return_connection(self, connection):
return
self._is_replacing = True
self._session.submit(self._replace, connection)
else:
if connection in self._trash:
with connection.lock:
if connection.in_flight == len(connection.orphaned_request_ids):
with self._lock:
if connection in self._trash:
self._trash.remove(connection)
log.debug("Closing trashed connection (%s) to %s", id(connection), self.host)
connection.close()
return

def on_orphaned_stream_released(self):
"""
Called when a response for an orphaned stream (timed out on the client
side) was received.
"""
with self._stream_available_condition:
self._stream_available_condition.notify()

def _replace(self, connection):
with self._lock:
Expand All @@ -469,17 +508,23 @@ def _replace(self, connection):

log.debug("Replacing connection (%s) to %s", id(connection), self.host)
try:
conn = self._session.cluster.connection_factory(self.host.endpoint)
conn = self._session.cluster.connection_factory(self.host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released)
if self._keyspace:
conn.set_keyspace_blocking(self._keyspace)
self._connection = conn
except Exception:
log.warning("Failed reconnecting %s. Retrying." % (self.host.endpoint,))
self._session.submit(self._replace, connection)
else:
with self._lock:
self._is_replacing = False
self._stream_available_condition.notify()
with connection.lock:
with self._lock:
if connection.orphaned_threshold_reached:
if connection.in_flight == len(connection.orphaned_request_ids):
connection.close()
else:
self._trash.add(connection)
self._is_replacing = False
self._stream_available_condition.notify()

def shutdown(self):
with self._lock:
Expand All @@ -493,6 +538,16 @@ def shutdown(self):
self._connection.close()
self._connection = None

trash_conns = None
with self._lock:
if self._trash:
trash_conns = self._trash
self._trash = set()

if trash_conns is not None:
for conn in self._trash:
conn.close()

def _set_keyspace_for_all_conns(self, keyspace, callback):
if self.is_shutdown or not self._connection:
return
Expand Down Expand Up @@ -548,7 +603,7 @@ def __init__(self, host, host_distance, session):

log.debug("Initializing new connection pool for host %s", self.host)
core_conns = session.cluster.get_core_connections_per_host(host_distance)
self._connections = [session.cluster.connection_factory(host.endpoint)
self._connections = [session.cluster.connection_factory(host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released)
for i in range(core_conns)]

self._keyspace = session.keyspace
Expand Down Expand Up @@ -652,7 +707,7 @@ def _add_conn_if_under_max(self):

log.debug("Going to open new connection to host %s", self.host)
try:
conn = self._session.cluster.connection_factory(self.host.endpoint)
conn = self._session.cluster.connection_factory(self.host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released)
if self._keyspace:
conn.set_keyspace_blocking(self._session.keyspace)
self._next_trash_allowed_at = time.time() + _MIN_TRASH_INTERVAL
Expand Down Expand Up @@ -712,9 +767,10 @@ def _wait_for_conn(self, timeout):

raise NoConnectionsAvailable()

def return_connection(self, connection):
def return_connection(self, connection, stream_was_orphaned=False):
with connection.lock:
connection.in_flight -= 1
if not stream_was_orphaned:
connection.in_flight -= 1
in_flight = connection.in_flight

if connection.is_defunct or connection.is_closed:
Expand Down Expand Up @@ -750,6 +806,13 @@ def return_connection(self, connection):
else:
self._signal_available_conn()

def on_orphaned_stream_released(self):
"""
Called when a response for an orphaned stream (timed out on the client
side) was received.
"""
self._signal_available_conn()

def _maybe_trash_connection(self, connection):
core_conns = self._session.cluster.get_core_connections_per_host(self.host_distance)
did_trash = False
Expand Down
Binary file added tests/unit/.noseids
Binary file not shown.
20 changes: 10 additions & 10 deletions tests/unit/test_host_connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from cassandra.pool import Host, NoConnectionsAvailable
from cassandra.policies import HostDistance, SimpleConvictionPolicy


class _PoolTests(unittest.TestCase):
PoolImpl = None
uses_single_connection = None
Expand All @@ -45,7 +44,7 @@ def test_borrow_and_return(self):
session.cluster.connection_factory.return_value = conn

pool = self.PoolImpl(host, HostDistance.LOCAL, session)
session.cluster.connection_factory.assert_called_once_with(host.endpoint)
session.cluster.connection_factory.assert_called_once_with(host.endpoint, on_orphaned_stream_released=pool.on_orphaned_stream_released)

c, request_id = pool.borrow_connection(timeout=0.01)
self.assertIs(c, conn)
Expand All @@ -64,7 +63,7 @@ def test_failed_wait_for_connection(self):
session.cluster.connection_factory.return_value = conn

pool = self.PoolImpl(host, HostDistance.LOCAL, session)
session.cluster.connection_factory.assert_called_once_with(host.endpoint)
session.cluster.connection_factory.assert_called_once_with(host.endpoint, on_orphaned_stream_released=pool.on_orphaned_stream_released)

pool.borrow_connection(timeout=0.01)
self.assertEqual(1, conn.in_flight)
Expand All @@ -82,7 +81,7 @@ def test_successful_wait_for_connection(self):
session.cluster.connection_factory.return_value = conn

pool = self.PoolImpl(host, HostDistance.LOCAL, session)
session.cluster.connection_factory.assert_called_once_with(host.endpoint)
session.cluster.connection_factory.assert_called_once_with(host.endpoint, on_orphaned_stream_released=pool.on_orphaned_stream_released)

pool.borrow_connection(timeout=0.01)
self.assertEqual(1, conn.in_flight)
Expand Down Expand Up @@ -110,7 +109,7 @@ def test_spawn_when_at_max(self):
session.cluster.get_max_connections_per_host.return_value = 2

pool = self.PoolImpl(host, HostDistance.LOCAL, session)
session.cluster.connection_factory.assert_called_once_with(host.endpoint)
session.cluster.connection_factory.assert_called_once_with(host.endpoint, on_orphaned_stream_released=pool.on_orphaned_stream_released)

pool.borrow_connection(timeout=0.01)
self.assertEqual(1, conn.in_flight)
Expand All @@ -133,7 +132,7 @@ def test_return_defunct_connection(self):
session.cluster.connection_factory.return_value = conn

pool = self.PoolImpl(host, HostDistance.LOCAL, session)
session.cluster.connection_factory.assert_called_once_with(host.endpoint)
session.cluster.connection_factory.assert_called_once_with(host.endpoint, on_orphaned_stream_released=pool.on_orphaned_stream_released)

pool.borrow_connection(timeout=0.01)
conn.is_defunct = True
Expand All @@ -148,11 +147,12 @@ def test_return_defunct_connection_on_down_host(self):
host = Mock(spec=Host, address='ip1')
session = self.make_session()
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False,
max_request_id=100, signaled_error=False)
max_request_id=100, signaled_error=False,
orphaned_threshold_reached=False)
session.cluster.connection_factory.return_value = conn

pool = self.PoolImpl(host, HostDistance.LOCAL, session)
session.cluster.connection_factory.assert_called_once_with(host.endpoint)
session.cluster.connection_factory.assert_called_once_with(host.endpoint, on_orphaned_stream_released=pool.on_orphaned_stream_released)

pool.borrow_connection(timeout=0.01)
conn.is_defunct = True
Expand All @@ -169,11 +169,11 @@ def test_return_closed_connection(self):
host = Mock(spec=Host, address='ip1')
session = self.make_session()
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=True, max_request_id=100,
signaled_error=False)
signaled_error=False, orphaned_threshold_reached=False)
session.cluster.connection_factory.return_value = conn

pool = self.PoolImpl(host, HostDistance.LOCAL, session)
session.cluster.connection_factory.assert_called_once_with(host.endpoint)
session.cluster.connection_factory.assert_called_once_with(host.endpoint, on_orphaned_stream_released=pool.on_orphaned_stream_released)

pool.borrow_connection(timeout=0.01)
conn.is_closed = True
Expand Down
Loading