From d64c52b9a8065a919531fa9821955913be3ce1ac Mon Sep 17 00:00:00 2001 From: Denis Ignatenko Date: Thu, 6 Jun 2019 14:28:21 +0300 Subject: [PATCH] Support cluster discovery in MeshConnection This feature adds the new optional arguments to the MeshConnection contructor: * `cluster_discovery_function` -- a name of the function which will be periodically called on a currently connected tarantool instance to update a list of MeshConnection addresses. * `cluster_discovery_delay` -- minimal amount of seconds between address list updates (default is 60 seconds). The update of addresses is performed right after successful connecting and before performing a request (if a minimal time passes). This commits changes the round robin retry strategy. Before it performs two attempts to connect to each address reconnect_max_attempts times (3 by default), now it do that only once. The new type of error is added: ConfigurationError. It is risen when a user provides incorrect configuration: say, one of provided addresses is not correct. The new type of warning is added: ClusterDiscoveryWarning. This warning is shown when a something went wrong during cluster discovery: say, one of returned addresses is not correct. Note the difference: a user provided configuration verified strictly, while a cluster discovery function result is filtered (with warnings) and good addresses are applied (if the list is not empty). Aside of the new functionality this commit improves compatibility of MeshConnection API with Connection. The following arguments are added to the MeshConnection constructor: `host`, `port`, `call_16`, `connection_timeout`. An address from `host` / `port` arguments is added to `addrs` (if provided) as the first item. Fixes #134. --- doc/api/class-mesh-connection.rst | 8 + doc/index.rst | 1 + doc/index.ru.rst | 1 + tarantool/const.py | 2 + tarantool/error.py | 12 ++ tarantool/mesh_connection.py | 329 +++++++++++++++++++++++++++--- unit/suites/__init__.py | 3 +- unit/suites/test_mesh.py | 258 +++++++++++++++++++++++ 8 files changed, 587 insertions(+), 27 deletions(-) create mode 100644 doc/api/class-mesh-connection.rst create mode 100644 unit/suites/test_mesh.py diff --git a/doc/api/class-mesh-connection.rst b/doc/api/class-mesh-connection.rst new file mode 100644 index 00000000..d1048eba --- /dev/null +++ b/doc/api/class-mesh-connection.rst @@ -0,0 +1,8 @@ + +.. currentmodule:: tarantool.mesh_connection + +class :class:`MeshConnection` +----------------------------- + +.. autoclass:: MeshConnection + diff --git a/doc/index.rst b/doc/index.rst index c89da773..346c656c 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -40,6 +40,7 @@ API Reference api/module-tarantool.rst api/class-connection.rst + api/class-mesh-connection.rst api/class-space.rst api/class-response.rst diff --git a/doc/index.ru.rst b/doc/index.ru.rst index 2708d8ab..ddf1b528 100644 --- a/doc/index.ru.rst +++ b/doc/index.ru.rst @@ -40,6 +40,7 @@ api/module-tarantool.rst api/class-connection.rst + api/class-mesh-connection.rst api/class-space.rst api/class-response.rst diff --git a/tarantool/const.py b/tarantool/const.py index 1ebad494..9d175974 100644 --- a/tarantool/const.py +++ b/tarantool/const.py @@ -86,3 +86,5 @@ RECONNECT_MAX_ATTEMPTS = 10 # Default delay between attempts to reconnect (seconds) RECONNECT_DELAY = 0.1 +# Default cluster nodes list refresh interval (seconds) +CLUSTER_DISCOVERY_DELAY = 60 diff --git a/tarantool/error.py b/tarantool/error.py index f49ba60c..cc66e8c5 100644 --- a/tarantool/error.py +++ b/tarantool/error.py @@ -43,6 +43,12 @@ class InterfaceError(Error): ''' +class ConfigurationError(Error): + ''' + Error of initialization with a user-provided configuration. + ''' + + # Monkey patch os.strerror for win32 if sys.platform == "win32": # Windows Sockets Error Codes (not all, but related on network errors) @@ -152,6 +158,11 @@ class NetworkWarning(UserWarning): pass +class ClusterDiscoveryWarning(UserWarning): + '''Warning related to cluster discovery''' + pass + + # always print this warnings warnings.filterwarnings("always", category=NetworkWarning) @@ -166,6 +177,7 @@ def warn(message, warning_class): line_no = frame.f_lineno warnings.warn_explicit(message, warning_class, module_name, line_no) + _strerror = { 0: ("ER_UNKNOWN", "Unknown error"), 1: ("ER_ILLEGAL_PARAMS", "Illegal parameters, %s"), diff --git a/tarantool/mesh_connection.py b/tarantool/mesh_connection.py index a2d69c56..d1ed851b 100644 --- a/tarantool/mesh_connection.py +++ b/tarantool/mesh_connection.py @@ -4,29 +4,182 @@ between tarantool instances and basic Round-Robin strategy. ''' +import time + + from tarantool.connection import Connection -from tarantool.error import NetworkError +from tarantool.error import ( + warn, + NetworkError, + DatabaseError, + ConfigurationError, + ClusterDiscoveryWarning, +) from tarantool.utils import ENCODING_DEFAULT from tarantool.const import ( + CONNECTION_TIMEOUT, SOCKET_TIMEOUT, RECONNECT_MAX_ATTEMPTS, - RECONNECT_DELAY + RECONNECT_DELAY, + CLUSTER_DISCOVERY_DELAY, +) + +from tarantool.request import ( + RequestCall ) +try: + string_types = basestring +except NameError: + string_types = str + + +def parse_uri(uri): + def parse_error(uri, msg): + msg = 'URI "%s": %s' % (uri, msg) + return None, msg + + if not uri: + return parse_error(uri, 'should not be None or empty string') + if not isinstance(uri, string_types): + return parse_error(uri, 'should be of a string type') + if uri.count(':') != 1: + return parse_error(uri, 'does not match host:port scheme') + + host, port_str = uri.split(':', 1) + if not host: + return parse_error(uri, 'host value is empty') + + try: + port = int(port_str) + except ValueError: + return parse_error(uri, 'port should be a number') + + return {'host': host, 'port': port}, None + + +def validate_address(address): + messages = [] + + if isinstance(address, dict): + if "host" not in address: + messages.append("host key must be set") + elif not isinstance(address["host"], string_types): + messages.append("host value must be string type") + + if "port" not in address: + messages.append("port is not set") + elif not isinstance(address["port"], int): + messages.append("port value must be int type") + elif address["port"] == 0: + messages.append("port value must not be zero") + elif address["port"] > 65535: + messages.append("port value must not be above 65535") + else: + messages.append("address must be a dict") + + if messages: + messages_str = ', '.join(messages) + msg = 'Address %s: %s' % (str(address), messages_str) + return None, msg + + return True, None + class RoundRobinStrategy(object): + """ + Simple round-robin address rotation + """ def __init__(self, addrs): - self.addrs = addrs - self.pos = 0 + self.update(addrs) + + def update(self, new_addrs): + # Verify new_addrs is a non-empty list. + assert new_addrs and isinstance(new_addrs, list) + + # Remove duplicates. + new_addrs_unique = [] + for addr in new_addrs: + if addr not in new_addrs_unique: + new_addrs_unique.append(addr) + new_addrs = new_addrs_unique + + # Save a current address if any. + if 'pos' in self.__dict__ and 'addrs' in self.__dict__: + current_addr = self.addrs[self.pos] + else: + current_addr = None + + # Determine a position of a current address (if any) in + # the new addresses list. + if current_addr and current_addr in new_addrs: + new_pos = new_addrs.index(current_addr) + else: + new_pos = -1 + + self.addrs = new_addrs + self.pos = new_pos def getnext(self): - tmp = self.pos self.pos = (self.pos + 1) % len(self.addrs) - return self.addrs[tmp] + return self.addrs[self.pos] class MeshConnection(Connection): - def __init__(self, addrs, + ''' + Represents a connection to a cluster of Tarantool servers. + + This class uses Connection to connect to one of the nodes of the cluster. + The initial list of nodes is passed to the constructor in 'addrs' parameter. + The class set in 'strategy_class' parameter is used to select a node from + the list and switch nodes in case of unavailability of the current node. + + 'cluster_discovery_function' param of the constructor sets the name of a + stored Lua function used to refresh the list of available nodes. The + function takes no parameters and returns a list of strings in format + 'host:port'. A generic function for getting the list of nodes looks like + this: + + .. code-block:: lua + + function get_cluster_nodes() + return { + '192.168.0.1:3301', + '192.168.0.2:3302', + -- ... + } + end + + You may put in this list whatever you need depending on your cluster + topology. Chances are you'll want to make the list of nodes from nodes' + replication config. Here is an example for it: + + .. code-block:: lua + + local uri_lib = require('uri') + + function get_cluster_nodes() + local nodes = {} + + local replicas = box.cfg.replication + + for i = 1, #replicas do + local uri = uri_lib.parse(replicas[i]) + + if uri.host and uri.service then + table.insert(nodes, uri.host .. ':' .. uri.service) + end + end + + -- if your replication config doesn't contain the current node + -- you have to add it manually like this: + table.insert(nodes, '192.168.0.1:3301') + + return nodes + end + ''' + + def __init__(self, host=None, port=None, user=None, password=None, socket_timeout=SOCKET_TIMEOUT, @@ -34,32 +187,156 @@ def __init__(self, addrs, reconnect_delay=RECONNECT_DELAY, connect_now=True, encoding=ENCODING_DEFAULT, - strategy_class=RoundRobinStrategy): - self.nattempts = 2 * len(addrs) + 1 + call_16=False, + connection_timeout=CONNECTION_TIMEOUT, + addrs=None, + strategy_class=RoundRobinStrategy, + cluster_discovery_function=None, + cluster_discovery_delay=CLUSTER_DISCOVERY_DELAY): + if addrs is None: + addrs = [] + else: + # Don't change user provided arguments. + addrs = addrs[:] + + if host and port: + addrs.insert(0, {'host': host, 'port': port}) + + # Verify that at least one address is provided. + if not addrs: + raise ConfigurationError( + 'Neither "host" and "port", nor "addrs" arguments are set') + + # Verify addresses. + for addr in addrs: + ok, msg = validate_address(addr) + if not ok: + raise ConfigurationError(msg) + + self.strategy_class = strategy_class self.strategy = strategy_class(addrs) + addr = self.strategy.getnext() host = addr['host'] port = addr['port'] - super(MeshConnection, self).__init__(host=host, - port=port, - user=user, - password=password, - socket_timeout=socket_timeout, - reconnect_max_attempts=reconnect_max_attempts, - reconnect_delay=reconnect_delay, - connect_now=connect_now, - encoding=encoding) + + self.cluster_discovery_function = cluster_discovery_function + self.cluster_discovery_delay = cluster_discovery_delay + self.last_nodes_refresh = 0 + + super(MeshConnection, self).__init__( + host=host, + port=port, + user=user, + password=password, + socket_timeout=socket_timeout, + reconnect_max_attempts=reconnect_max_attempts, + reconnect_delay=reconnect_delay, + connect_now=connect_now, + encoding=encoding, + call_16=call_16, + connection_timeout=connection_timeout) + + def connect(self): + super(MeshConnection, self).connect() + if self.connected and self.cluster_discovery_function: + self._opt_refresh_instances() def _opt_reconnect(self): - nattempts = self.nattempts - while nattempts > 0: + ''' + Attempt to connect "reconnect_max_attempts" times to each + available address. + ''' + + last_error = None + for _ in range(len(self.strategy.addrs)): try: super(MeshConnection, self)._opt_reconnect() + last_error = None break - except NetworkError: - nattempts -= 1 + except NetworkError as e: + last_error = e addr = self.strategy.getnext() - self.host = addr['host'] - self.port = addr['port'] - else: - raise NetworkError + self.host = addr["host"] + self.port = addr["port"] + + if last_error: + raise last_error + + def _opt_refresh_instances(self): + ''' + Refresh list of tarantool instances in a cluster. + Reconnect if a current instance was gone from the list. + ''' + now = time.time() + + if not self.connected or not self.cluster_discovery_function or \ + now - self.last_nodes_refresh < self.cluster_discovery_delay: + return + + # Call a cluster discovery function w/o reconnection. If + # something going wrong: warn about that and ignore. + request = RequestCall(self, self.cluster_discovery_function, (), + self.call_16) + try: + resp = self._send_request_wo_reconnect(request) + except DatabaseError as e: + msg = 'got "%s" error, skipped addresses updating' % str(e) + warn(msg, ClusterDiscoveryWarning) + return + + if not resp.data or not resp.data[0] or \ + not isinstance(resp.data[0], list): + msg = "got incorrect response instead of URI list, " + \ + "skipped addresses updating" + warn(msg, ClusterDiscoveryWarning) + return + + # Validate received address list. + new_addrs = [] + for uri in resp.data[0]: + addr, msg = parse_uri(uri) + if not addr: + warn(msg, ClusterDiscoveryWarning) + continue + + ok, msg = validate_address(addr) + if not ok: + warn(msg, ClusterDiscoveryWarning) + continue + + new_addrs.append(addr) + + if not new_addrs: + msg = "got no correct URIs, skipped addresses updating" + warn(msg, ClusterDiscoveryWarning) + return + + self.strategy.update(new_addrs) + self.last_nodes_refresh = now + + # Disconnect from a current instance if it was gone from + # an instance list and connect to one of new instances. + current_addr = {'host': self.host, 'port': self.port} + if current_addr not in self.strategy.addrs: + self.close() + addr = self.strategy.getnext() + self.host = addr['host'] + self.port = addr['port'] + self._opt_reconnect() + + def _send_request(self, request): + ''' + Update instances list if "cluster_discovery_function" is provided and a + last update was more then "cluster_discovery_delay" seconds ago. + + After that perform a request as usual and return an instance of + `Response` class. + + :param request: object representing a request + :type request: `Request` instance + + :rtype: `Response` instance + ''' + self._opt_refresh_instances() + return super(MeshConnection, self)._send_request(request) diff --git a/unit/suites/__init__.py b/unit/suites/__init__.py index 3f59862e..ead75297 100644 --- a/unit/suites/__init__.py +++ b/unit/suites/__init__.py @@ -8,9 +8,10 @@ from .test_dml import TestSuite_Request from .test_protocol import TestSuite_Protocol from .test_reconnect import TestSuite_Reconnect +from .test_mesh import TestSuite_Mesh test_cases = (TestSuite_Schema, TestSuite_Request, TestSuite_Protocol, - TestSuite_Reconnect) + TestSuite_Reconnect, TestSuite_Mesh) def load_tests(loader, tests, pattern): suite = unittest.TestSuite() diff --git a/unit/suites/test_mesh.py b/unit/suites/test_mesh.py new file mode 100644 index 00000000..dda59a89 --- /dev/null +++ b/unit/suites/test_mesh.py @@ -0,0 +1,258 @@ +# -*- coding: utf-8 -*- + +from __future__ import print_function + +import sys +import unittest +import warnings +from time import sleep +import tarantool +from tarantool.error import ( + ConfigurationError, + ClusterDiscoveryWarning, +) +from .lib.tarantool_server import TarantoolServer + + +def create_server(_id): + srv = TarantoolServer() + srv.script = 'unit/suites/box.lua' + srv.start() + srv.admin("box.schema.user.create('test', {password = 'test', " + + "if_not_exists = true})") + srv.admin("box.schema.user.grant('test', 'execute', 'universe')") + + # Create srv_id function (for testing purposes). + srv.admin("function srv_id() return %s end" % _id) + return srv + + +@unittest.skipIf(sys.platform.startswith("win"), + 'Mesh tests on windows platform are not supported') +class TestSuite_Mesh(unittest.TestCase): + def define_cluster_function(self, func_name, servers): + addresses = [(srv.host, srv.args['primary']) for srv in servers] + addresses_lua = ",".join("'%s:%d'" % address for address in addresses) + func_body = """ + function %s() + return {%s} + end + """ % (func_name, addresses_lua) + for srv in self.servers: + srv.admin(func_body) + + def define_custom_cluster_function(self, func_name, retval): + func_body = """ + function %s() + return %s + end + """ % (func_name, retval) + for srv in self.servers: + srv.admin(func_body) + + @classmethod + def setUpClass(self): + print(' MESH '.center(70, '='), file=sys.stderr) + print('-' * 70, file=sys.stderr) + + def setUp(self): + # Create two servers and extract helpful fields for tests. + self.srv = create_server(1) + self.srv2 = create_server(2) + self.servers = [self.srv, self.srv2] + self.host_1 = self.srv.host + self.port_1 = self.srv.args['primary'] + self.host_2 = self.srv2.host + self.port_2 = self.srv2.args['primary'] + + # Create get_all_nodes() function on servers. + self.get_all_nodes_func_name = 'get_all_nodes' + self.define_cluster_function(self.get_all_nodes_func_name, + self.servers) + + def test_01_contructor(self): + # Verify that an error is risen when no addresses are + # configured (neither with host/port, nor with addrs). + with self.assertRaises(ConfigurationError): + tarantool.MeshConnection() + + # Verify that a bad address given at initialization leads + # to an error. + bad_addrs = [ + {"port": 1234}, # no host + {"host": "localhost"}, # no port + {"host": "localhost", "port": "1234"}, # port is str + ] + for bad_addr in bad_addrs: + with self.assertRaises(ConfigurationError): + con = tarantool.MeshConnection(bad_addr.get('host'), + bad_addr.get('port')) + with self.assertRaises(ConfigurationError): + con = tarantool.MeshConnection(addrs=[bad_addr]) + + # Verify that identical addresses are squashed. + addrs = [{"host": "localhost", "port": 1234}] + con = tarantool.MeshConnection("localhost", 1234, addrs=addrs, + connect_now=False) + self.assertEqual(len(con.strategy.addrs), 1) + + def test_02_discovery_bad_address(self): + retvals = [ + "", + "1", + "'localhost:1234'", + "{}", + "error('raise an error')", + "{'localhost:foo'}", + "{'localhost:0'}", + "{'localhost:65536'}", + "{'localhost:1234:5678'}", + "{':1234'}", + "{'localhost:'}", + ] + for retval in retvals: + func_name = 'bad_cluster_discovery' + self.define_custom_cluster_function(func_name, retval) + con = tarantool.MeshConnection(self.host_1, self.port_1, + user='test', password='test') + con.cluster_discovery_function = func_name + + # Verify that a cluster discovery (that is triggered + # by ping) give one or two warnings. + with warnings.catch_warnings(record=True) as ws: + con.ping() + self.assertTrue(len(ws) in (1, 2)) + for w in ws: + self.assertIs(w.category, ClusterDiscoveryWarning) + + # Verify that incorrect or empty result was discarded. + self.assertEqual(len(con.strategy.addrs), 1) + self.assertEqual(con.strategy.addrs[0]['host'], self.host_1) + self.assertEqual(con.strategy.addrs[0]['port'], self.port_1) + + con.close() + + def test_03_discovery_bad_good_addresses(self): + func_name = 'bad_and_good_addresses' + retval = "{'localhost:', '%s:%d'}" % (self.host_2, self.port_2) + self.define_custom_cluster_function(func_name, retval) + con = tarantool.MeshConnection(self.host_1, self.port_1, + user='test', password='test') + con.cluster_discovery_function = func_name + + # Verify that a cluster discovery (that is triggered + # by ping) give one warning. + with warnings.catch_warnings(record=True) as ws: + con.ping() + self.assertEqual(len(ws), 1) + self.assertIs(ws[0].category, ClusterDiscoveryWarning) + + # Verify that only second address was accepted. + self.assertEqual(len(con.strategy.addrs), 1) + self.assertEqual(con.strategy.addrs[0]['host'], self.host_2) + self.assertEqual(con.strategy.addrs[0]['port'], self.port_2) + + con.close() + + def test_04_discovery_add_address(self): + # Create a mesh connection; pass only the first server + # address. + con = tarantool.MeshConnection( + self.host_1, self.port_1, user='test', password='test', + cluster_discovery_function=self.get_all_nodes_func_name, + connect_now=False) + + # Verify that the strategy has one address that comes from + # the constructor arguments. + self.assertEqual(len(con.strategy.addrs), 1) + con.connect() + + # Verify that we work with the first server. + resp = con.call('srv_id') + self.assertEqual(resp.data and resp.data[0], 1) + + # Verify that the refresh was successful and the strategy + # has 2 addresses. + self.assertEqual(len(con.strategy.addrs), 2) + + con.close() + + def test_05_discovery_delay(self): + # Create a mesh connection, pass only the first server address. + con = tarantool.MeshConnection( + self.host_1, self.port_1, user='test', password='test', + cluster_discovery_function=self.get_all_nodes_func_name, + cluster_discovery_delay=1) + + # Verify that the strategy has two addresses come from + # the function right after connecting. + self.assertEqual(len(con.strategy.addrs), 2) + + # Drop addresses list to the initial state. + con.strategy.update([con.strategy.addrs[0], ]) + + # Verify that the discovery will not be performed until + # 'cluster_discovery_delay' seconds will be passed. + con.ping() + self.assertEqual(len(con.strategy.addrs), 1) + + sleep(1.1) + + # Refresh after cluster_discovery_delay. + con.ping() + self.assertEqual(len(con.strategy.addrs), 2) + + con.close() + + def test_06_reconnection(self): + # Create a mesh connection; pass only the first server + # address. + con = tarantool.MeshConnection( + self.host_1, self.port_1, user='test', password='test', + cluster_discovery_function=self.get_all_nodes_func_name) + + con.last_nodes_refresh = 0 + resp = con.call('srv_id') + self.assertEqual(resp.data and resp.data[0], 1) + + # Verify that the last discovery was successful and the + # strategy has 2 addresses. + self.assertEqual(len(con.strategy.addrs), 2) + + self.srv.stop() + + # Verify that we switched to the second server. + with warnings.catch_warnings(): + # Suppress reconnection warnings. + warnings.simplefilter("ignore") + resp = con.call('srv_id') + self.assertEqual(resp.data and resp.data[0], 2) + + con.close() + + def test_07_discovery_exclude_address(self): + # Define function to get back only second server. + func_name = 'get_second_node' + self.define_cluster_function(func_name, [self.srv2]) + + # Create a mesh connection, pass only the first server address. + con = tarantool.MeshConnection( + self.host_1, self.port_1, user='test', password='test', + cluster_discovery_function=func_name) + + # Verify that discovery was successful and the strategy + # has 1 address. + self.assertEqual(len(con.strategy.addrs), 1) + + # Verify that the current server is second one. + resp = con.call('srv_id') + self.assertEqual(resp.data and resp.data[0], 2) + + con.close() + + def tearDown(self): + self.srv.stop() + self.srv.clean() + + self.srv2.stop() + self.srv2.clean()