From 1c9e49a27057023ab5a1fd7e7f107d73938b34d9 Mon Sep 17 00:00:00 2001 From: starcat37 Date: Thu, 31 Aug 2023 20:56:44 +0900 Subject: [PATCH 1/9] Refactor: add typing to binlogstream.py --- pymysqlreplication/binlogstream.py | 251 +++++++++++++++-------------- 1 file changed, 126 insertions(+), 125 deletions(-) diff --git a/pymysqlreplication/binlogstream.py b/pymysqlreplication/binlogstream.py index c153fcda..de3c2107 100644 --- a/pymysqlreplication/binlogstream.py +++ b/pymysqlreplication/binlogstream.py @@ -5,7 +5,7 @@ import pymysql from pymysql.constants.COMMAND import COM_BINLOG_DUMP, COM_REGISTER_SLAVE -from pymysql.cursors import DictCursor +from pymysql.cursors import DictCursor, Cursor from .constants.BINLOG import TABLE_MAP_EVENT, ROTATE_EVENT, FORMAT_DESCRIPTION_EVENT from .event import ( @@ -20,6 +20,8 @@ from .packet import BinLogPacketWrapper from .row_event import ( UpdateRowsEvent, WriteRowsEvent, DeleteRowsEvent, TableMapEvent) +from typing import ByteString, Union, Optional, List, Tuple, Dict, Any, Iterator, FrozenSet, Type +from pymysql.connections import Connection try: from pymysql.constants.COMMAND import COM_BINLOG_DUMP_GTID @@ -37,25 +39,24 @@ class ReportSlave(object): """Represent the values that you may report when connecting as a slave to a master. SHOW SLAVE HOSTS related""" - hostname = '' - username = '' - password = '' - port = 0 - - def __init__(self, value): + def __init__(self, value: Union[str, tuple[str, str, str, int]]) -> None: """ Attributes: value: string or tuple if string, then it will be used hostname if tuple it will be used as (hostname, user, password, port) """ + self.hostname: str = '' + self.username: str = '' + self.password: str = '' + self.port: int = 0 if isinstance(value, (tuple, list)): try: - self.hostname = value[0] - self.username = value[1] - self.password = value[2] - self.port = int(value[3]) + self.hostname: str = value[0] + self.username: str = value[1] + self.password: str = value[2] + self.port: int = int(value[3]) except IndexError: pass elif isinstance(value, dict): @@ -65,13 +66,13 @@ def __init__(self, value): except KeyError: pass else: - self.hostname = value + self.hostname: Union[str, tuple] = value - def __repr__(self): + def __repr__(self) -> str: return '' % \ (self.hostname, self.username, self.password, self.port) - def encoded(self, server_id, master_id=0): + def encoded(self, server_id: int, master_id: int = 0) -> ByteString: """ server_id: the slave server-id master_id: usually 0. Appears as "master id" in SHOW SLAVE HOSTS @@ -90,23 +91,23 @@ def encoded(self, server_id, master_id=0): # 4 replication rank # 4 master-id - lhostname = len(self.hostname.encode()) - lusername = len(self.username.encode()) - lpassword = len(self.password.encode()) + lhostname: int = len(self.hostname.encode()) + lusername: int = len(self.username.encode()) + lpassword: int = len(self.password.encode()) - packet_len = (1 + # command - 4 + # server-id - 1 + # hostname length - lhostname + - 1 + # username length - lusername + - 1 + # password length - lpassword + - 2 + # slave mysql port - 4 + # replication rank - 4) # master-id + packet_len: int = (1 + # command + 4 + # server-id + 1 + # hostname length + lhostname + + 1 + # username length + lusername + + 1 + # password length + lpassword + + 2 + # slave mysql port + 4 + # replication rank + 4) # master-id - MAX_STRING_LEN = 257 # one byte for length + 256 chars + MAX_STRING_LEN: int = 257 # one byte for length + 256 chars return (struct.pack(' None: """ Attributes: ctl_connection_settings: Connection settings for cluster holding @@ -182,89 +183,89 @@ def __init__(self, connection_settings, server_id, to point to Mariadb specific GTID. annotate_rows_event: Parameter value to enable annotate rows event in mariadb, used with 'is_mariadb' - ignore_decode_errors: If true, any decode errors encountered + ignore_decode_errors: If true, any decode errors encountered when reading column data will be ignored. """ - self.__connection_settings = connection_settings + self.__connection_settings: Dict = connection_settings self.__connection_settings.setdefault("charset", "utf8") - self.__connected_stream = False - self.__connected_ctl = False - self.__resume_stream = resume_stream - self.__blocking = blocking - self._ctl_connection_settings = ctl_connection_settings + self.__connected_stream: bool = False + self.__connected_ctl: bool = False + self.__resume_stream: bool = resume_stream + self.__blocking: bool = blocking + self._ctl_connection_settings: Dict = ctl_connection_settings if ctl_connection_settings: self._ctl_connection_settings.setdefault("charset", "utf8") - self.__only_tables = only_tables - self.__ignored_tables = ignored_tables - self.__only_schemas = only_schemas - self.__ignored_schemas = ignored_schemas - self.__freeze_schema = freeze_schema - self.__allowed_events = self._allowed_event_list( + self.__only_tables: Optional[List[str]] = only_tables + self.__ignored_tables: Optional[List[str]] = ignored_tables + self.__only_schemas: Optional[List[str]] = only_schemas + self.__ignored_schemas: Optional[List[str]] = ignored_schemas + self.__freeze_schema: bool = freeze_schema + self.__allowed_events: FrozenSet[str] = self._allowed_event_list( only_events, ignored_events, filter_non_implemented_events) - self.__fail_on_table_metadata_unavailable = fail_on_table_metadata_unavailable - self.__ignore_decode_errors = ignore_decode_errors + self.__fail_on_table_metadata_unavailable: bool = fail_on_table_metadata_unavailable + self.__ignore_decode_errors: bool = ignore_decode_errors # We can't filter on packet level TABLE_MAP and rotate event because # we need them for handling other operations - self.__allowed_events_in_packet = frozenset( + self.__allowed_events_in_packet: FrozenSet[str] = frozenset( [TableMapEvent, RotateEvent]).union(self.__allowed_events) - self.__server_id = server_id - self.__use_checksum = False + self.__server_id: int = server_id + self.__use_checksum: bool = False # Store table meta information - self.table_map = {} - self.log_pos = log_pos - self.end_log_pos = end_log_pos - self.log_file = log_file - self.auto_position = auto_position - self.skip_to_timestamp = skip_to_timestamp - self.is_mariadb = is_mariadb - self.__annotate_rows_event = annotate_rows_event + self.table_map: Dict = {} + self.log_pos: Optional[int] = log_pos + self.end_log_pos: Optional[int] = end_log_pos + self.log_file: Optional[str] = log_file + self.auto_position: Optional[str] = auto_position + self.skip_to_timestamp: Optional[float] = skip_to_timestamp + self.is_mariadb: bool = is_mariadb + self.__annotate_rows_event: bool = annotate_rows_event if end_log_pos: - self.is_past_end_log_pos = False + self.is_past_end_log_pos: bool = False if report_slave: - self.report_slave = ReportSlave(report_slave) - self.slave_uuid = slave_uuid - self.slave_heartbeat = slave_heartbeat + self.report_slave: Optional[ReportSlave] = ReportSlave(report_slave) + self.slave_uuid: Optional[str] = slave_uuid + self.slave_heartbeat: Optional[float] = slave_heartbeat if pymysql_wrapper: - self.pymysql_wrapper = pymysql_wrapper + self.pymysql_wrapper: Optional[Connection] = pymysql_wrapper else: - self.pymysql_wrapper = pymysql.connect - self.mysql_version = (0, 0, 0) + self.pymysql_wrapper: Optional[Union[Connection, Type[Connection]]] = pymysql.connect + self.mysql_version: Tuple = (0, 0, 0) - def close(self): + def close(self) -> None: if self.__connected_stream: self._stream_connection.close() - self.__connected_stream = False + self.__connected_stream: bool = False if self.__connected_ctl: # break reference cycle between stream reader and underlying # mysql connection object self._ctl_connection._get_table_information = None self._ctl_connection.close() - self.__connected_ctl = False + self.__connected_ctl: bool = False - def __connect_to_ctl(self): + def __connect_to_ctl(self) -> None: if not self._ctl_connection_settings: - self._ctl_connection_settings = dict(self.__connection_settings) + self._ctl_connection_settings: Dict[str, Any] = dict(self.__connection_settings) self._ctl_connection_settings["db"] = "information_schema" self._ctl_connection_settings["cursorclass"] = DictCursor self._ctl_connection_settings["autocommit"] = True - self._ctl_connection = self.pymysql_wrapper(**self._ctl_connection_settings) + self._ctl_connection: Connection = self.pymysql_wrapper(**self._ctl_connection_settings) self._ctl_connection._get_table_information = self.__get_table_information - self.__connected_ctl = True + self.__connected_ctl: bool = True - def __checksum_enabled(self): + def __checksum_enabled(self) -> bool: """Return True if binlog-checksum = CRC32. Only for MySQL > 5.6""" - cur = self._stream_connection.cursor() + cur: Cursor = self._stream_connection.cursor() cur.execute("SHOW GLOBAL VARIABLES LIKE 'BINLOG_CHECKSUM'") - result = cur.fetchone() + result: Optional[Tuple[str, str]] = cur.fetchone() cur.close() if result is None: @@ -274,11 +275,11 @@ def __checksum_enabled(self): return False return True - def _register_slave(self): + def _register_slave(self) -> None: if not self.report_slave: return - packet = self.report_slave.encoded(self.__server_id) + packet: ByteString = self.report_slave.encoded(self.__server_id) if pymysql.__version__ < LooseVersion("0.6"): self._stream_connection.wfile.write(packet) @@ -289,14 +290,14 @@ def _register_slave(self): self._stream_connection._next_seq_id = 1 self._stream_connection._read_packet() - def __connect_to_stream(self): + def __connect_to_stream(self) -> None: # log_pos (4) -- position in the binlog-file to start the stream with # flags (2) BINLOG_DUMP_NON_BLOCK (0 or 1) # server_id (4) -- server id of this slave # log_file (string.EOF) -- filename of the binlog on the master self._stream_connection = self.pymysql_wrapper(**self.__connection_settings) - self.__use_checksum = self.__checksum_enabled() + self.__use_checksum: bool = self.__checksum_enabled() # If checksum is enabled we need to inform the server about the that # we support it @@ -312,17 +313,17 @@ def __connect_to_stream(self): if self.slave_heartbeat: # 4294967 is documented as the max value for heartbeats - net_timeout = float(self.__connection_settings.get('read_timeout', + net_timeout: float = float(self.__connection_settings.get('read_timeout', 4294967)) # If heartbeat is too low, the connection will disconnect before, # this is also the behavior in mysql - heartbeat = float(min(net_timeout / 2., self.slave_heartbeat)) + heartbeat: float = float(min(net_timeout / 2., self.slave_heartbeat)) if heartbeat > 4294967: heartbeat = 4294967 # master_heartbeat_period is nanoseconds - heartbeat = int(heartbeat * 1000000000) - cur = self._stream_connection.cursor() + heartbeat: int = int(heartbeat * 1000000000) + cur: Cursor = self._stream_connection.cursor() cur.execute("SET @master_heartbeat_period= %d" % heartbeat) cur.close() @@ -330,7 +331,7 @@ def __connect_to_stream(self): # Mariadb, when it tries to replace GTID events with dummy ones. Given that this library understands GTID # events, setting the capability to 4 circumvents this error. # If the DB is mysql, this won't have any effect so no need to run this in a condition - cur = self._stream_connection.cursor() + cur: Cursor = self._stream_connection.cursor() cur.execute("SET @mariadb_slave_capability=4") cur.close() @@ -343,15 +344,15 @@ def __connect_to_stream(self): # only when log_file and log_pos both provided, the position info is # valid, if not, get the current position from master if self.log_file is None or self.log_pos is None: - cur = self._stream_connection.cursor() + cur: Cursor = self._stream_connection.cursor() cur.execute("SHOW MASTER STATUS") - master_status = cur.fetchone() + master_status: Optional[Tuple[str, int, ...]] = cur.fetchone() if master_status is None: raise BinLogNotEnabled() self.log_file, self.log_pos = master_status[:2] cur.close() - prelude = struct.pack(' ByteString: # https://mariadb.com/kb/en/5-slave-registration/ - cur = self._stream_connection.cursor() + cur: Cursor = self._stream_connection.cursor() if self.auto_position != None: cur.execute("SET @slave_connect_state='%s'" % self.auto_position) cur.execute("SET @slave_gtid_strict_mode=1") @@ -461,21 +462,21 @@ def __set_mariadb_settings(self): cur.close() # https://mariadb.com/kb/en/com_binlog_dump/ - header_size = ( + header_size: int = ( 4 + # binlog pos 2 + # binlog flags 4 + # slave server_id, 4 # requested binlog file name , set it to empty ) - prelude = struct.pack(' Union[BinLogPacketWrapper, None]: while True: if self.end_log_pos and self.is_past_end_log_pos: return None @@ -524,7 +525,7 @@ def fetchone(self): if not pkt.is_ok_packet(): continue - binlog_event = BinLogPacketWrapper(pkt, self.table_map, + binlog_event: BinLogPacketWrapper = BinLogPacketWrapper(pkt, self.table_map, self._ctl_connection, self.mysql_version, self.__use_checksum, @@ -549,7 +550,7 @@ def fetchone(self): # invalidates all our cached table id to schema mappings. This means we have to load them all # again for each logfile which is potentially wasted effort but we can't really do much better # without being broken in restart case - self.table_map = {} + self.table_map: Dict = {} elif binlog_event.log_pos: self.log_pos = binlog_event.log_pos @@ -599,8 +600,8 @@ def fetchone(self): return binlog_event.event - def _allowed_event_list(self, only_events, ignored_events, - filter_non_implemented_events): + def _allowed_event_list(self, only_events: Optional[List[str]], ignored_events: Optional[List[str]], + filter_non_implemented_events: bool) -> FrozenSet[str]: if only_events is not None: events = set(only_events) else: @@ -638,13 +639,13 @@ def _allowed_event_list(self, only_events, ignored_events, pass return frozenset(events) - def __get_table_information(self, schema, table): + def __get_table_information(self, schema: str, table: str) -> List[Dict[str, Any]]: for i in range(1, 3): try: if not self.__connected_ctl: self.__connect_to_ctl() - cur = self._ctl_connection.cursor() + cur: Cursor = self._ctl_connection.cursor() cur.execute(""" SELECT COLUMN_NAME, COLLATION_NAME, CHARACTER_SET_NAME, @@ -655,7 +656,7 @@ def __get_table_information(self, schema, table): WHERE table_schema = %s AND table_name = %s """, (schema, table)) - result = sorted(cur.fetchall(), key=lambda x: x['ORDINAL_POSITION']) + result: List = sorted(cur.fetchall(), key=lambda x: x['ORDINAL_POSITION']) cur.close() return result @@ -667,5 +668,5 @@ def __get_table_information(self, schema, table): else: raise error - def __iter__(self): + def __iter__(self) -> Iterator[Union[BinLogPacketWrapper, None]]: return iter(self.fetchone, None) From 3b8255763edf2b4d4d410ae105b0b435c1710041 Mon Sep 17 00:00:00 2001 From: starcat37 Date: Thu, 31 Aug 2023 21:07:35 +0900 Subject: [PATCH 2/9] Fix: modify subscript 'tuple' to 'Tuple' --- pymysqlreplication/binlogstream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymysqlreplication/binlogstream.py b/pymysqlreplication/binlogstream.py index de3c2107..19bcdd06 100644 --- a/pymysqlreplication/binlogstream.py +++ b/pymysqlreplication/binlogstream.py @@ -39,7 +39,7 @@ class ReportSlave(object): """Represent the values that you may report when connecting as a slave to a master. SHOW SLAVE HOSTS related""" - def __init__(self, value: Union[str, tuple[str, str, str, int]]) -> None: + def __init__(self, value: Union[str, Tuple[str, str, str, int]]) -> None: """ Attributes: value: string or tuple From 12ba2faf6576ec9413e2201e7e4a6269a0a8d4ab Mon Sep 17 00:00:00 2001 From: JeongSeung Mun Date: Thu, 31 Aug 2023 21:27:00 +0900 Subject: [PATCH 3/9] docs: add typing gtid.py * docs: add typing gtid.py & refactor Gtid.__init__ * docs: add typing gtid.py * refactor: update docstring format --------- Co-authored-by: mikaniz --- pymysqlreplication/gtid.py | 124 +++++++++++++++++++++---------------- 1 file changed, 69 insertions(+), 55 deletions(-) diff --git a/pymysqlreplication/gtid.py b/pymysqlreplication/gtid.py index 3b2554da..df80aac2 100644 --- a/pymysqlreplication/gtid.py +++ b/pymysqlreplication/gtid.py @@ -5,15 +5,18 @@ import binascii from copy import deepcopy from io import BytesIO +from typing import List, Optional, Tuple, Union, Set -def overlap(i1, i2): + +def overlap(i1: Tuple[int, int], i2: Tuple[int, int]) -> bool: return i1[0] < i2[1] and i1[1] > i2[0] -def contains(i1, i2): +def contains(i1: Tuple[int, int], i2: Tuple[int, int]) -> bool: return i2[0] >= i1[0] and i2[1] <= i1[1] class Gtid(object): - """A mysql GTID is composed of a server-id and a set of right-open + """ + A mysql GTID is composed of a server-id and a set of right-open intervals [a,b), and represent all transactions x that happened on server SID such as @@ -49,7 +52,7 @@ class Gtid(object): Exception: Adding a Gtid with a different SID. """ @staticmethod - def parse_interval(interval): + def parse_interval(interval: str) -> Tuple[int, int]: """ We parse a human-generated string here. So our end value b is incremented to conform to the internal representation format. @@ -65,8 +68,9 @@ def parse_interval(interval): return (a, b+1) @staticmethod - def parse(gtid): - """Parse a GTID from mysql textual format. + def parse(gtid: str) -> Tuple[str, List[Tuple[int, int]]]: + """ + Parse a GTID from mysql textual format. Raises: - ValueError: if GTID format is incorrect. @@ -84,7 +88,7 @@ def parse(gtid): return (sid, intervals_parsed) - def __add_interval(self, itvl): + def __add_interval(self, itvl: Tuple[int, int]) -> None: """ Use the internal representation format and add it to our intervals, merging if required. @@ -92,7 +96,7 @@ def __add_interval(self, itvl): Raises: Exception: if Malformated interval or Overlapping interval """ - new = [] + new: List[Tuple[int, int]] = [] if itvl[0] > itvl[1]: raise Exception('Malformed interval %s' % (itvl,)) @@ -114,11 +118,13 @@ def __add_interval(self, itvl): self.intervals = sorted(new + [itvl]) - def __sub_interval(self, itvl): - """Using the internal representation, remove an interval + def __sub_interval(self, itvl: Tuple[int, int]) -> None: + """ + Using the internal representation, remove an interval - Raises: Exception if itvl malformated""" - new = [] + Raises: Exception if itvl malformated + """ + new: List[Tuple[int, int]] = [] if itvl[0] > itvl[1]: raise Exception('Malformed interval %s' % (itvl,)) @@ -139,8 +145,9 @@ def __sub_interval(self, itvl): self.intervals = new - def __contains__(self, other): - """Test if other is contained within self. + def __contains__(self, other: 'Gtid') -> bool: + """ + Test if other is contained within self. First we compare sid they must be equals. Then we search if intervals from other are contained within @@ -152,10 +159,8 @@ def __contains__(self, other): return all(any(contains(me, them) for me in self.intervals) for them in other.intervals) - def __init__(self, gtid, sid=None, intervals=[]): - if sid: - intervals = intervals - else: + def __init__(self, gtid: str, sid: Optional[str] = None, intervals: Optional[List[Tuple[int, int]]] = None) -> None: + if sid is None: sid, intervals = Gtid.parse(gtid) self.sid = sid @@ -163,11 +168,13 @@ def __init__(self, gtid, sid=None, intervals=[]): for itvl in intervals: self.__add_interval(itvl) - def __add__(self, other): - """Include the transactions of this gtid. + def __add__(self, other: 'Gtid') -> 'Gtid': + """ + Include the transactions of this gtid. Raises: - Exception: if the attempted merge has different SID""" + Exception: if the attempted merge has different SID + """ if self.sid != other.sid: raise Exception('Attempt to merge different SID' '%s != %s' % (self.sid, other.sid)) @@ -179,9 +186,10 @@ def __add__(self, other): return result - def __sub__(self, other): - """Remove intervals. Do not raise, if different SID simply - ignore""" + def __sub__(self, other: 'Gtid') -> 'Gtid': + """ + Remove intervals. Do not raise, if different SID simply ignore + """ result = deepcopy(self) if self.sid != other.sid: return result @@ -191,27 +199,30 @@ def __sub__(self, other): return result - def __str__(self): - """We represent the human value here - a single number - for one transaction, or a closed interval (decrementing b)""" + def __str__(self) -> str: + """ + We represent the human value here - a single number + for one transaction, or a closed interval (decrementing b) + """ return '%s:%s' % (self.sid, ':'.join(('%d-%d' % (x[0], x[1]-1)) if x[0] +1 != x[1] else str(x[0]) for x in self.intervals)) - def __repr__(self): + def __repr__(self) -> str: return '' % self @property - def encoded_length(self): + def encoded_length(self) -> int: return (16 + # sid 8 + # n_intervals 2 * # stop/start 8 * # stop/start mark encoded as int64 len(self.intervals)) - def encode(self): - """Encode a Gtid in binary + def encode(self) -> bytes: + """ + Encode a Gtid in binary Bytes are in **little endian**. Format: @@ -251,8 +262,9 @@ def encode(self): return buffer @classmethod - def decode(cls, payload): - """Decode from binary a Gtid + def decode(cls, payload: BytesIO) -> 'Gtid': + """ + Decode from binary a Gtid :param BytesIO payload to decode """ @@ -281,27 +293,27 @@ def decode(cls, payload): else '%d' % x for x in intervals]))) - def __eq__(self, other): + def __eq__(self, other: 'Gtid') -> bool: if other.sid != self.sid: return False return self.intervals == other.intervals - def __lt__(self, other): + def __lt__(self, other: 'Gtid') -> bool: if other.sid != self.sid: return self.sid < other.sid return self.intervals < other.intervals - def __le__(self, other): + def __le__(self, other: 'Gtid') -> bool: if other.sid != self.sid: return self.sid <= other.sid return self.intervals <= other.intervals - def __gt__(self, other): + def __gt__(self, other: 'Gtid') -> bool: if other.sid != self.sid: return self.sid > other.sid return self.intervals > other.intervals - def __ge__(self, other): + def __ge__(self, other: 'Gtid') -> bool: if other.sid != self.sid: return self.sid >= other.sid return self.intervals >= other.intervals @@ -309,7 +321,7 @@ def __ge__(self, other): class GtidSet(object): """Represents a set of Gtid""" - def __init__(self, gtid_set): + def __init__(self, gtid_set: Optional[Union[None, str, Set[Gtid], List[Gtid], Gtid]] = None) -> None: """ Construct a GtidSet initial state depends of the nature of `gtid_set` param. @@ -325,21 +337,21 @@ def __init__(self, gtid_set): - Exception: if Gtid interval are either malformated or overlapping """ - def _to_gtid(element): + def _to_gtid(element: str) -> Gtid: if isinstance(element, Gtid): return element return Gtid(element.strip(' \n')) if not gtid_set: - self.gtids = [] + self.gtids: List[Gtid] = [] elif isinstance(gtid_set, (list, set)): - self.gtids = [_to_gtid(x) for x in gtid_set] + self.gtids: List[Gtid] = [_to_gtid(x) for x in gtid_set] else: - self.gtids = [Gtid(x.strip(' \n')) for x in gtid_set.split(',')] + self.gtids: List[Gtid] = [Gtid(x.strip(' \n')) for x in gtid_set.split(',')] - def merge_gtid(self, gtid): + def merge_gtid(self, gtid: Gtid) -> None: """Insert a Gtid in current GtidSet.""" - new_gtids = [] + new_gtids: List[Gtid] = [] for existing in self.gtids: if existing.sid == gtid.sid: new_gtids.append(existing + gtid) @@ -349,7 +361,7 @@ def merge_gtid(self, gtid): new_gtids.append(gtid) self.gtids = new_gtids - def __contains__(self, other): + def __contains__(self, other: Union['GtidSet', Gtid]) -> bool: """ Test if self contains other, could be a GtidSet or a Gtid. @@ -363,7 +375,7 @@ def __contains__(self, other): return any(other in x for x in self.gtids) raise NotImplementedError - def __add__(self, other): + def __add__(self, other: Union['GtidSet', Gtid]) -> 'GtidSet': """ Merge current instance with an other GtidSet or with a Gtid alone. @@ -384,22 +396,23 @@ def __add__(self, other): raise NotImplementedError - def __str__(self): + def __str__(self) -> str: """ Returns a comma separated string of gtids. """ return ','.join(str(x) for x in self.gtids) - def __repr__(self): + def __repr__(self) -> str: return '' % self.gtids @property - def encoded_length(self): + def encoded_length(self) -> int: return (8 + # n_sids sum(x.encoded_length for x in self.gtids)) - def encoded(self): - """Encode a GtidSet in binary + def encoded(self) -> bytes: + """ + Encode a GtidSet in binary Bytes are in **little endian**. - `n_sid`: u64 is the number of Gtid to read @@ -421,8 +434,9 @@ def encoded(self): encode = encoded @classmethod - def decode(cls, payload): - """Decode a GtidSet from binary. + def decode(cls, payload: BytesIO) -> 'GtidSet': + """ + Decode a GtidSet from binary. :param BytesIO payload to decode """ @@ -432,5 +446,5 @@ def decode(cls, payload): return cls([Gtid.decode(payload) for _ in range(0, n_sid)]) - def __eq__(self, other): + def __eq__(self, other: 'GtidSet') -> bool: return self.gtids == other.gtids From cb129a5bcc654c83fb73c07ec3cd5315c7fd01e4 Mon Sep 17 00:00:00 2001 From: Suin Kim Date: Thu, 31 Aug 2023 22:04:29 +0900 Subject: [PATCH 4/9] Refactor: add typing in exceptions.py --- pymysqlreplication/exceptions.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pymysqlreplication/exceptions.py b/pymysqlreplication/exceptions.py index 434d8d76..d233a6a6 100644 --- a/pymysqlreplication/exceptions.py +++ b/pymysqlreplication/exceptions.py @@ -1,19 +1,19 @@ class TableMetadataUnavailableError(Exception): - def __init__(self, table): + def __init__(self, table: str) -> None: Exception.__init__(self,"Unable to find metadata for table {0}".format(table)) class BinLogNotEnabled(Exception): - def __init__(self): + def __init__(self) -> None: Exception.__init__(self, "MySQL binary logging is not enabled.") class StatusVariableMismatch(Exception): - def __init__(self): - Exception.__init__(self, " ".join( + def __init__(self) -> None: + Exception.__init__(self, " ".join([ "Unknown status variable in query event." , "Possible parse failure in preceding fields" , "or outdated constants.STATUS_VAR_KEY" , "Refer to MySQL documentation/source code" , "or create an issue on GitHub" - )) + ])) From 9bcf867dc9518ee5e048a6518f6ff33d02cc8aae Mon Sep 17 00:00:00 2001 From: starcat37 Date: Fri, 1 Sep 2023 16:13:16 +0900 Subject: [PATCH 5/9] docs: improve the docstrings and add missing typing --- pymysqlreplication/binlogstream.py | 93 +++++++++++++++--------------- 1 file changed, 48 insertions(+), 45 deletions(-) diff --git a/pymysqlreplication/binlogstream.py b/pymysqlreplication/binlogstream.py index 19bcdd06..4cefcf74 100644 --- a/pymysqlreplication/binlogstream.py +++ b/pymysqlreplication/binlogstream.py @@ -5,7 +5,8 @@ import pymysql from pymysql.constants.COMMAND import COM_BINLOG_DUMP, COM_REGISTER_SLAVE -from pymysql.cursors import DictCursor, Cursor +from pymysql.cursors import Cursor, DictCursor +from pymysql.connections import Connection, MysqlPacket from .constants.BINLOG import TABLE_MAP_EVENT, ROTATE_EVENT, FORMAT_DESCRIPTION_EVENT from .event import ( @@ -21,7 +22,6 @@ from .row_event import ( UpdateRowsEvent, WriteRowsEvent, DeleteRowsEvent, TableMapEvent) from typing import ByteString, Union, Optional, List, Tuple, Dict, Any, Iterator, FrozenSet, Type -from pymysql.connections import Connection try: from pymysql.constants.COMMAND import COM_BINLOG_DUMP_GTID @@ -36,8 +36,10 @@ class ReportSlave(object): - """Represent the values that you may report when connecting as a slave - to a master. SHOW SLAVE HOSTS related""" + """ + Represent the values that you may report + when connecting as a slave to a master. SHOW SLAVE HOSTS related. + """ def __init__(self, value: Union[str, Tuple[str, str, str, int]]) -> None: """ @@ -74,9 +76,9 @@ def __repr__(self) -> str: def encoded(self, server_id: int, master_id: int = 0) -> ByteString: """ - server_id: the slave server-id - master_id: usually 0. Appears as "master id" in SHOW SLAVE HOSTS - on the master. Unknown what else it impacts. + :ivar server_id: int - the slave server-id + :ivar master_id: int - usually 0. Appears as "master id" in SHOW SLAVE HOSTS on the master. + Unknown what else it impacts. """ # 1 [15] COM_REGISTER_SLAVE @@ -124,9 +126,10 @@ def encoded(self, server_id: int, master_id: int = 0) -> ByteString: class BinLogStreamReader(object): - """Connect to replication stream and read event """ - report_slave: Optional[ReportSlave] = None + Connect to replication stream and read event + """ + report_slave: Optional[Union[str, Tuple[str, str, str, int]]] = None def __init__(self, connection_settings: Dict, server_id: int, ctl_connection_settings: Optional[Dict] = None, resume_stream: bool = False, @@ -137,7 +140,7 @@ def __init__(self, connection_settings: Dict, server_id: int, only_tables: Optional[List[str]] = None, ignored_tables: Optional[List[str]] = None, only_schemas: Optional[List[str]] = None, ignored_schemas: Optional[List[str]] = None, freeze_schema: bool = False, skip_to_timestamp: Optional[float] = None, - report_slave: Optional[ReportSlave] = None, slave_uuid: Optional[str] = None, + report_slave: Optional[Union[str, Tuple[str, str, str, int]]] = None, slave_uuid: Optional[str] = None, pymysql_wrapper: Optional[Connection] = None, fail_on_table_metadata_unavailable: bool = False, slave_heartbeat: Optional[float] = None, @@ -146,44 +149,42 @@ def __init__(self, connection_settings: Dict, server_id: int, ignore_decode_errors: bool = False) -> None: """ Attributes: - ctl_connection_settings: Connection settings for cluster holding + ctl_connection_settings[Dict]: Connection settings for cluster holding schema information - resume_stream: Start for event from position or the latest event of + resume_stream[bool]: Start for event from position or the latest event of binlog or from older available event - blocking: When master has finished reading/sending binlog it will + blocking[bool]: When master has finished reading/sending binlog it will send EOF instead of blocking connection. - only_events: Array of allowed events - ignored_events: Array of ignored events - log_file: Set replication start log file - log_pos: Set replication start log pos (resume_stream should be + only_events[List[str]]: Array of allowed events + ignored_events[List[str]]: Array of ignored events + log_file[str]: Set replication start log file + log_pos[int]: Set replication start log pos (resume_stream should be true) - end_log_pos: Set replication end log pos - auto_position: Use master_auto_position gtid to set position - only_tables: An array with the tables you want to watch (only works + end_log_pos[int]: Set replication end log pos + auto_position[str]: Use master_auto_position gtid to set position + only_tables[List[str]]: An array with the tables you want to watch (only works in binlog_format ROW) - ignored_tables: An array with the tables you want to skip - only_schemas: An array with the schemas you want to watch - ignored_schemas: An array with the schemas you want to skip - freeze_schema: If true do not support ALTER TABLE. It's faster. - skip_to_timestamp: Ignore all events until reaching specified - timestamp. - report_slave: Report slave in SHOW SLAVE HOSTS. - slave_uuid: Report slave_uuid or replica_uuid in SHOW SLAVE HOSTS(MySQL 8.0.21-) or + ignored_tables[List[str]]: An array with the tables you want to skip + only_schemas[List[str]]: An array with the schemas you want to watch + ignored_schemas[List[str]]: An array with the schemas you want to skip + freeze_schema[bool]: If true do not support ALTER TABLE. It's faster. + skip_to_timestamp[float]: Ignore all events until reaching specified timestamp. + report_slave[ReportSlave]: Report slave in SHOW SLAVE HOSTS. + slave_uuid[str]: Report slave_uuid or replica_uuid in SHOW SLAVE HOSTS(MySQL 8.0.21-) or SHOW REPLICAS(MySQL 8.0.22+) depends on your MySQL version. - fail_on_table_metadata_unavailable: Should raise exception if we - can't get table information on - row_events - slave_heartbeat: (seconds) Should master actively send heartbeat on + fail_on_table_metadata_unavailable[bool]: Should raise exception if we + can't get table information on row_events + slave_heartbeat[float]: (seconds) Should master actively send heartbeat on connection. This also reduces traffic in GTID replication on replication resumption (in case many event to skip in binlog). See MASTER_HEARTBEAT_PERIOD in mysql documentation for semantics - is_mariadb: Flag to indicate it's a MariaDB server, used with auto_position + is_mariadb[bool]: Flag to indicate it's a MariaDB server, used with auto_position to point to Mariadb specific GTID. - annotate_rows_event: Parameter value to enable annotate rows event in mariadb, + annotate_rows_event[bool]: Parameter value to enable annotate rows event in mariadb, used with 'is_mariadb' - ignore_decode_errors: If true, any decode errors encountered + ignore_decode_errors[bool]: If true, any decode errors encountered when reading column data will be ignored. """ @@ -230,12 +231,12 @@ def __init__(self, connection_settings: Dict, server_id: int, self.is_past_end_log_pos: bool = False if report_slave: - self.report_slave: Optional[ReportSlave] = ReportSlave(report_slave) + self.report_slave: ReportSlave = ReportSlave(report_slave) self.slave_uuid: Optional[str] = slave_uuid self.slave_heartbeat: Optional[float] = slave_heartbeat if pymysql_wrapper: - self.pymysql_wrapper: Optional[Connection] = pymysql_wrapper + self.pymysql_wrapper: Connection = pymysql_wrapper else: self.pymysql_wrapper: Optional[Union[Connection, Type[Connection]]] = pymysql.connect self.mysql_version: Tuple = (0, 0, 0) @@ -262,7 +263,9 @@ def __connect_to_ctl(self) -> None: self.__connected_ctl: bool = True def __checksum_enabled(self) -> bool: - """Return True if binlog-checksum = CRC32. Only for MySQL > 5.6""" + """ + Return True if binlog-checksum = CRC32. Only for MySQL > 5.6 + """ cur: Cursor = self._stream_connection.cursor() cur.execute("SHOW GLOBAL VARIABLES LIKE 'BINLOG_CHECKSUM'") result: Optional[Tuple[str, str]] = cur.fetchone() @@ -295,19 +298,19 @@ def __connect_to_stream(self) -> None: # flags (2) BINLOG_DUMP_NON_BLOCK (0 or 1) # server_id (4) -- server id of this slave # log_file (string.EOF) -- filename of the binlog on the master - self._stream_connection = self.pymysql_wrapper(**self.__connection_settings) + self._stream_connection: Connection = self.pymysql_wrapper(**self.__connection_settings) self.__use_checksum: bool = self.__checksum_enabled() # If checksum is enabled we need to inform the server about the that # we support it if self.__use_checksum: - cur = self._stream_connection.cursor() + cur: Cursor = self._stream_connection.cursor() cur.execute("SET @master_binlog_checksum= @@global.binlog_checksum") cur.close() if self.slave_uuid: - cur = self._stream_connection.cursor() + cur: Cursor = self._stream_connection.cursor() cur.execute("SET @slave_uuid = %s, @replica_uuid = %s", (self.slave_uuid, self.slave_uuid)) cur.close() @@ -339,14 +342,14 @@ def __connect_to_stream(self) -> None: if not self.auto_position: if self.is_mariadb: - prelude = self.__set_mariadb_settings() + prelude: ByteString = self.__set_mariadb_settings() else: # only when log_file and log_pos both provided, the position info is # valid, if not, get the current position from master if self.log_file is None or self.log_pos is None: cur: Cursor = self._stream_connection.cursor() cur.execute("SHOW MASTER STATUS") - master_status: Optional[Tuple[str, int, ...]] = cur.fetchone() + master_status: Optional[Tuple[str, int, Any]] = cur.fetchone() if master_status is None: raise BinLogNotEnabled() self.log_file, self.log_pos = master_status[:2] @@ -507,9 +510,9 @@ def fetchone(self) -> Union[BinLogPacketWrapper, None]: try: if pymysql.__version__ < LooseVersion("0.6"): - pkt = self._stream_connection.read_packet() + pkt: MysqlPacket = self._stream_connection.read_packet() else: - pkt = self._stream_connection._read_packet() + pkt: MysqlPacket = self._stream_connection._read_packet() except pymysql.OperationalError as error: code, message = error.args if code in MYSQL_EXPECTED_ERROR_CODES: From 7cc0e713b05fd82def255cc1b0220228ee2f3d27 Mon Sep 17 00:00:00 2001 From: starcat37 Date: Fri, 1 Sep 2023 18:16:45 +0900 Subject: [PATCH 6/9] Fix: add typing about dict and improve docstring --- pymysqlreplication/binlogstream.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/pymysqlreplication/binlogstream.py b/pymysqlreplication/binlogstream.py index 4cefcf74..f434d70b 100644 --- a/pymysqlreplication/binlogstream.py +++ b/pymysqlreplication/binlogstream.py @@ -41,12 +41,13 @@ class ReportSlave(object): when connecting as a slave to a master. SHOW SLAVE HOSTS related. """ - def __init__(self, value: Union[str, Tuple[str, str, str, int]]) -> None: + def __init__(self, value: Union[str, Tuple[str, str, str, int], Dict[str, Union[str, int]]]) -> None: """ Attributes: - value: string or tuple + value: string, tuple or dict if string, then it will be used hostname if tuple it will be used as (hostname, user, password, port) + if dict, keys 'hostname', 'username', 'password', 'port' will be used. """ self.hostname: str = '' self.username: str = '' @@ -342,7 +343,7 @@ def __connect_to_stream(self) -> None: if not self.auto_position: if self.is_mariadb: - prelude: ByteString = self.__set_mariadb_settings() + prelude = self.__set_mariadb_settings() else: # only when log_file and log_pos both provided, the position info is # valid, if not, get the current position from master @@ -355,7 +356,7 @@ def __connect_to_stream(self) -> None: self.log_file, self.log_pos = master_status[:2] cur.close() - prelude: ByteString = struct.pack(' None: self._stream_connection._next_seq_id = 1 self.__connected_stream: bool = True - def __set_mariadb_settings(self) -> ByteString: + def __set_mariadb_settings(self) -> bytes: # https://mariadb.com/kb/en/5-slave-registration/ cur: Cursor = self._stream_connection.cursor() if self.auto_position != None: @@ -472,7 +473,7 @@ def __set_mariadb_settings(self) -> ByteString: 4 # requested binlog file name , set it to empty ) - prelude: ByteString = struct.pack(' Date: Fri, 1 Sep 2023 18:21:19 +0900 Subject: [PATCH 7/9] Fix: modify typing from ByteString to bytes --- pymysqlreplication/binlogstream.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymysqlreplication/binlogstream.py b/pymysqlreplication/binlogstream.py index f434d70b..a2b7cf92 100644 --- a/pymysqlreplication/binlogstream.py +++ b/pymysqlreplication/binlogstream.py @@ -283,7 +283,7 @@ def _register_slave(self) -> None: if not self.report_slave: return - packet: ByteString = self.report_slave.encoded(self.__server_id) + packet: bytes = self.report_slave.encoded(self.__server_id) if pymysql.__version__ < LooseVersion("0.6"): self._stream_connection.wfile.write(packet) @@ -374,7 +374,7 @@ def __connect_to_stream(self) -> None: prelude += self.log_file.encode() else: if self.is_mariadb: - prelude: ByteString = self.__set_mariadb_settings() + prelude = self.__set_mariadb_settings() else: # Format for mysql packet master_auto_position # From 028aba17e6430771ea5b2fb633b7d306015da059 Mon Sep 17 00:00:00 2001 From: Suin Kim Date: Sat, 2 Sep 2023 13:29:30 +0900 Subject: [PATCH 8/9] Refactor: add typing to binlogstream.py --- pymysqlreplication/binlogstream.py | 327 +++++++++++++++-------------- 1 file changed, 166 insertions(+), 161 deletions(-) diff --git a/pymysqlreplication/binlogstream.py b/pymysqlreplication/binlogstream.py index c153fcda..a2b7cf92 100644 --- a/pymysqlreplication/binlogstream.py +++ b/pymysqlreplication/binlogstream.py @@ -5,7 +5,8 @@ import pymysql from pymysql.constants.COMMAND import COM_BINLOG_DUMP, COM_REGISTER_SLAVE -from pymysql.cursors import DictCursor +from pymysql.cursors import Cursor, DictCursor +from pymysql.connections import Connection, MysqlPacket from .constants.BINLOG import TABLE_MAP_EVENT, ROTATE_EVENT, FORMAT_DESCRIPTION_EVENT from .event import ( @@ -20,6 +21,7 @@ from .packet import BinLogPacketWrapper from .row_event import ( UpdateRowsEvent, WriteRowsEvent, DeleteRowsEvent, TableMapEvent) +from typing import ByteString, Union, Optional, List, Tuple, Dict, Any, Iterator, FrozenSet, Type try: from pymysql.constants.COMMAND import COM_BINLOG_DUMP_GTID @@ -34,28 +36,30 @@ class ReportSlave(object): - """Represent the values that you may report when connecting as a slave - to a master. SHOW SLAVE HOSTS related""" - - hostname = '' - username = '' - password = '' - port = 0 + """ + Represent the values that you may report + when connecting as a slave to a master. SHOW SLAVE HOSTS related. + """ - def __init__(self, value): + def __init__(self, value: Union[str, Tuple[str, str, str, int], Dict[str, Union[str, int]]]) -> None: """ Attributes: - value: string or tuple + value: string, tuple or dict if string, then it will be used hostname if tuple it will be used as (hostname, user, password, port) + if dict, keys 'hostname', 'username', 'password', 'port' will be used. """ + self.hostname: str = '' + self.username: str = '' + self.password: str = '' + self.port: int = 0 if isinstance(value, (tuple, list)): try: - self.hostname = value[0] - self.username = value[1] - self.password = value[2] - self.port = int(value[3]) + self.hostname: str = value[0] + self.username: str = value[1] + self.password: str = value[2] + self.port: int = int(value[3]) except IndexError: pass elif isinstance(value, dict): @@ -65,17 +69,17 @@ def __init__(self, value): except KeyError: pass else: - self.hostname = value + self.hostname: Union[str, tuple] = value - def __repr__(self): + def __repr__(self) -> str: return '' % \ (self.hostname, self.username, self.password, self.port) - def encoded(self, server_id, master_id=0): + def encoded(self, server_id: int, master_id: int = 0) -> ByteString: """ - server_id: the slave server-id - master_id: usually 0. Appears as "master id" in SHOW SLAVE HOSTS - on the master. Unknown what else it impacts. + :ivar server_id: int - the slave server-id + :ivar master_id: int - usually 0. Appears as "master id" in SHOW SLAVE HOSTS on the master. + Unknown what else it impacts. """ # 1 [15] COM_REGISTER_SLAVE @@ -90,23 +94,23 @@ def encoded(self, server_id, master_id=0): # 4 replication rank # 4 master-id - lhostname = len(self.hostname.encode()) - lusername = len(self.username.encode()) - lpassword = len(self.password.encode()) + lhostname: int = len(self.hostname.encode()) + lusername: int = len(self.username.encode()) + lpassword: int = len(self.password.encode()) - packet_len = (1 + # command - 4 + # server-id - 1 + # hostname length - lhostname + - 1 + # username length - lusername + - 1 + # password length - lpassword + - 2 + # slave mysql port - 4 + # replication rank - 4) # master-id + packet_len: int = (1 + # command + 4 + # server-id + 1 + # hostname length + lhostname + + 1 + # username length + lusername + + 1 + # password length + lpassword + + 2 + # slave mysql port + 4 + # replication rank + 4) # master-id - MAX_STRING_LEN = 257 # one byte for length + 256 chars + MAX_STRING_LEN: int = 257 # one byte for length + 256 chars return (struct.pack(' None: """ Attributes: - ctl_connection_settings: Connection settings for cluster holding + ctl_connection_settings[Dict]: Connection settings for cluster holding schema information - resume_stream: Start for event from position or the latest event of + resume_stream[bool]: Start for event from position or the latest event of binlog or from older available event - blocking: When master has finished reading/sending binlog it will + blocking[bool]: When master has finished reading/sending binlog it will send EOF instead of blocking connection. - only_events: Array of allowed events - ignored_events: Array of ignored events - log_file: Set replication start log file - log_pos: Set replication start log pos (resume_stream should be + only_events[List[str]]: Array of allowed events + ignored_events[List[str]]: Array of ignored events + log_file[str]: Set replication start log file + log_pos[int]: Set replication start log pos (resume_stream should be true) - end_log_pos: Set replication end log pos - auto_position: Use master_auto_position gtid to set position - only_tables: An array with the tables you want to watch (only works + end_log_pos[int]: Set replication end log pos + auto_position[str]: Use master_auto_position gtid to set position + only_tables[List[str]]: An array with the tables you want to watch (only works in binlog_format ROW) - ignored_tables: An array with the tables you want to skip - only_schemas: An array with the schemas you want to watch - ignored_schemas: An array with the schemas you want to skip - freeze_schema: If true do not support ALTER TABLE. It's faster. - skip_to_timestamp: Ignore all events until reaching specified - timestamp. - report_slave: Report slave in SHOW SLAVE HOSTS. - slave_uuid: Report slave_uuid or replica_uuid in SHOW SLAVE HOSTS(MySQL 8.0.21-) or + ignored_tables[List[str]]: An array with the tables you want to skip + only_schemas[List[str]]: An array with the schemas you want to watch + ignored_schemas[List[str]]: An array with the schemas you want to skip + freeze_schema[bool]: If true do not support ALTER TABLE. It's faster. + skip_to_timestamp[float]: Ignore all events until reaching specified timestamp. + report_slave[ReportSlave]: Report slave in SHOW SLAVE HOSTS. + slave_uuid[str]: Report slave_uuid or replica_uuid in SHOW SLAVE HOSTS(MySQL 8.0.21-) or SHOW REPLICAS(MySQL 8.0.22+) depends on your MySQL version. - fail_on_table_metadata_unavailable: Should raise exception if we - can't get table information on - row_events - slave_heartbeat: (seconds) Should master actively send heartbeat on + fail_on_table_metadata_unavailable[bool]: Should raise exception if we + can't get table information on row_events + slave_heartbeat[float]: (seconds) Should master actively send heartbeat on connection. This also reduces traffic in GTID replication on replication resumption (in case many event to skip in binlog). See MASTER_HEARTBEAT_PERIOD in mysql documentation for semantics - is_mariadb: Flag to indicate it's a MariaDB server, used with auto_position + is_mariadb[bool]: Flag to indicate it's a MariaDB server, used with auto_position to point to Mariadb specific GTID. - annotate_rows_event: Parameter value to enable annotate rows event in mariadb, + annotate_rows_event[bool]: Parameter value to enable annotate rows event in mariadb, used with 'is_mariadb' - ignore_decode_errors: If true, any decode errors encountered + ignore_decode_errors[bool]: If true, any decode errors encountered when reading column data will be ignored. """ - self.__connection_settings = connection_settings + self.__connection_settings: Dict = connection_settings self.__connection_settings.setdefault("charset", "utf8") - self.__connected_stream = False - self.__connected_ctl = False - self.__resume_stream = resume_stream - self.__blocking = blocking - self._ctl_connection_settings = ctl_connection_settings + self.__connected_stream: bool = False + self.__connected_ctl: bool = False + self.__resume_stream: bool = resume_stream + self.__blocking: bool = blocking + self._ctl_connection_settings: Dict = ctl_connection_settings if ctl_connection_settings: self._ctl_connection_settings.setdefault("charset", "utf8") - self.__only_tables = only_tables - self.__ignored_tables = ignored_tables - self.__only_schemas = only_schemas - self.__ignored_schemas = ignored_schemas - self.__freeze_schema = freeze_schema - self.__allowed_events = self._allowed_event_list( + self.__only_tables: Optional[List[str]] = only_tables + self.__ignored_tables: Optional[List[str]] = ignored_tables + self.__only_schemas: Optional[List[str]] = only_schemas + self.__ignored_schemas: Optional[List[str]] = ignored_schemas + self.__freeze_schema: bool = freeze_schema + self.__allowed_events: FrozenSet[str] = self._allowed_event_list( only_events, ignored_events, filter_non_implemented_events) - self.__fail_on_table_metadata_unavailable = fail_on_table_metadata_unavailable - self.__ignore_decode_errors = ignore_decode_errors + self.__fail_on_table_metadata_unavailable: bool = fail_on_table_metadata_unavailable + self.__ignore_decode_errors: bool = ignore_decode_errors # We can't filter on packet level TABLE_MAP and rotate event because # we need them for handling other operations - self.__allowed_events_in_packet = frozenset( + self.__allowed_events_in_packet: FrozenSet[str] = frozenset( [TableMapEvent, RotateEvent]).union(self.__allowed_events) - self.__server_id = server_id - self.__use_checksum = False + self.__server_id: int = server_id + self.__use_checksum: bool = False # Store table meta information - self.table_map = {} - self.log_pos = log_pos - self.end_log_pos = end_log_pos - self.log_file = log_file - self.auto_position = auto_position - self.skip_to_timestamp = skip_to_timestamp - self.is_mariadb = is_mariadb - self.__annotate_rows_event = annotate_rows_event + self.table_map: Dict = {} + self.log_pos: Optional[int] = log_pos + self.end_log_pos: Optional[int] = end_log_pos + self.log_file: Optional[str] = log_file + self.auto_position: Optional[str] = auto_position + self.skip_to_timestamp: Optional[float] = skip_to_timestamp + self.is_mariadb: bool = is_mariadb + self.__annotate_rows_event: bool = annotate_rows_event if end_log_pos: - self.is_past_end_log_pos = False + self.is_past_end_log_pos: bool = False if report_slave: - self.report_slave = ReportSlave(report_slave) - self.slave_uuid = slave_uuid - self.slave_heartbeat = slave_heartbeat + self.report_slave: ReportSlave = ReportSlave(report_slave) + self.slave_uuid: Optional[str] = slave_uuid + self.slave_heartbeat: Optional[float] = slave_heartbeat if pymysql_wrapper: - self.pymysql_wrapper = pymysql_wrapper + self.pymysql_wrapper: Connection = pymysql_wrapper else: - self.pymysql_wrapper = pymysql.connect - self.mysql_version = (0, 0, 0) + self.pymysql_wrapper: Optional[Union[Connection, Type[Connection]]] = pymysql.connect + self.mysql_version: Tuple = (0, 0, 0) - def close(self): + def close(self) -> None: if self.__connected_stream: self._stream_connection.close() - self.__connected_stream = False + self.__connected_stream: bool = False if self.__connected_ctl: # break reference cycle between stream reader and underlying # mysql connection object self._ctl_connection._get_table_information = None self._ctl_connection.close() - self.__connected_ctl = False + self.__connected_ctl: bool = False - def __connect_to_ctl(self): + def __connect_to_ctl(self) -> None: if not self._ctl_connection_settings: - self._ctl_connection_settings = dict(self.__connection_settings) + self._ctl_connection_settings: Dict[str, Any] = dict(self.__connection_settings) self._ctl_connection_settings["db"] = "information_schema" self._ctl_connection_settings["cursorclass"] = DictCursor self._ctl_connection_settings["autocommit"] = True - self._ctl_connection = self.pymysql_wrapper(**self._ctl_connection_settings) + self._ctl_connection: Connection = self.pymysql_wrapper(**self._ctl_connection_settings) self._ctl_connection._get_table_information = self.__get_table_information - self.__connected_ctl = True + self.__connected_ctl: bool = True - def __checksum_enabled(self): - """Return True if binlog-checksum = CRC32. Only for MySQL > 5.6""" - cur = self._stream_connection.cursor() + def __checksum_enabled(self) -> bool: + """ + Return True if binlog-checksum = CRC32. Only for MySQL > 5.6 + """ + cur: Cursor = self._stream_connection.cursor() cur.execute("SHOW GLOBAL VARIABLES LIKE 'BINLOG_CHECKSUM'") - result = cur.fetchone() + result: Optional[Tuple[str, str]] = cur.fetchone() cur.close() if result is None: @@ -274,11 +279,11 @@ def __checksum_enabled(self): return False return True - def _register_slave(self): + def _register_slave(self) -> None: if not self.report_slave: return - packet = self.report_slave.encoded(self.__server_id) + packet: bytes = self.report_slave.encoded(self.__server_id) if pymysql.__version__ < LooseVersion("0.6"): self._stream_connection.wfile.write(packet) @@ -289,40 +294,40 @@ def _register_slave(self): self._stream_connection._next_seq_id = 1 self._stream_connection._read_packet() - def __connect_to_stream(self): + def __connect_to_stream(self) -> None: # log_pos (4) -- position in the binlog-file to start the stream with # flags (2) BINLOG_DUMP_NON_BLOCK (0 or 1) # server_id (4) -- server id of this slave # log_file (string.EOF) -- filename of the binlog on the master - self._stream_connection = self.pymysql_wrapper(**self.__connection_settings) + self._stream_connection: Connection = self.pymysql_wrapper(**self.__connection_settings) - self.__use_checksum = self.__checksum_enabled() + self.__use_checksum: bool = self.__checksum_enabled() # If checksum is enabled we need to inform the server about the that # we support it if self.__use_checksum: - cur = self._stream_connection.cursor() + cur: Cursor = self._stream_connection.cursor() cur.execute("SET @master_binlog_checksum= @@global.binlog_checksum") cur.close() if self.slave_uuid: - cur = self._stream_connection.cursor() + cur: Cursor = self._stream_connection.cursor() cur.execute("SET @slave_uuid = %s, @replica_uuid = %s", (self.slave_uuid, self.slave_uuid)) cur.close() if self.slave_heartbeat: # 4294967 is documented as the max value for heartbeats - net_timeout = float(self.__connection_settings.get('read_timeout', + net_timeout: float = float(self.__connection_settings.get('read_timeout', 4294967)) # If heartbeat is too low, the connection will disconnect before, # this is also the behavior in mysql - heartbeat = float(min(net_timeout / 2., self.slave_heartbeat)) + heartbeat: float = float(min(net_timeout / 2., self.slave_heartbeat)) if heartbeat > 4294967: heartbeat = 4294967 # master_heartbeat_period is nanoseconds - heartbeat = int(heartbeat * 1000000000) - cur = self._stream_connection.cursor() + heartbeat: int = int(heartbeat * 1000000000) + cur: Cursor = self._stream_connection.cursor() cur.execute("SET @master_heartbeat_period= %d" % heartbeat) cur.close() @@ -330,7 +335,7 @@ def __connect_to_stream(self): # Mariadb, when it tries to replace GTID events with dummy ones. Given that this library understands GTID # events, setting the capability to 4 circumvents this error. # If the DB is mysql, this won't have any effect so no need to run this in a condition - cur = self._stream_connection.cursor() + cur: Cursor = self._stream_connection.cursor() cur.execute("SET @mariadb_slave_capability=4") cur.close() @@ -343,15 +348,15 @@ def __connect_to_stream(self): # only when log_file and log_pos both provided, the position info is # valid, if not, get the current position from master if self.log_file is None or self.log_pos is None: - cur = self._stream_connection.cursor() + cur: Cursor = self._stream_connection.cursor() cur.execute("SHOW MASTER STATUS") - master_status = cur.fetchone() + master_status: Optional[Tuple[str, int, Any]] = cur.fetchone() if master_status is None: raise BinLogNotEnabled() self.log_file, self.log_pos = master_status[:2] cur.close() - prelude = struct.pack(' bytes: # https://mariadb.com/kb/en/5-slave-registration/ - cur = self._stream_connection.cursor() + cur: Cursor = self._stream_connection.cursor() if self.auto_position != None: cur.execute("SET @slave_connect_state='%s'" % self.auto_position) cur.execute("SET @slave_gtid_strict_mode=1") @@ -461,21 +466,21 @@ def __set_mariadb_settings(self): cur.close() # https://mariadb.com/kb/en/com_binlog_dump/ - header_size = ( + header_size: int = ( 4 + # binlog pos 2 + # binlog flags 4 + # slave server_id, 4 # requested binlog file name , set it to empty ) - prelude = struct.pack(' Union[BinLogPacketWrapper, None]: while True: if self.end_log_pos and self.is_past_end_log_pos: return None @@ -506,9 +511,9 @@ def fetchone(self): try: if pymysql.__version__ < LooseVersion("0.6"): - pkt = self._stream_connection.read_packet() + pkt: MysqlPacket = self._stream_connection.read_packet() else: - pkt = self._stream_connection._read_packet() + pkt: MysqlPacket = self._stream_connection._read_packet() except pymysql.OperationalError as error: code, message = error.args if code in MYSQL_EXPECTED_ERROR_CODES: @@ -524,7 +529,7 @@ def fetchone(self): if not pkt.is_ok_packet(): continue - binlog_event = BinLogPacketWrapper(pkt, self.table_map, + binlog_event: BinLogPacketWrapper = BinLogPacketWrapper(pkt, self.table_map, self._ctl_connection, self.mysql_version, self.__use_checksum, @@ -549,7 +554,7 @@ def fetchone(self): # invalidates all our cached table id to schema mappings. This means we have to load them all # again for each logfile which is potentially wasted effort but we can't really do much better # without being broken in restart case - self.table_map = {} + self.table_map: Dict = {} elif binlog_event.log_pos: self.log_pos = binlog_event.log_pos @@ -599,8 +604,8 @@ def fetchone(self): return binlog_event.event - def _allowed_event_list(self, only_events, ignored_events, - filter_non_implemented_events): + def _allowed_event_list(self, only_events: Optional[List[str]], ignored_events: Optional[List[str]], + filter_non_implemented_events: bool) -> FrozenSet[str]: if only_events is not None: events = set(only_events) else: @@ -638,13 +643,13 @@ def _allowed_event_list(self, only_events, ignored_events, pass return frozenset(events) - def __get_table_information(self, schema, table): + def __get_table_information(self, schema: str, table: str) -> List[Dict[str, Any]]: for i in range(1, 3): try: if not self.__connected_ctl: self.__connect_to_ctl() - cur = self._ctl_connection.cursor() + cur: Cursor = self._ctl_connection.cursor() cur.execute(""" SELECT COLUMN_NAME, COLLATION_NAME, CHARACTER_SET_NAME, @@ -655,7 +660,7 @@ def __get_table_information(self, schema, table): WHERE table_schema = %s AND table_name = %s """, (schema, table)) - result = sorted(cur.fetchall(), key=lambda x: x['ORDINAL_POSITION']) + result: List = sorted(cur.fetchall(), key=lambda x: x['ORDINAL_POSITION']) cur.close() return result @@ -667,5 +672,5 @@ def __get_table_information(self, schema, table): else: raise error - def __iter__(self): + def __iter__(self) -> Iterator[Union[BinLogPacketWrapper, None]]: return iter(self.fetchone, None) From b9b241b97169f5f93e7d1a7f27f900ecfcb4d02d Mon Sep 17 00:00:00 2001 From: starcat37 Date: Sun, 3 Sep 2023 00:57:54 +0900 Subject: [PATCH 9/9] Fix: modify the typing of allowed_events_in_packet --- pymysqlreplication/binlogstream.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pymysqlreplication/binlogstream.py b/pymysqlreplication/binlogstream.py index a2b7cf92..d5becf95 100644 --- a/pymysqlreplication/binlogstream.py +++ b/pymysqlreplication/binlogstream.py @@ -8,9 +8,10 @@ from pymysql.cursors import Cursor, DictCursor from pymysql.connections import Connection, MysqlPacket +from . import event from .constants.BINLOG import TABLE_MAP_EVENT, ROTATE_EVENT, FORMAT_DESCRIPTION_EVENT from .event import ( - QueryEvent, RotateEvent, FormatDescriptionEvent, + BinLogEvent, QueryEvent, RotateEvent, FormatDescriptionEvent, XidEvent, GtidEvent, StopEvent, XAPrepareEvent, BeginLoadQueryEvent, ExecuteLoadQueryEvent, HeartbeatLogEvent, NotImplementedEvent, MariadbGtidEvent, @@ -212,7 +213,7 @@ def __init__(self, connection_settings: Dict, server_id: int, # We can't filter on packet level TABLE_MAP and rotate event because # we need them for handling other operations - self.__allowed_events_in_packet: FrozenSet[str] = frozenset( + self.__allowed_events_in_packet: FrozenSet[Type[BinLogEvent]] = frozenset( [TableMapEvent, RotateEvent]).union(self.__allowed_events) self.__server_id: int = server_id