|
21 | 21 | import sys |
22 | 22 | import time |
23 | 23 | import traceback |
24 | | -from typing import Dict, Iterable, Optional, Set |
| 24 | +from types import TracebackType |
| 25 | +from typing import Dict, Iterable, Optional, Set, Tuple, Type, cast |
25 | 26 |
|
26 | 27 | import yaml |
27 | 28 | from matrix_common.versionstring import get_distribution_version_string |
28 | 29 |
|
29 | | -from twisted.internet import defer, reactor |
| 30 | +from twisted.internet import defer, reactor as reactor_ |
30 | 31 |
|
31 | 32 | from synapse.config.database import DatabaseConnectionConfig |
32 | 33 | from synapse.config.homeserver import HomeServerConfig |
|
66 | 67 | from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore |
67 | 68 | from synapse.storage.engines import create_engine |
68 | 69 | from synapse.storage.prepare_database import prepare_database |
| 70 | +from synapse.types import ISynapseReactor |
69 | 71 | from synapse.util import Clock |
70 | 72 |
|
| 73 | +# Cast safety: Twisted does some naughty magic which replaces the |
| 74 | +# twisted.internet.reactor module with a Reactor instance at runtime. |
| 75 | +reactor = cast(ISynapseReactor, reactor_) |
71 | 76 | logger = logging.getLogger("synapse_port_db") |
72 | 77 |
|
73 | 78 |
|
|
159 | 164 |
|
160 | 165 | # Error returned by the run function. Used at the top-level part of the script to |
161 | 166 | # handle errors and return codes. |
162 | | -end_error = None # type: Optional[str] |
| 167 | +end_error: Optional[str] = None |
163 | 168 | # The exec_info for the error, if any. If error is defined but not exec_info the script |
164 | 169 | # will show only the error message without the stacktrace, if exec_info is defined but |
165 | 170 | # not the error then the script will show nothing outside of what's printed in the run |
166 | 171 | # function. If both are defined, the script will print both the error and the stacktrace. |
167 | | -end_error_exec_info = None |
| 172 | +end_error_exec_info: Optional[ |
| 173 | + Tuple[Type[BaseException], BaseException, TracebackType] |
| 174 | +] = None |
168 | 175 |
|
169 | 176 |
|
170 | 177 | class Store( |
@@ -236,9 +243,12 @@ def get_instance_name(self): |
236 | 243 | return "master" |
237 | 244 |
|
238 | 245 |
|
239 | | -class Porter(object): |
240 | | - def __init__(self, **kwargs): |
241 | | - self.__dict__.update(kwargs) |
| 246 | +class Porter: |
| 247 | + def __init__(self, sqlite_config, progress, batch_size, hs_config): |
| 248 | + self.sqlite_config = sqlite_config |
| 249 | + self.progress = progress |
| 250 | + self.batch_size = batch_size |
| 251 | + self.hs_config = hs_config |
242 | 252 |
|
243 | 253 | async def setup_table(self, table): |
244 | 254 | if table in APPEND_ONLY_TABLES: |
@@ -323,7 +333,7 @@ def _get_constraints(txn): |
323 | 333 | """ |
324 | 334 | txn.execute(sql) |
325 | 335 |
|
326 | | - results = {} |
| 336 | + results: Dict[str, Set[str]] = {} |
327 | 337 | for table, foreign_table in txn: |
328 | 338 | results.setdefault(table, set()).add(foreign_table) |
329 | 339 | return results |
@@ -540,7 +550,8 @@ def build_db_store( |
540 | 550 | db_conn, allow_outdated_version=allow_outdated_version |
541 | 551 | ) |
542 | 552 | prepare_database(db_conn, engine, config=self.hs_config) |
543 | | - store = Store(DatabasePool(hs, db_config, engine), db_conn, hs) |
| 553 | + # Type safety: ignore that we're using Mock homeservers here. |
| 554 | + store = Store(DatabasePool(hs, db_config, engine), db_conn, hs) # type: ignore[arg-type] |
544 | 555 | db_conn.commit() |
545 | 556 |
|
546 | 557 | return store |
@@ -724,7 +735,9 @@ def alter_table(txn): |
724 | 735 | except Exception as e: |
725 | 736 | global end_error_exec_info |
726 | 737 | end_error = str(e) |
727 | | - end_error_exec_info = sys.exc_info() |
| 738 | + # Type safety: we're in an exception handler, so the exc_info() tuple |
| 739 | + # will not be (None, None, None). |
| 740 | + end_error_exec_info = sys.exc_info() # type: ignore[assignment] |
728 | 741 | logger.exception("") |
729 | 742 | finally: |
730 | 743 | reactor.stop() |
@@ -1023,7 +1036,7 @@ def __init__(self, stdscr): |
1023 | 1036 | curses.init_pair(1, curses.COLOR_RED, -1) |
1024 | 1037 | curses.init_pair(2, curses.COLOR_GREEN, -1) |
1025 | 1038 |
|
1026 | | - self.last_update = 0 |
| 1039 | + self.last_update = 0.0 |
1027 | 1040 |
|
1028 | 1041 | self.finished = False |
1029 | 1042 |
|
@@ -1082,8 +1095,7 @@ def render(self, force=False): |
1082 | 1095 | left_margin = 5 |
1083 | 1096 | middle_space = 1 |
1084 | 1097 |
|
1085 | | - items = self.tables.items() |
1086 | | - items = sorted(items, key=lambda i: (i[1]["perc"], i[0])) |
| 1098 | + items = sorted(self.tables.items(), key=lambda i: (i[1]["perc"], i[0])) |
1087 | 1099 |
|
1088 | 1100 | for i, (table, data) in enumerate(items): |
1089 | 1101 | if i + 2 >= rows: |
@@ -1179,15 +1191,11 @@ def main(): |
1179 | 1191 |
|
1180 | 1192 | args = parser.parse_args() |
1181 | 1193 |
|
1182 | | - logging_config = { |
1183 | | - "level": logging.DEBUG if args.v else logging.INFO, |
1184 | | - "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s", |
1185 | | - } |
1186 | | - |
1187 | | - if args.curses: |
1188 | | - logging_config["filename"] = "port-synapse.log" |
1189 | | - |
1190 | | - logging.basicConfig(**logging_config) |
| 1194 | + logging.basicConfig( |
| 1195 | + level=logging.DEBUG if args.v else logging.INFO, |
| 1196 | + format="%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s", |
| 1197 | + filename="port-synapse.log" if args.curses else None, |
| 1198 | + ) |
1191 | 1199 |
|
1192 | 1200 | sqlite_config = { |
1193 | 1201 | "name": "sqlite3", |
@@ -1218,6 +1226,7 @@ def main(): |
1218 | 1226 | config.parse_config_dict(hs_config, "", "") |
1219 | 1227 |
|
1220 | 1228 | def start(stdscr=None): |
| 1229 | + progress: Progress |
1221 | 1230 | if stdscr: |
1222 | 1231 | progress = CursesProgress(stdscr) |
1223 | 1232 | else: |
|
0 commit comments