diff --git a/neo4j/bolt/connection.py b/neo4j/bolt/connection.py index 28e4adcac..75fd4851f 100644 --- a/neo4j/bolt/connection.py +++ b/neo4j/bolt/connection.py @@ -42,11 +42,14 @@ from neo4j.meta import version from neo4j.packstream import Packer, Unpacker from neo4j.util import import_best as _import_best +from time import clock ChunkedInputBuffer = _import_best("neo4j.bolt._io", "neo4j.bolt.io").ChunkedInputBuffer ChunkedOutputBuffer = _import_best("neo4j.bolt._io", "neo4j.bolt.io").ChunkedOutputBuffer +INFINITE_CONNECTION_LIFETIME = -1 +DEFAULT_MAX_CONNECTION_LIFETIME = INFINITE_CONNECTION_LIFETIME DEFAULT_CONNECTION_TIMEOUT = 5.0 DEFAULT_PORT = 7687 DEFAULT_USER_AGENT = "neo4j-python/%s" % version @@ -178,6 +181,8 @@ def __init__(self, address, sock, error_handler, **config): self.packer = Packer(self.output_buffer) self.unpacker = Unpacker() self.responses = deque() + self._max_connection_lifetime = config.get("max_connection_lifetime", DEFAULT_MAX_CONNECTION_LIFETIME) + self._creation_timestamp = clock() # Determine the user agent and ensure it is a Unicode value user_agent = config.get("user_agent", DEFAULT_USER_AGENT) @@ -201,6 +206,7 @@ def __init__(self, address, sock, error_handler, **config): # Pick up the server certificate, if any self.der_encoded_server_certificate = config.get("der_encoded_server_certificate") + def Init(self): response = InitResponse(self) self.append(INIT, (self.user_agent, self.auth_dict), response=response) self.sync() @@ -360,6 +366,9 @@ def _unpack(self): more = False return details, summary_signature, summary_metadata + def timedout(self): + return 0 <= self._max_connection_lifetime <= clock() - self._creation_timestamp + def sync(self): """ Send and fetch all outstanding messages. @@ -425,7 +434,7 @@ def acquire_direct(self, address): except KeyError: connections = self.connections[address] = deque() for connection in list(connections): - if connection.closed() or connection.defunct(): + if connection.closed() or connection.defunct() or connection.timedout(): connections.remove(connection) continue if not connection.in_use: @@ -600,8 +609,10 @@ def connect(address, ssl_context=None, error_handler=None, **config): s.shutdown(SHUT_RDWR) s.close() elif agreed_version == 1: - return Connection(address, s, der_encoded_server_certificate=der_encoded_server_certificate, + connection = Connection(address, s, der_encoded_server_certificate=der_encoded_server_certificate, error_handler=error_handler, **config) + connection.Init() + return connection elif agreed_version == 0x48545450: log_error("S: [CLOSE]") s.close() diff --git a/test/integration/test_connection.py b/test/integration/test_connection.py index 703f97df0..fb0d0ab7c 100644 --- a/test/integration/test_connection.py +++ b/test/integration/test_connection.py @@ -22,7 +22,6 @@ from socket import create_connection from neo4j.v1 import ConnectionPool, ServiceUnavailable, DirectConnectionErrorHandler - from test.integration.tools import IntegrationTestCase @@ -44,6 +43,9 @@ def closed(self): def defunct(self): return False + def timedout(self): + return False + def connector(address, _): return QuickConnection(create_connection(address)) @@ -119,4 +121,4 @@ def test_in_use_count(self): connection = self.pool.acquire_direct(address) self.assertEqual(self.pool.in_use_connection_count(address), 1) self.pool.release(connection) - self.assertEqual(self.pool.in_use_connection_count(address), 0) + self.assertEqual(self.pool.in_use_connection_count(address), 0) \ No newline at end of file diff --git a/test/unit/test_connection.py b/test/unit/test_connection.py new file mode 100644 index 000000000..1d10a4998 --- /dev/null +++ b/test/unit/test_connection.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +# Copyright (c) 2002-2017 "Neo Technology," +# Network Engine for Objects in Lund AB [http://neotechnology.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import TestCase +from neo4j.v1 import DirectConnectionErrorHandler +from neo4j.bolt import Connection + + +class FakeSocket(object): + def __init__(self, address): + self.address = address + + def getpeername(self): + return self.address + + def sendall(self, data): + return + + def close(self): + return + + +class ConnectionTestCase(TestCase): + + def test_conn_timedout(self): + address = ("127.0.0.1", 7687) + connection = Connection(address, FakeSocket(address), DirectConnectionErrorHandler(), max_connection_lifetime=0) + self.assertEqual(connection.timedout(), True) + + def test_conn_not_timedout_if_not_enabled(self): + address = ("127.0.0.1", 7687) + connection = Connection(address, FakeSocket(address), DirectConnectionErrorHandler(), + max_connection_lifetime=-1) + self.assertEqual(connection.timedout(), False) + + def test_conn_not_timedout(self): + address = ("127.0.0.1", 7687) + connection = Connection(address, FakeSocket(address), DirectConnectionErrorHandler(), + max_connection_lifetime=999999999) + self.assertEqual(connection.timedout(), False) \ No newline at end of file