diff --git a/neo4j/_async_compat/network/_bolt_socket.py b/neo4j/_async_compat/network/_bolt_socket.py index 2891e5ce..27004188 100644 --- a/neo4j/_async_compat/network/_bolt_socket.py +++ b/neo4j/_async_compat/network/_bolt_socket.py @@ -274,20 +274,20 @@ async def _handshake(self, resolved_address): # If no data is returned after a successful select # response, the server has closed the connection log.debug("[#%04X] S: ", local_port) - self.close() + await self.close() raise ServiceUnavailable( "Connection to {address} closed without handshake response".format( address=resolved_address)) if data_size != 4: # Some garbled data has been received log.debug("[#%04X] S: @*#!", local_port) - self.close() + await self.close() raise BoltProtocolError( "Expected four byte Bolt handshake response from %r, received %r instead; check for incorrect port number" % ( resolved_address, data), address=resolved_address) elif data == b"HTTP": log.debug("[#%04X] S: ", local_port) - self.close() + await self.close() raise ServiceUnavailable( "Cannot to connect to Bolt service on {!r} " "(looks like HTTP)".format(resolved_address)) @@ -298,14 +298,14 @@ async def _handshake(self, resolved_address): @classmethod async def close_socket(cls, socket_): - if isinstance(socket_, socket): - try: + try: + if isinstance(socket_, AsyncBoltSocket): + await socket_.close() + else: socket_.shutdown(SHUT_RDWR) socket_.close() - except OSError: - pass - else: - await socket_.close() + except OSError: + pass @classmethod async def connect(cls, address, *, timeout, custom_resolver, ssl_context, @@ -417,6 +417,10 @@ def recv_into(self, buffer, nbytes): def sendall(self, data): return self._wait_for_io(self._socket.sendall, data) + def close(self): + self._socket.shutdown(SHUT_RDWR) + self._socket.close() + @classmethod def _connect(cls, resolved_address, timeout, keep_alive): """ @@ -555,8 +559,11 @@ def _handshake(cls, s, resolved_address): @classmethod def close_socket(cls, socket_): try: - socket_.shutdown(SHUT_RDWR) - socket_.close() + if isinstance(socket_, BoltSocket): + socket.close() + else: + socket_.shutdown(SHUT_RDWR) + socket_.close() except OSError: pass