From 8627451b820f882c42f00fea7d539864404be02c Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 3 Feb 2022 08:29:07 -0600 Subject: [PATCH 01/20] PYTHON-3064 Add typings to test package --- mypy.ini | 9 + pymongo/socket_checker.py | 2 +- pymongo/typings.py | 4 +- test/__init__.py | 33 +++- test/auth_aws/test_auth_aws.py | 1 + test/mockupdb/test_cursor_namespace.py | 6 + test/mockupdb/test_getmore_sharded.py | 2 +- test/mockupdb/test_handshake.py | 62 +------ test/mockupdb/test_mixed_version_sharded.py | 4 +- test/mockupdb/test_op_msg.py | 2 + test/mockupdb/test_op_msg_read_preference.py | 5 +- test/mockupdb/test_query_read_pref_sharded.py | 2 +- test/mockupdb/test_reset_and_request_check.py | 12 +- test/mockupdb/test_slave_okay_sharded.py | 2 +- test/performance/perf_test.py | 11 +- test/qcheck.py | 2 +- test/test_auth.py | 6 +- test/test_binary.py | 6 + test/test_bson.py | 83 +++++---- test/test_bulk.py | 162 +++++------------- test/test_change_stream.py | 40 +++++ test/test_client.py | 14 +- test/test_cmap.py | 8 +- test/test_code.py | 5 +- test/test_collation.py | 6 + test/test_collection.py | 63 ++++--- test/test_command_monitoring_legacy.py | 2 + test/test_common.py | 10 +- ...nnections_survive_primary_stepdown_spec.py | 6 +- test/test_crud_v1.py | 6 +- test/test_cursor.py | 29 ++-- test/test_custom_types.py | 59 +++++-- test/test_data_lake.py | 4 +- test/test_database.py | 27 +-- test/test_dbref.py | 7 +- test/test_decimal128.py | 1 + test/test_discovery_and_monitoring.py | 2 + test/test_dns.py | 1 + test/test_encryption.py | 67 +++++--- test/test_examples.py | 14 +- test/test_grid_file.py | 5 +- test/test_gridfs.py | 35 ++-- test/test_gridfs_bucket.py | 6 +- test/test_gridfs_spec.py | 3 + test/test_json_util.py | 4 +- test/test_max_staleness.py | 2 +- test/test_monitor.py | 2 +- test/test_monitoring.py | 14 +- test/test_objectid.py | 4 +- test/test_ocsp_cache.py | 12 +- test/test_raw_bson.py | 3 + test/test_read_concern.py | 1 + test/test_read_preferences.py | 22 ++- test/test_read_write_concern_spec.py | 17 +- test/test_retryable_writes.py | 5 + test/test_sdam_monitoring_spec.py | 5 + test/test_server_selection.py | 5 +- test/test_server_selection_in_window.py | 2 +- test/test_session.py | 25 ++- test/test_son.py | 38 ++-- test/test_srv_polling.py | 16 +- test/test_ssl.py | 34 ++-- test/test_streaming_protocol.py | 2 +- test/test_transactions.py | 18 +- test/test_uri_parser.py | 4 +- test/test_write_concern.py | 2 +- test/unified_format.py | 13 +- test/utils.py | 19 +- test/utils_spec_runner.py | 17 +- 69 files changed, 655 insertions(+), 467 deletions(-) diff --git a/mypy.ini b/mypy.ini index 926bf95745..5772c1355f 100644 --- a/mypy.ini +++ b/mypy.ini @@ -11,6 +11,9 @@ warn_unused_configs = true warn_unused_ignores = true warn_redundant_casts = true +[mypy-gevent.*] +ignore_missing_imports = True + [mypy-kerberos.*] ignore_missing_imports = True @@ -29,5 +32,11 @@ ignore_missing_imports = True [mypy-snappy.*] ignore_missing_imports = True +[mypy-test.*] +allow_redefinition = true + [mypy-winkerberos.*] ignore_missing_imports = True + +[mypy-xmlrunner.*] +ignore_missing_imports = True diff --git a/pymongo/socket_checker.py b/pymongo/socket_checker.py index 9eb3d5f084..8b7418725b 100644 --- a/pymongo/socket_checker.py +++ b/pymongo/socket_checker.py @@ -43,7 +43,7 @@ def __init__(self) -> None: else: self._poller = None - def select(self, sock: Any, read: bool = False, write: bool = False, timeout: int = 0) -> bool: + def select(self, sock: Any, read: bool = False, write: bool = False, timeout: Union[float, int] = 0) -> bool: """Select for reads or writes with a timeout in seconds (or None). Returns True if the socket is readable/writable, False on timeout. diff --git a/pymongo/typings.py b/pymongo/typings.py index ae5aec3213..e234f59903 100644 --- a/pymongo/typings.py +++ b/pymongo/typings.py @@ -16,8 +16,10 @@ from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, MutableMapping, Optional, Tuple, Type, TypeVar, Union) + if TYPE_CHECKING: from bson.raw_bson import RawBSONDocument + from bson.son import SON from pymongo.collation import Collation @@ -26,4 +28,4 @@ _CollationIn = Union[Mapping[str, Any], "Collation"] _DocumentIn = Union[MutableMapping[str, Any], "RawBSONDocument"] _Pipeline = List[Mapping[str, Any]] -_DocumentType = TypeVar('_DocumentType', Mapping[str, Any], MutableMapping[str, Any], Dict[str, Any]) +_DocumentType = TypeVar('_DocumentType', Mapping[str, Any], MutableMapping[str, Any], Dict[str, Any], "SON") diff --git a/test/__init__.py b/test/__init__.py index ab53b7fdc5..5ed23b7442 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -40,6 +40,7 @@ from contextlib import contextmanager from functools import wraps +from typing import Dict, no_type_check from unittest import SkipTest import pymongo @@ -48,7 +49,9 @@ from bson.son import SON from pymongo import common, message from pymongo.common import partition_node +from pymongo.database import Database from pymongo.hello import HelloCompat +from pymongo.mongo_client import MongoClient from pymongo.server_api import ServerApi from pymongo.ssl_support import HAVE_SSL, _ssl from pymongo.uri_parser import parse_uri @@ -86,7 +89,7 @@ os.path.join(CERT_PATH, 'client.pem')) CA_PEM = os.environ.get('CA_PEM', os.path.join(CERT_PATH, 'ca.pem')) -TLS_OPTIONS = dict(tls=True) +TLS_OPTIONS: Dict = dict(tls=True) if CLIENT_PEM: TLS_OPTIONS['tlsCertificateKeyFile'] = CLIENT_PEM if CA_PEM: @@ -102,13 +105,13 @@ # Remove after PYTHON-2712 from pymongo import pool pool._MOCK_SERVICE_ID = True - res = parse_uri(SINGLE_MONGOS_LB_URI) + res = parse_uri(SINGLE_MONGOS_LB_URI or "") host, port = res['nodelist'][0] db_user = res['username'] or db_user db_pwd = res['password'] or db_pwd elif TEST_SERVERLESS: TEST_LOADBALANCER = True - res = parse_uri(SINGLE_MONGOS_LB_URI) + res = parse_uri(SINGLE_MONGOS_LB_URI or "") host, port = res['nodelist'][0] db_user = res['username'] or db_user db_pwd = res['password'] or db_pwd @@ -184,6 +187,7 @@ def enable(self): def __enter__(self): self.enable() + @no_type_check def disable(self): common.HEARTBEAT_FREQUENCY = self.old_heartbeat_frequency common.MIN_HEARTBEAT_INTERVAL = self.old_min_heartbeat_interval @@ -224,6 +228,8 @@ def _all_users(db): class ClientContext(object): + client: MongoClient + MULTI_MONGOS_LB_URI = MULTI_MONGOS_LB_URI def __init__(self): @@ -247,9 +253,9 @@ def __init__(self): self.tls = False self.tlsCertificateKeyFile = False self.server_is_resolvable = is_server_resolvable() - self.default_client_options = {} + self.default_client_options: Dict = {} self.sessions_enabled = False - self.client = None + self.client = None # type: ignore self.conn_lock = threading.Lock() self.is_data_lake = False self.load_balancer = TEST_LOADBALANCER @@ -340,6 +346,7 @@ def _init_client(self): try: self.cmd_line = self.client.admin.command('getCmdLineOpts') except pymongo.errors.OperationFailure as e: + assert e.details is not None msg = e.details.get('errmsg', '') if e.code == 13 or 'unauthorized' in msg or 'login' in msg: # Unauthorized. @@ -418,6 +425,7 @@ def _init_client(self): else: self.server_parameters = self.client.admin.command( 'getParameter', '*') + assert self.cmd_line is not None if 'enableTestCommands=1' in self.cmd_line['argv']: self.test_commands_enabled = True elif 'parsed' in self.cmd_line: @@ -436,7 +444,8 @@ def _init_client(self): self.mongoses.append(address) if not self.serverless: # Check for another mongos on the next port. - next_address = address[0], address[1] + 1 + assert address is not None + next_address = address[0], address[1] + 1 mongos_client = self._connect( *next_address, **self.default_client_options) if mongos_client: @@ -479,7 +488,7 @@ def has_secondaries(self): @property def storage_engine(self): try: - return self.server_status.get("storageEngine", {}).get("name") + return self.server_status.get("storageEngine", {}).get("name") # type: ignore[union-attr] except AttributeError: # Raised if self.server_status is None. return None @@ -496,6 +505,7 @@ def _check_user_provided(self): try: return db_user in _all_users(client.admin) except pymongo.errors.OperationFailure as e: + assert e.details is not None msg = e.details.get('errmsg', '') if e.code == 18 or 'auth fails' in msg: # Auth failed. @@ -505,6 +515,7 @@ def _check_user_provided(self): def _server_started_with_auth(self): # MongoDB >= 2.0 + assert self.cmd_line is not None if 'parsed' in self.cmd_line: parsed = self.cmd_line['parsed'] # MongoDB >= 2.6 @@ -525,6 +536,7 @@ def _server_started_with_ipv6(self): if not socket.has_ipv6: return False + assert self.cmd_line is not None if 'parsed' in self.cmd_line: if not self.cmd_line['parsed'].get('net', {}).get('ipv6'): return False @@ -642,7 +654,7 @@ def supports_secondary_read_pref(self): if self.has_secondaries: return True if self.is_mongos: - shard = self.client.config.shards.find_one()['host'] + shard = self.client.config.shards.find_one()['host'] # type: ignore[index] num_members = shard.count(',') + 1 return num_members > 1 return False @@ -932,6 +944,9 @@ def fail_point(self, command_args): class IntegrationTest(PyMongoTestCase): """Base class for TestCases that need a connection to MongoDB to pass.""" + client: MongoClient + db: Database + credentials: dict[str, str] @classmethod @client_context.require_connection @@ -1073,7 +1088,7 @@ def run(self, test): if HAVE_XML: - class PymongoXMLTestRunner(XMLTestRunner): + class PymongoXMLTestRunner(XMLTestRunner): # type: ignore[misc] def run(self, test): setup() result = super(PymongoXMLTestRunner, self).run(test) diff --git a/test/auth_aws/test_auth_aws.py b/test/auth_aws/test_auth_aws.py index 0522201097..f096d0569a 100644 --- a/test/auth_aws/test_auth_aws.py +++ b/test/auth_aws/test_auth_aws.py @@ -26,6 +26,7 @@ class TestAuthAWS(unittest.TestCase): + uri: str @classmethod def setUpClass(cls): diff --git a/test/mockupdb/test_cursor_namespace.py b/test/mockupdb/test_cursor_namespace.py index 10605601cf..a52e2fb4e7 100644 --- a/test/mockupdb/test_cursor_namespace.py +++ b/test/mockupdb/test_cursor_namespace.py @@ -21,6 +21,9 @@ class TestCursorNamespace(unittest.TestCase): + server: MockupDB + client: MongoClient + @classmethod def setUpClass(cls): cls.server = MockupDB(auto_ismaster={'maxWireVersion': 6}) @@ -69,6 +72,9 @@ def op(): class TestKillCursorsNamespace(unittest.TestCase): + server: MockupDB + client: MongoClient + @classmethod def setUpClass(cls): cls.server = MockupDB(auto_ismaster={'maxWireVersion': 6}) diff --git a/test/mockupdb/test_getmore_sharded.py b/test/mockupdb/test_getmore_sharded.py index 5461a13e35..0d91583378 100644 --- a/test/mockupdb/test_getmore_sharded.py +++ b/test/mockupdb/test_getmore_sharded.py @@ -27,7 +27,7 @@ def test_getmore_sharded(self): servers = [MockupDB(), MockupDB()] # Collect queries to either server in one queue. - q = Queue() + q: Queue = Queue() for server in servers: server.subscribe(q.put) server.autoresponds('ismaster', ismaster=True, msg='isdbgrid', diff --git a/test/mockupdb/test_handshake.py b/test/mockupdb/test_handshake.py index 34028a637f..621f01728f 100644 --- a/test/mockupdb/test_handshake.py +++ b/test/mockupdb/test_handshake.py @@ -12,58 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from mockupdb import (MockupDB, OpReply, OpMsg, OpMsgReply, OpQuery, absent, - Command, go) +from mockupdb import MockupDB, OpReply, OpMsg, absent, Command, go from pymongo import MongoClient, version as pymongo_version from pymongo.errors import OperationFailure -from pymongo.server_api import ServerApi, ServerApiVersion -from bson.objectid import ObjectId import unittest -def test_hello_with_option(self, protocol, **kwargs): - hello = "ismaster" if isinstance(protocol(), OpQuery) else "hello" - # `db.command("hello"|"ismaster")` commands are the same for primaries and - # secondaries, so we only need one server. - primary = MockupDB() - # Set up a custom handler to save the first request from the driver. - self.handshake_req = None - def respond(r): - # Only save the very first request from the driver. - if self.handshake_req == None: - self.handshake_req = r - load_balanced_kwargs = {"serviceId": ObjectId()} if kwargs.get( - "loadBalanced") else {} - return r.reply(OpMsgReply(minWireVersion=0, maxWireVersion=13, - **kwargs, **load_balanced_kwargs)) - primary.autoresponds(respond) - primary.run() - self.addCleanup(primary.stop) - - # We need a special dict because MongoClient uses "server_api" and all - # of the commands use "apiVersion". - k_map = {("apiVersion", "1"):("server_api", ServerApi( - ServerApiVersion.V1))} - client = MongoClient("mongodb://"+primary.address_string, - appname='my app', # For _check_handshake_data() - **dict([k_map.get((k, v), (k, v)) for k, v - in kwargs.items()])) - - self.addCleanup(client.close) - - # We have an autoresponder luckily, so no need for `go()`. - assert client.db.command(hello) - - # We do this checking here rather than in the autoresponder `respond()` - # because it runs in another Python thread so there are some funky things - # with error handling within that thread, and we want to be able to use - # self.assertRaises(). - self.handshake_req.assert_matches(protocol(hello, **kwargs)) - _check_handshake_data(self.handshake_req) - - def _check_handshake_data(request): assert 'client' in request data = request['client'] @@ -200,22 +156,6 @@ def test_client_handshake_saslSupportedMechs(self): future() return - def test_handshake_load_balanced(self): - test_hello_with_option(self, OpMsg, loadBalanced=True) - with self.assertRaisesRegex(AssertionError, "does not match"): - test_hello_with_option(self, Command, loadBalanced=True) - - def test_handshake_versioned_api(self): - test_hello_with_option(self, OpMsg, apiVersion="1") - with self.assertRaisesRegex(AssertionError, "does not match"): - test_hello_with_option(self, Command, apiVersion="1") - - def test_handshake_not_either(self): - # If we don't specify either option then it should be using - # OP_QUERY for the initial step of the handshake. - test_hello_with_option(self, Command) - with self.assertRaisesRegex(AssertionError, "does not match"): - test_hello_with_option(self, OpMsg) if __name__ == '__main__': unittest.main() diff --git a/test/mockupdb/test_mixed_version_sharded.py b/test/mockupdb/test_mixed_version_sharded.py index 2b6ea6a513..c3af907404 100644 --- a/test/mockupdb/test_mixed_version_sharded.py +++ b/test/mockupdb/test_mixed_version_sharded.py @@ -30,7 +30,7 @@ def setup_server(self, upgrade): self.mongos_old, self.mongos_new = MockupDB(), MockupDB() # Collect queries to either server in one queue. - self.q = Queue() + self.q: Queue = Queue() for server in self.mongos_old, self.mongos_new: server.subscribe(self.q.put) server.autoresponds('getlasterror') @@ -59,7 +59,7 @@ def create_mixed_version_sharded_test(upgrade): def test(self): self.setup_server(upgrade) start = time.time() - servers_used = set() + servers_used: set = set() while len(servers_used) < 2: go(upgrade.function, self.client) request = self.q.get(timeout=1) diff --git a/test/mockupdb/test_op_msg.py b/test/mockupdb/test_op_msg.py index 35e70cebfc..78397a3336 100755 --- a/test/mockupdb/test_op_msg.py +++ b/test/mockupdb/test_op_msg.py @@ -233,6 +233,8 @@ class TestOpMsg(unittest.TestCase): + server: MockupDB + client: MongoClient @classmethod def setUpClass(cls): diff --git a/test/mockupdb/test_op_msg_read_preference.py b/test/mockupdb/test_op_msg_read_preference.py index d9adfe17eb..eb3a14fa01 100644 --- a/test/mockupdb/test_op_msg_read_preference.py +++ b/test/mockupdb/test_op_msg_read_preference.py @@ -14,6 +14,7 @@ import copy import itertools +from typing import Any from mockupdb import MockupDB, going, CommandBase from pymongo import MongoClient, ReadPreference @@ -27,6 +28,8 @@ class OpMsgReadPrefBase(unittest.TestCase): single_mongod = False + primary: MockupDB + secondary: MockupDB @classmethod def setUpClass(cls): @@ -142,7 +145,7 @@ def test(self): tag_sets=None) client = self.setup_client(read_preference=pref) - + expected_pref: Any if operation.op_type == 'always-use-secondary': expected_server = self.secondary expected_pref = ReadPreference.SECONDARY diff --git a/test/mockupdb/test_query_read_pref_sharded.py b/test/mockupdb/test_query_read_pref_sharded.py index 88dcdd8351..20bc1e9155 100644 --- a/test/mockupdb/test_query_read_pref_sharded.py +++ b/test/mockupdb/test_query_read_pref_sharded.py @@ -50,7 +50,7 @@ def test_query_and_read_mode_sharded_op_msg(self): for pref in read_prefs: collection = client.db.get_collection('test', read_preference=pref) - cursor = collection.find(query.copy()) + cursor = collection.find(query.copy()) # type: ignore[attr-defined] with going(next, cursor): request = server.receives() # Command is not nested in $query. diff --git a/test/mockupdb/test_reset_and_request_check.py b/test/mockupdb/test_reset_and_request_check.py index 86c2085e39..48f9486544 100755 --- a/test/mockupdb/test_reset_and_request_check.py +++ b/test/mockupdb/test_reset_and_request_check.py @@ -27,7 +27,7 @@ class TestResetAndRequestCheck(unittest.TestCase): def __init__(self, *args, **kwargs): super(TestResetAndRequestCheck, self).__init__(*args, **kwargs) - self.ismaster_time = 0 + self.ismaster_time = 0.0 self.client = None self.server = None @@ -45,7 +45,7 @@ def responder(request): kwargs = {'socketTimeoutMS': 100} # Disable retryable reads when pymongo supports it. kwargs['retryReads'] = False - self.client = MongoClient(self.server.uri, **kwargs) + self.client = MongoClient(self.server.uri, **kwargs) # type: ignore wait_until(lambda: self.client.nodes, 'connect to standalone') def tearDown(self): @@ -56,6 +56,8 @@ def _test_disconnect(self, operation): # Application operation fails. Test that client resets server # description and does *not* schedule immediate check. self.setup_server() + assert self.server is not None + assert self.client is not None # Network error on application operation. with self.assertRaises(ConnectionFailure): @@ -81,6 +83,8 @@ def _test_timeout(self, operation): # Application operation times out. Test that client does *not* reset # server description and does *not* schedule immediate check. self.setup_server() + assert self.server is not None + assert self.client is not None with self.assertRaises(ConnectionFailure): with going(operation.function, self.client): @@ -91,6 +95,7 @@ def _test_timeout(self, operation): # Server is *not* Unknown. topology = self.client._topology server = topology.select_server_by_address(self.server.address, 0) + assert server is not None self.assertEqual(SERVER_TYPE.Standalone, server.description.server_type) after = self.ismaster_time @@ -99,6 +104,8 @@ def _test_timeout(self, operation): def _test_not_master(self, operation): # Application operation gets a "not master" error. self.setup_server() + assert self.server is not None + assert self.client is not None with self.assertRaises(ConnectionFailure): with going(operation.function, self.client): @@ -110,6 +117,7 @@ def _test_not_master(self, operation): # Server is rediscovered. topology = self.client._topology server = topology.select_server_by_address(self.server.address, 0) + assert server is not None self.assertEqual(SERVER_TYPE.Standalone, server.description.server_type) after = self.ismaster_time diff --git a/test/mockupdb/test_slave_okay_sharded.py b/test/mockupdb/test_slave_okay_sharded.py index 63bb0fe303..07e05bfece 100644 --- a/test/mockupdb/test_slave_okay_sharded.py +++ b/test/mockupdb/test_slave_okay_sharded.py @@ -37,7 +37,7 @@ def setup_server(self): self.mongos1, self.mongos2 = MockupDB(), MockupDB() # Collect queries to either server in one queue. - self.q = Queue() + self.q: Queue = Queue() for server in self.mongos1, self.mongos2: server.subscribe(self.q.put) server.run() diff --git a/test/performance/perf_test.py b/test/performance/perf_test.py index 84c6baf60d..f25ff675e8 100644 --- a/test/performance/perf_test.py +++ b/test/performance/perf_test.py @@ -25,7 +25,7 @@ try: import simplejson as json except ImportError: - import json # type: ignore[no-redef] + import json # type: ignore sys.path[0:0] = [""] @@ -67,6 +67,10 @@ def __exit__(self, *args): class PerformanceTest(object): + dataset: Any + data_size: Any + do_task: Any + fail: Any @classmethod def setUpClass(cls): @@ -386,6 +390,7 @@ def mp_map(map_func, files): def insert_json_file(filename): + assert proc_client is not None with open(filename, 'r') as data: coll = proc_client.perftest.corpus coll.insert_many([json.loads(line) for line in data]) @@ -398,11 +403,13 @@ def insert_json_file_with_file_id(filename): doc = json.loads(line) doc['file'] = filename documents.append(doc) + assert proc_client is not None coll = proc_client.perftest.corpus coll.insert_many(documents) def read_json_file(filename): + assert proc_client is not None coll = proc_client.perftest.corpus temp = tempfile.TemporaryFile(mode='w') try: @@ -414,6 +421,7 @@ def read_json_file(filename): def insert_gridfs_file(filename): + assert proc_client is not None bucket = GridFSBucket(proc_client.perftest) with open(filename, 'rb') as gfile: @@ -421,6 +429,7 @@ def insert_gridfs_file(filename): def read_gridfs_file(filename): + assert proc_client is not None bucket = GridFSBucket(proc_client.perftest) temp = tempfile.TemporaryFile() diff --git a/test/qcheck.py b/test/qcheck.py index 57e0940b72..0066ac6e57 100644 --- a/test/qcheck.py +++ b/test/qcheck.py @@ -189,7 +189,7 @@ def simplify(case): # TODO this is a hack simplified[key] = value return (success, success and simplified or case) if isinstance(case, list): - simplified = list(case) + simplified = list(case) # type: ignore[assignment] if random.choice([True, False]): # delete if not len(simplified): diff --git a/test/test_auth.py b/test/test_auth.py index 35f198574b..5b4ef0c51f 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -76,6 +76,8 @@ def run(self): class TestGSSAPI(unittest.TestCase): + mech_properties: str + service_realm_required: bool @classmethod def setUpClass(cls): @@ -116,6 +118,7 @@ def test_credentials_hashing(self): @ignore_deprecations def test_gssapi_simple(self): + assert GSSAPI_PRINCIPAL is not None if GSSAPI_PASS is not None: uri = ('mongodb://%s:%s@%s:%d/?authMechanism=' 'GSSAPI' % (quote_plus(GSSAPI_PRINCIPAL), @@ -264,6 +267,8 @@ def test_sasl_plain(self): authMechanism='PLAIN') client.ldap.test.find_one() + assert SASL_USER is not None + assert SASL_PASS is not None uri = ('mongodb://%s:%s@%s:%d/?authMechanism=PLAIN;' 'authSource=%s' % (quote_plus(SASL_USER), quote_plus(SASL_PASS), @@ -540,7 +545,6 @@ def test_cache(self): self.assertIsInstance(iterations, int) def test_scram_threaded(self): - coll = client_context.client.db.test coll.drop() coll.insert_one({'_id': 1}) diff --git a/test/test_binary.py b/test/test_binary.py index e6b681fc51..4bbda0c9d4 100644 --- a/test/test_binary.py +++ b/test/test_binary.py @@ -41,6 +41,8 @@ class TestBinary(unittest.TestCase): + csharp_data: bytes + java_data: bytes @classmethod def setUpClass(cls): @@ -354,6 +356,8 @@ def test_buffer_protocol(self): class TestUuidSpecExplicitCoding(unittest.TestCase): + uuid: uuid.UUID + @classmethod def setUpClass(cls): super(TestUuidSpecExplicitCoding, cls).setUpClass() @@ -457,6 +461,8 @@ def test_decoding_4(self): class TestUuidSpecImplicitCoding(IntegrationTest): + uuid: uuid.UUID + @classmethod def setUpClass(cls): super(TestUuidSpecImplicitCoding, cls).setUpClass() diff --git a/test/test_bson.py b/test/test_bson.py index eb4f4e47c2..6e619f27d5 100644 --- a/test/test_bson.py +++ b/test/test_bson.py @@ -186,7 +186,7 @@ def test_encode_then_decode_any_mapping_legacy(self): decoder=lambda *args: BSON(args[0]).decode(*args[1:])) def test_encoding_defaultdict(self): - dct = collections.defaultdict(dict, [('foo', 'bar')]) + dct = collections.defaultdict(dict, [('foo', 'bar')]) # type: ignore[arg-type] encode(dct) self.assertEqual(dct, collections.defaultdict(dict, [('foo', 'bar')])) @@ -302,7 +302,7 @@ def test_basic_decode(self): def test_decode_all_buffer_protocol(self): docs = [{'foo': 'bar'}, {}] - bs = b"".join(map(encode, docs)) + bs = b"".join(map(encode, docs)) # type: ignore[arg-type] self.assertEqual(docs, decode_all(bytearray(bs))) self.assertEqual(docs, decode_all(memoryview(bs))) self.assertEqual(docs, decode_all(memoryview(b'1' + bs + b'1')[1:-1])) @@ -530,7 +530,7 @@ def test_large_datetime_truncation(self): def test_aware_datetime(self): aware = datetime.datetime(1993, 4, 4, 2, tzinfo=FixedOffset(555, "SomeZone")) - as_utc = (aware - aware.utcoffset()).replace(tzinfo=utc) + as_utc = (aware - aware.utcoffset()).replace(tzinfo=utc) # type: ignore self.assertEqual(datetime.datetime(1993, 4, 3, 16, 45, tzinfo=utc), as_utc) after = decode(encode({"date": aware}), CodecOptions(tz_aware=True))[ @@ -591,7 +591,7 @@ def test_local_datetime(self): def test_naive_decode(self): aware = datetime.datetime(1993, 4, 4, 2, tzinfo=FixedOffset(555, "SomeZone")) - naive_utc = (aware - aware.utcoffset()).replace(tzinfo=None) + naive_utc = (aware - aware.utcoffset()).replace(tzinfo=None) # type: ignore self.assertEqual(datetime.datetime(1993, 4, 3, 16, 45), naive_utc) after = decode(encode({"date": aware}))["date"] self.assertEqual(None, after.tzinfo) @@ -603,9 +603,9 @@ def test_dst(self): @unittest.skip('Disabled due to http://bugs.python.org/issue25222') def test_bad_encode(self): - evil_list = {'a': []} + evil_list: dict = {'a': []} evil_list['a'].append(evil_list) - evil_dict = {} + evil_dict: dict = {} evil_dict['a'] = evil_dict for evil_data in [evil_dict, evil_list]: self.assertRaises(Exception, encode, evil_data) @@ -994,32 +994,57 @@ def test_decode_all_defaults(self): def test_unicode_decode_error_handler(self): enc = encode({"keystr": "foobar"}) - # Test handling of bad key value, bad string value, and both. + # Test handling of bad key value. invalid_key = enc[:7] + b'\xe9' + enc[8:] - invalid_val = enc[:18] + b'\xe9' + enc[19:] + replaced_key = b'ke\xe9str'.decode('utf-8', 'replace') + ignored_key = b'ke\xe9str'.decode('utf-8', 'ignore') + + dec = decode(invalid_key, + CodecOptions(unicode_decode_error_handler="replace")) + self.assertEqual(dec, {replaced_key: "foobar"}) + + dec = decode(invalid_key, + CodecOptions(unicode_decode_error_handler="ignore")) + self.assertEqual(dec, {ignored_key: "foobar"}) + + self.assertRaises(InvalidBSON, decode, invalid_key, CodecOptions( + unicode_decode_error_handler="strict")) + self.assertRaises(InvalidBSON, decode, invalid_key, CodecOptions()) + self.assertRaises(InvalidBSON, decode, invalid_key) + + # Test handing of bad string value. + invalid_val = BSON(enc[:18] + b'\xe9' + enc[19:]) + replaced_val = b'fo\xe9bar'.decode('utf-8', 'replace') + ignored_val = b'fo\xe9bar'.decode('utf-8', 'ignore') + + dec = decode(invalid_val, + CodecOptions(unicode_decode_error_handler="replace")) + self.assertEqual(dec, {"keystr": replaced_val}) + + dec = decode(invalid_val, + CodecOptions(unicode_decode_error_handler="ignore")) + self.assertEqual(dec, {"keystr": ignored_val}) + + self.assertRaises(InvalidBSON, decode, invalid_val, CodecOptions( + unicode_decode_error_handler="strict")) + self.assertRaises(InvalidBSON, decode, invalid_val, CodecOptions()) + self.assertRaises(InvalidBSON, decode, invalid_val) + + # Test handing bad key + bad value. invalid_both = enc[:7] + b'\xe9' + enc[8:18] + b'\xe9' + enc[19:] - # Ensure that strict mode raises an error. - for invalid in [invalid_key, invalid_val, invalid_both]: - self.assertRaises(InvalidBSON, decode, invalid, CodecOptions( - unicode_decode_error_handler="strict")) - self.assertRaises(InvalidBSON, decode, invalid, CodecOptions()) - self.assertRaises(InvalidBSON, decode, invalid) - - # Test all other error handlers. - for handler in ['replace', 'backslashreplace', 'surrogateescape', - 'ignore']: - expected_key = b'ke\xe9str'.decode('utf-8', handler) - expected_val = b'fo\xe9bar'.decode('utf-8', handler) - doc = decode(invalid_key, - CodecOptions(unicode_decode_error_handler=handler)) - self.assertEqual(doc, {expected_key: "foobar"}) - doc = decode(invalid_val, - CodecOptions(unicode_decode_error_handler=handler)) - self.assertEqual(doc, {"keystr": expected_val}) - doc = decode(invalid_both, - CodecOptions(unicode_decode_error_handler=handler)) - self.assertEqual(doc, {expected_key: expected_val}) + dec = decode(invalid_both, + CodecOptions(unicode_decode_error_handler="replace")) + self.assertEqual(dec, {replaced_key: replaced_val}) + + dec = decode(invalid_both, + CodecOptions(unicode_decode_error_handler="ignore")) + self.assertEqual(dec, {ignored_key: ignored_val}) + + self.assertRaises(InvalidBSON, decode, invalid_both, CodecOptions( + unicode_decode_error_handler="strict")) + self.assertRaises(InvalidBSON, decode, invalid_both, CodecOptions()) + self.assertRaises(InvalidBSON, decode, invalid_both) # Test handling bad error mode. dec = decode(enc, diff --git a/test/test_bulk.py b/test/test_bulk.py index 08740a437e..dcbc6d1efa 100644 --- a/test/test_bulk.py +++ b/test/test_bulk.py @@ -15,14 +15,13 @@ """Test the bulk API.""" import sys -import uuid -from bson.binary import UuidRepresentation -from bson.codec_options import CodecOptions + +from pymongo.mongo_client import MongoClient sys.path[0:0] = [""] -from bson import Binary from bson.objectid import ObjectId +from pymongo.collection import Collection from pymongo.common import partition_node from pymongo.errors import (BulkWriteError, ConfigurationError, @@ -40,6 +39,8 @@ class BulkTestBase(IntegrationTest): + coll: Collection + coll_w0: Collection @classmethod def setUpClass(cls): @@ -280,6 +281,7 @@ def test_upsert(self): upsert=True)]) self.assertEqualResponse(expected, result.bulk_api_result) self.assertEqual(1, result.upserted_count) + assert result.upserted_ids is not None self.assertEqual(1, len(result.upserted_ids)) self.assertTrue(isinstance(result.upserted_ids.get(0), ObjectId)) @@ -341,11 +343,11 @@ def test_bulk_write_invalid_arguments(self): # The requests argument must be a list. generator = (InsertOne({}) for _ in range(10)) with self.assertRaises(TypeError): - self.coll.bulk_write(generator) + self.coll.bulk_write(generator) # type: ignore[arg-type] # Document is not wrapped in a bulk write operation. with self.assertRaises(TypeError): - self.coll.bulk_write([{}]) + self.coll.bulk_write([{}]) # type: ignore[list-item] def test_upsert_large(self): big = 'a' * (client_context.max_bson_size - 37) @@ -380,78 +382,6 @@ def test_client_generated_upsert_id(self): {'index': 2, '_id': 2}]}, result.bulk_api_result) - def test_upsert_uuid_standard(self): - options = CodecOptions(uuid_representation=UuidRepresentation.STANDARD) - coll = self.coll.with_options(codec_options=options) - uuids = [uuid.uuid4() for _ in range(3)] - result = coll.bulk_write([ - UpdateOne({'_id': uuids[0]}, {'$set': {'a': 0}}, upsert=True), - ReplaceOne({'a': 1}, {'_id': uuids[1]}, upsert=True), - # This is just here to make the counts right in all cases. - ReplaceOne({'_id': uuids[2]}, {'_id': uuids[2]}, upsert=True), - ]) - self.assertEqualResponse( - {'nMatched': 0, - 'nModified': 0, - 'nUpserted': 3, - 'nInserted': 0, - 'nRemoved': 0, - 'upserted': [{'index': 0, '_id': uuids[0]}, - {'index': 1, '_id': uuids[1]}, - {'index': 2, '_id': uuids[2]}]}, - result.bulk_api_result) - - def test_upsert_uuid_unspecified(self): - options = CodecOptions(uuid_representation=UuidRepresentation.UNSPECIFIED) - coll = self.coll.with_options(codec_options=options) - uuids = [Binary.from_uuid(uuid.uuid4()) for _ in range(3)] - result = coll.bulk_write([ - UpdateOne({'_id': uuids[0]}, {'$set': {'a': 0}}, upsert=True), - ReplaceOne({'a': 1}, {'_id': uuids[1]}, upsert=True), - # This is just here to make the counts right in all cases. - ReplaceOne({'_id': uuids[2]}, {'_id': uuids[2]}, upsert=True), - ]) - self.assertEqualResponse( - {'nMatched': 0, - 'nModified': 0, - 'nUpserted': 3, - 'nInserted': 0, - 'nRemoved': 0, - 'upserted': [{'index': 0, '_id': uuids[0]}, - {'index': 1, '_id': uuids[1]}, - {'index': 2, '_id': uuids[2]}]}, - result.bulk_api_result) - - def test_upsert_uuid_standard_subdocuments(self): - options = CodecOptions(uuid_representation=UuidRepresentation.STANDARD) - coll = self.coll.with_options(codec_options=options) - ids = [ - {'f': Binary(bytes(i)), 'f2': uuid.uuid4()} - for i in range(3) - ] - - result = coll.bulk_write([ - UpdateOne({'_id': ids[0]}, {'$set': {'a': 0}}, upsert=True), - ReplaceOne({'a': 1}, {'_id': ids[1]}, upsert=True), - # This is just here to make the counts right in all cases. - ReplaceOne({'_id': ids[2]}, {'_id': ids[2]}, upsert=True), - ]) - - # The `Binary` values are returned as `bytes` objects. - for _id in ids: - _id['f'] = bytes(_id['f']) - - self.assertEqualResponse( - {'nMatched': 0, - 'nModified': 0, - 'nUpserted': 3, - 'nInserted': 0, - 'nRemoved': 0, - 'upserted': [{'index': 0, '_id': ids[0]}, - {'index': 1, '_id': ids[1]}, - {'index': 2, '_id': ids[2]}]}, - result.bulk_api_result) - def test_single_ordered_batch(self): result = self.coll.bulk_write([ InsertOne({'a': 1}), @@ -472,7 +402,7 @@ def test_single_ordered_batch(self): def test_single_error_ordered_batch(self): self.coll.create_index('a', unique=True) self.addCleanup(self.coll.drop_index, [('a', 1)]) - requests = [ + requests: list = [ InsertOne({'b': 1, 'a': 1}), UpdateOne({'b': 2}, {'$set': {'a': 1}}, upsert=True), InsertOne({'b': 3, 'a': 2}), @@ -506,7 +436,7 @@ def test_single_error_ordered_batch(self): def test_multiple_error_ordered_batch(self): self.coll.create_index('a', unique=True) self.addCleanup(self.coll.drop_index, [('a', 1)]) - requests = [ + requests: list = [ InsertOne({'b': 1, 'a': 1}), UpdateOne({'b': 2}, {'$set': {'a': 1}}, upsert=True), UpdateOne({'b': 3}, {'$set': {'a': 2}}, upsert=True), @@ -542,7 +472,7 @@ def test_multiple_error_ordered_batch(self): result) def test_single_unordered_batch(self): - requests = [ + requests: list = [ InsertOne({'a': 1}), UpdateOne({'a': 1}, {'$set': {'b': 1}}), UpdateOne({'a': 2}, {'$set': {'b': 2}}, upsert=True), @@ -564,7 +494,7 @@ def test_single_unordered_batch(self): def test_single_error_unordered_batch(self): self.coll.create_index('a', unique=True) self.addCleanup(self.coll.drop_index, [('a', 1)]) - requests = [ + requests: list = [ InsertOne({'b': 1, 'a': 1}), UpdateOne({'b': 2}, {'$set': {'a': 1}}, upsert=True), InsertOne({'b': 3, 'a': 2}), @@ -599,7 +529,7 @@ def test_single_error_unordered_batch(self): def test_multiple_error_unordered_batch(self): self.coll.create_index('a', unique=True) self.addCleanup(self.coll.drop_index, [('a', 1)]) - requests = [ + requests: list = [ InsertOne({'b': 1, 'a': 1}), UpdateOne({'b': 2}, {'$set': {'a': 3}}, upsert=True), UpdateOne({'b': 3}, {'$set': {'a': 4}}, upsert=True), @@ -662,7 +592,7 @@ def test_large_inserts_ordered(self): self.coll.delete_many({}) big = 'x' * (1024 * 1024 * 4) - result = self.coll.bulk_write([ + write_result = self.coll.bulk_write([ InsertOne({'a': 1, 'big': big}), InsertOne({'a': 2, 'big': big}), InsertOne({'a': 3, 'big': big}), @@ -671,7 +601,7 @@ def test_large_inserts_ordered(self): InsertOne({'a': 6, 'big': big}), ]) - self.assertEqual(6, result.inserted_count) + self.assertEqual(6, write_result.inserted_count) self.assertEqual(6, self.coll.count_documents({})) def test_large_inserts_unordered(self): @@ -685,12 +615,12 @@ def test_large_inserts_unordered(self): try: self.coll.bulk_write(requests, ordered=False) except BulkWriteError as exc: - result = exc.details + details = exc.details self.assertEqual(exc.code, 65) else: self.fail("Error not raised") - self.assertEqual(2, result['nInserted']) + self.assertEqual(2, details['nInserted']) self.coll.delete_many({}) @@ -741,7 +671,7 @@ def tearDown(self): self.coll.delete_many({}) def test_no_results_ordered_success(self): - requests = [ + requests: list = [ InsertOne({'a': 1}), UpdateOne({'a': 3}, {'$set': {'b': 1}}, upsert=True), InsertOne({'a': 2}), @@ -755,7 +685,7 @@ def test_no_results_ordered_success(self): 'removed {"_id": 1}') def test_no_results_ordered_failure(self): - requests = [ + requests: list = [ InsertOne({'_id': 1}), UpdateOne({'_id': 3}, {'$set': {'b': 1}}, upsert=True), InsertOne({'_id': 2}), @@ -771,7 +701,7 @@ def test_no_results_ordered_failure(self): self.assertEqual({'_id': 1}, self.coll.find_one({'_id': 1})) def test_no_results_unordered_success(self): - requests = [ + requests: list = [ InsertOne({'a': 1}), UpdateOne({'a': 3}, {'$set': {'b': 1}}, upsert=True), InsertOne({'a': 2}), @@ -785,7 +715,7 @@ def test_no_results_unordered_success(self): 'removed {"_id": 1}') def test_no_results_unordered_failure(self): - requests = [ + requests: list = [ InsertOne({'_id': 1}), UpdateOne({'_id': 3}, {'$set': {'b': 1}}, upsert=True), InsertOne({'_id': 2}), @@ -832,13 +762,15 @@ def test_no_remove(self): class TestBulkWriteConcern(BulkTestBase): + w: Optional[int] + secondary: MongoClient @classmethod def setUpClass(cls): super(TestBulkWriteConcern, cls).setUpClass() cls.w = client_context.w - cls.secondary = None - if cls.w > 1: + cls.secondary = None # type: ignore[assignment] + if cls.w is not None and cls.w > 1: for member in client_context.hello['hosts']: if member != client_context.hello['primary']: cls.secondary = single_client(*partition_node(member)) @@ -886,7 +818,7 @@ def test_write_concern_failure_ordered(self): try: self.cause_wtimeout(requests, ordered=True) except BulkWriteError as exc: - result = exc.details + details = exc.details self.assertEqual(exc.code, 65) else: self.fail("Error not raised") @@ -899,13 +831,13 @@ def test_write_concern_failure_ordered(self): 'nRemoved': 0, 'upserted': [], 'writeErrors': []}, - result) + details) # When talking to legacy servers there will be a # write concern error for each operation. - self.assertTrue(len(result['writeConcernErrors']) > 0) + self.assertTrue(len(details['writeConcernErrors']) > 0) - failed = result['writeConcernErrors'][0] + failed = details['writeConcernErrors'][0] self.assertEqual(64, failed['code']) self.assertTrue(isinstance(failed['errmsg'], str)) @@ -924,7 +856,7 @@ def test_write_concern_failure_ordered(self): try: self.cause_wtimeout(requests, ordered=True) except BulkWriteError as exc: - result = exc.details + details = exc.details self.assertEqual(exc.code, 65) else: self.fail("Error not raised") @@ -941,10 +873,10 @@ def test_write_concern_failure_ordered(self): 'code': 11000, 'errmsg': '...', 'op': {'_id': '...', 'a': 1}}]}, - result) + details) - self.assertTrue(len(result['writeConcernErrors']) > 1) - failed = result['writeErrors'][0] + self.assertTrue(len(details['writeConcernErrors']) > 1) + failed = details['writeErrors'][0] self.assertTrue("duplicate" in failed['errmsg']) @client_context.require_replica_set @@ -966,17 +898,17 @@ def test_write_concern_failure_unordered(self): try: self.cause_wtimeout(requests, ordered=False) except BulkWriteError as exc: - result = exc.details + details = exc.details self.assertEqual(exc.code, 65) else: self.fail("Error not raised") - self.assertEqual(2, result['nInserted']) - self.assertEqual(1, result['nUpserted']) - self.assertEqual(0, len(result['writeErrors'])) + self.assertEqual(2, details['nInserted']) + self.assertEqual(1, details['nUpserted']) + self.assertEqual(0, len(details['writeErrors'])) # When talking to legacy servers there will be a # write concern error for each operation. - self.assertTrue(len(result['writeConcernErrors']) > 1) + self.assertTrue(len(details['writeConcernErrors']) > 1) self.coll.delete_many({}) self.coll.create_index('a', unique=True) @@ -984,7 +916,7 @@ def test_write_concern_failure_unordered(self): # Fail due to write concern support as well # as duplicate key error on unordered batch. - requests = [ + requests: list = [ InsertOne({'a': 1}), UpdateOne({'a': 3}, {'$set': {'a': 3, 'b': 1}}, upsert=True), InsertOne({'a': 1}), @@ -993,29 +925,29 @@ def test_write_concern_failure_unordered(self): try: self.cause_wtimeout(requests, ordered=False) except BulkWriteError as exc: - result = exc.details + details = exc.details self.assertEqual(exc.code, 65) else: self.fail("Error not raised") - self.assertEqual(2, result['nInserted']) - self.assertEqual(1, result['nUpserted']) - self.assertEqual(1, len(result['writeErrors'])) + self.assertEqual(2, details['nInserted']) + self.assertEqual(1, details['nUpserted']) + self.assertEqual(1, len(details['writeErrors'])) # When talking to legacy servers there will be a # write concern error for each operation. - self.assertTrue(len(result['writeConcernErrors']) > 1) + self.assertTrue(len(details['writeConcernErrors']) > 1) - failed = result['writeErrors'][0] + failed = details['writeErrors'][0] self.assertEqual(2, failed['index']) self.assertEqual(11000, failed['code']) self.assertTrue(isinstance(failed['errmsg'], str)) self.assertEqual(1, failed['op']['a']) - failed = result['writeConcernErrors'][0] + failed = details['writeConcernErrors'][0] self.assertEqual(64, failed['code']) self.assertTrue(isinstance(failed['errmsg'], str)) - upserts = result['upserted'] + upserts = details['upserted'] self.assertEqual(1, len(upserts)) self.assertEqual(1, upserts[0]['index']) self.assertTrue(upserts[0].get('_id')) diff --git a/test/test_change_stream.py b/test/test_change_stream.py index a49f6972b2..655b99e801 100644 --- a/test/test_change_stream.py +++ b/test/test_change_stream.py @@ -24,6 +24,7 @@ import uuid from itertools import product +from typing import no_type_check sys.path[0:0] = [''] @@ -121,6 +122,7 @@ def kill_change_stream_cursor(self, change_stream): class APITestsMixin(object): + @no_type_check def test_watch(self): with self.change_stream( [{'$project': {'foo': 0}}], full_document='updateLookup', @@ -145,6 +147,7 @@ def test_watch(self): with self.change_stream(resume_after=resume_token): pass + @no_type_check def test_try_next(self): # ChangeStreams only read majority committed data so use w:majority. coll = self.watched_collection().with_options( @@ -161,6 +164,7 @@ def test_try_next(self): wait_until(lambda: stream.try_next() is not None, "get change from try_next") + @no_type_check def test_try_next_runs_one_getmore(self): listener = EventListener() client = rs_or_single_client(event_listeners=[listener]) @@ -216,6 +220,7 @@ def test_try_next_runs_one_getmore(self): set(["getMore"])) self.assertIsNone(stream.try_next()) + @no_type_check def test_batch_size_is_honored(self): listener = EventListener() client = rs_or_single_client(event_listeners=[listener]) @@ -245,6 +250,7 @@ def test_batch_size_is_honored(self): self.assertEqual(expected[key], cmd[key]) # $changeStream.startAtOperationTime was added in 4.0.0. + @no_type_check @client_context.require_version_min(4, 0, 0) def test_start_at_operation_time(self): optime = self.get_start_at_operation_time() @@ -258,6 +264,7 @@ def test_start_at_operation_time(self): for i in range(ndocs): cs.next() + @no_type_check def _test_full_pipeline(self, expected_cs_stage): client, listener = self.client_with_listener("aggregate") results = listener.results @@ -273,12 +280,14 @@ def _test_full_pipeline(self, expected_cs_stage): {'$project': {'foo': 0}}], command.command['pipeline']) + @no_type_check def test_full_pipeline(self): """$changeStream must be the first stage in a change stream pipeline sent to the server. """ self._test_full_pipeline({}) + @no_type_check def test_iteration(self): with self.change_stream(batch_size=2) as change_stream: num_inserted = 10 @@ -292,6 +301,7 @@ def test_iteration(self): break self._test_invalidate_stops_iteration(change_stream) + @no_type_check def _test_next_blocks(self, change_stream): inserted_doc = {'_id': ObjectId()} changes = [] @@ -311,18 +321,21 @@ def _test_next_blocks(self, change_stream): self.assertEqual(changes[0]['operationType'], 'insert') self.assertEqual(changes[0]['fullDocument'], inserted_doc) + @no_type_check def test_next_blocks(self): """Test that next blocks until a change is readable""" # Use a short await time to speed up the test. with self.change_stream(max_await_time_ms=250) as change_stream: self._test_next_blocks(change_stream) + @no_type_check def test_aggregate_cursor_blocks(self): """Test that an aggregate cursor blocks until a change is readable.""" with self.watched_collection().aggregate( [{'$changeStream': {}}], maxAwaitTimeMS=250) as change_stream: self._test_next_blocks(change_stream) + @no_type_check def test_concurrent_close(self): """Ensure a ChangeStream can be closed from another thread.""" # Use a short await time to speed up the test. @@ -338,6 +351,7 @@ def iterate_cursor(): t.join(3) self.assertFalse(t.is_alive()) + @no_type_check def test_unknown_full_document(self): """Must rely on the server to raise an error on unknown fullDocument. """ @@ -347,6 +361,7 @@ def test_unknown_full_document(self): except OperationFailure: pass + @no_type_check def test_change_operations(self): """Test each operation type.""" expected_ns = {'db': self.watched_collection().database.name, @@ -393,6 +408,7 @@ def test_change_operations(self): # Invalidate. self._test_get_invalidate_event(change_stream) + @no_type_check @client_context.require_version_min(4, 1, 1) def test_start_after(self): resume_token = self.get_resume_token(invalidate=True) @@ -408,6 +424,7 @@ def test_start_after(self): self.assertEqual(change['operationType'], 'insert') self.assertEqual(change['fullDocument'], {'_id': 2}) + @no_type_check @client_context.require_version_min(4, 1, 1) def test_start_after_resume_process_with_changes(self): resume_token = self.get_resume_token(invalidate=True) @@ -427,6 +444,7 @@ def test_start_after_resume_process_with_changes(self): self.assertEqual(change['operationType'], 'insert') self.assertEqual(change['fullDocument'], {'_id': 3}) + @no_type_check @client_context.require_no_mongos # Remove after SERVER-41196 @client_context.require_version_min(4, 1, 1) def test_start_after_resume_process_without_changes(self): @@ -444,12 +462,14 @@ def test_start_after_resume_process_without_changes(self): class ProseSpecTestsMixin(object): + @no_type_check def _client_with_listener(self, *commands): listener = AllowListEventListener(*commands) client = rs_or_single_client(event_listeners=[listener]) self.addCleanup(client.close) return client, listener + @no_type_check def _populate_and_exhaust_change_stream(self, change_stream, batch_size=3): self.watched_collection().insert_many( [{"data": k} for k in range(batch_size)]) @@ -485,6 +505,7 @@ def _get_expected_resume_token(self, stream, listener, response = listener.results['succeeded'][-1].reply return response['cursor']['postBatchResumeToken'] + @no_type_check def _test_raises_error_on_missing_id(self, expected_exception): """ChangeStream will raise an exception if the server response is missing the resume token. @@ -497,6 +518,7 @@ def _test_raises_error_on_missing_id(self, expected_exception): with self.assertRaises(StopIteration): next(change_stream) + @no_type_check def _test_update_resume_token(self, expected_rt_getter): """ChangeStream must continuously track the last seen resumeToken.""" client, listener = self._client_with_listener("aggregate", "getMore") @@ -536,6 +558,7 @@ def test_raises_error_on_missing_id_418minus(self): self._test_raises_error_on_missing_id(InvalidOperation) # Prose test no. 3 + @no_type_check def test_resume_on_error(self): with self.change_stream() as change_stream: self.insert_one_and_check(change_stream, {'_id': 1}) @@ -544,6 +567,7 @@ def test_resume_on_error(self): self.insert_one_and_check(change_stream, {'_id': 2}) # Prose test no. 4 + @no_type_check @client_context.require_failCommand_fail_point def test_no_resume_attempt_if_aggregate_command_fails(self): # Set non-retryable error on aggregate command. @@ -568,6 +592,7 @@ def test_no_resume_attempt_if_aggregate_command_fails(self): # each operation which ensure compliance with this prose test. # Prose test no. 7 + @no_type_check def test_initial_empty_batch(self): with self.change_stream() as change_stream: # The first batch should be empty. @@ -579,6 +604,7 @@ def test_initial_empty_batch(self): self.assertEqual(cursor_id, change_stream._cursor.cursor_id) # Prose test no. 8 + @no_type_check def test_kill_cursors(self): def raise_error(): raise ServerSelectionTimeoutError('mock error') @@ -591,6 +617,7 @@ def raise_error(): self.insert_one_and_check(change_stream, {'_id': 2}) # Prose test no. 9 + @no_type_check @client_context.require_version_min(4, 0, 0) @client_context.require_version_max(4, 0, 7) def test_start_at_operation_time_caching(self): @@ -619,6 +646,7 @@ def test_start_at_operation_time_caching(self): # This test is identical to prose test no. 3. # Prose test no. 11 + @no_type_check @client_context.require_version_min(4, 0, 7) def test_resumetoken_empty_batch(self): client, listener = self._client_with_listener("getMore") @@ -631,6 +659,7 @@ def test_resumetoken_empty_batch(self): response["cursor"]["postBatchResumeToken"]) # Prose test no. 11 + @no_type_check @client_context.require_version_min(4, 0, 7) def test_resumetoken_exhausted_batch(self): client, listener = self._client_with_listener("getMore") @@ -643,6 +672,7 @@ def test_resumetoken_exhausted_batch(self): response["cursor"]["postBatchResumeToken"]) # Prose test no. 12 + @no_type_check @client_context.require_version_max(4, 0, 7) def test_resumetoken_empty_batch_legacy(self): resume_point = self.get_resume_token() @@ -659,6 +689,7 @@ def test_resumetoken_empty_batch_legacy(self): self.assertEqual(resume_token, resume_point) # Prose test no. 12 + @no_type_check @client_context.require_version_max(4, 0, 7) def test_resumetoken_exhausted_batch_legacy(self): # Resume token is _id of last change. @@ -673,6 +704,7 @@ def test_resumetoken_exhausted_batch_legacy(self): self.assertEqual(change_stream.resume_token, change["_id"]) # Prose test no. 13 + @no_type_check def test_resumetoken_partially_iterated_batch(self): # When batch has been iterated up to but not including the last element. # Resume token should be _id of previous change document. @@ -686,6 +718,7 @@ def test_resumetoken_partially_iterated_batch(self): self.assertEqual(resume_token, change["_id"]) + @no_type_check def _test_resumetoken_uniterated_nonempty_batch(self, resume_option): # When the batch is not empty and hasn't been iterated at all. # Resume token should be same as the resume option used. @@ -704,17 +737,20 @@ def _test_resumetoken_uniterated_nonempty_batch(self, resume_option): self.assertEqual(resume_token, resume_point) # Prose test no. 14 + @no_type_check @client_context.require_no_mongos def test_resumetoken_uniterated_nonempty_batch_resumeafter(self): self._test_resumetoken_uniterated_nonempty_batch("resume_after") # Prose test no. 14 + @no_type_check @client_context.require_no_mongos @client_context.require_version_min(4, 1, 1) def test_resumetoken_uniterated_nonempty_batch_startafter(self): self._test_resumetoken_uniterated_nonempty_batch("start_after") # Prose test no. 17 + @no_type_check @client_context.require_version_min(4, 1, 1) def test_startafter_resume_uses_startafter_after_empty_getMore(self): # Resume should use startAfter after no changes have been returned. @@ -735,6 +771,7 @@ def test_startafter_resume_uses_startafter_after_empty_getMore(self): response.command["pipeline"][0]["$changeStream"].get("startAfter")) # Prose test no. 18 + @no_type_check @client_context.require_version_min(4, 1, 1) def test_startafter_resume_uses_resumeafter_after_nonempty_getMore(self): # Resume should use resumeAfter after some changes have been returned. @@ -757,6 +794,8 @@ def test_startafter_resume_uses_resumeafter_after_nonempty_getMore(self): class TestClusterChangeStream(TestChangeStreamBase, APITestsMixin): + dbs: list + @classmethod @client_context.require_version_min(4, 0, 0, -1) @client_context.require_no_mmap @@ -1045,6 +1084,7 @@ def test_read_concern(self): class TestAllLegacyScenarios(IntegrationTest): RUN_ON_LOAD_BALANCER = True + listener: AllowListEventListener @classmethod @client_context.require_connection diff --git a/test/test_client.py b/test/test_client.py index 8db1cb5621..1c89631f5e 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -28,6 +28,8 @@ import threading import warnings +from typing import no_type_check + sys.path[0:0] = [""] from bson import encode @@ -99,6 +101,7 @@ class ClientUnitTest(unittest.TestCase): """MongoClient tests that don't require a server.""" + client: MongoClient @classmethod @client_context.require_connection @@ -614,7 +617,7 @@ def test_constants(self): port are not overloaded. """ host, port = client_context.host, client_context.port - kwargs = client_context.default_client_options.copy() + kwargs: dict = client_context.default_client_options.copy() if client_context.auth_enabled: kwargs['username'] = db_user kwargs['password'] = db_pwd @@ -1111,6 +1114,7 @@ def test_socketKeepAlive(self): socket.SO_KEEPALIVE) self.assertTrue(keepalive) + @no_type_check def test_tz_aware(self): self.assertRaises(ValueError, MongoClient, tz_aware='foo') @@ -1140,7 +1144,7 @@ def test_ipv6(self): uri = "mongodb://%s[::1]:%d" % (auth_str, client_context.port) if client_context.is_rs: - uri += '/?replicaSet=' + client_context.replica_set_name + uri += '/?replicaSet=' + (client_context.replica_set_name or "") client = rs_or_single_client_noauth(uri) client.pymongo_test.test.insert_one({"dummy": "object"}) @@ -1379,7 +1383,7 @@ def init(self, *args): heartbeat_times.append(time.time()) try: - ServerHeartbeatStartedEvent.__init__ = init + ServerHeartbeatStartedEvent.__init__ = init # type: ignore listener = HeartbeatStartedListener() uri = "mongodb://%s:%d/?heartbeatFrequencyMS=500" % ( client_context.host, client_context.port) @@ -1394,7 +1398,7 @@ def init(self, *args): client.close() finally: - ServerHeartbeatStartedEvent.__init__ = old_init + ServerHeartbeatStartedEvent.__init__ = old_init # type: ignore def test_small_heartbeat_frequency_ms(self): uri = "mongodb://example/?heartbeatFrequencyMS=499" @@ -1847,7 +1851,7 @@ def test(collection): lazy_client_trial(reset, delete_one, test, self._get_client) def test_find_one(self): - results = [] + results: list = [] def reset(collection): collection.drop() diff --git a/test/test_cmap.py b/test/test_cmap.py index 20ed7f31ec..bfc600f19f 100644 --- a/test/test_cmap.py +++ b/test/test_cmap.py @@ -213,11 +213,11 @@ def set_fail_point(self, command_args): def run_scenario(self, scenario_def, test): """Run a CMAP spec test.""" - self.logs = [] + self.logs: list = [] self.assertEqual(scenario_def['version'], 1) self.assertIn(scenario_def['style'], ['unit', 'integration']) self.listener = CMAPListener() - self._ops = [] + self._ops: list = [] # Configure the fail point before creating the client. if 'failPoint' in test: @@ -259,9 +259,9 @@ def run_scenario(self, scenario_def, test): self.pool = list(client._topology._servers.values())[0].pool # Map of target names to Thread objects. - self.targets = dict() + self.targets: dict = dict() # Map of label names to Connection objects - self.labels = dict() + self.labels: dict = dict() def cleanup(): for t in self.targets.values(): diff --git a/test/test_code.py b/test/test_code.py index c5e190f363..f392865ab5 100644 --- a/test/test_code.py +++ b/test/test_code.py @@ -17,6 +17,7 @@ """Tests for the Code wrapper.""" import sys +from typing import Any, cast sys.path[0:0] = [""] from bson.code import Code @@ -35,7 +36,7 @@ def test_read_only(self): c = Code("blah") def set_c(): - c.scope = 5 + cast(Any, c).scope = 5 self.assertRaises(AttributeError, set_c) def test_code(self): @@ -59,7 +60,7 @@ def test_code(self): def test_repr(self): c = Code("hello world", {}) self.assertEqual(repr(c), "Code('hello world', {})") - c.scope["foo"] = "bar" + cast(Any, c).scope["foo"] = "bar" self.assertEqual(repr(c), "Code('hello world', {'foo': 'bar'})") c = Code("hello world", {"blah": 3}) self.assertEqual(repr(c), "Code('hello world', {'blah': 3})") diff --git a/test/test_collation.py b/test/test_collation.py index f0139b4a22..9c4f4f6576 100644 --- a/test/test_collation.py +++ b/test/test_collation.py @@ -17,6 +17,8 @@ import functools import warnings +from typing import Any + from pymongo.collation import ( Collation, CollationCaseFirst, CollationStrength, CollationAlternate, @@ -78,6 +80,10 @@ def test_constructor(self): class TestCollation(IntegrationTest): + listener: EventListener + warn_context: Any + collation: Collation + @classmethod @client_context.require_connection def setUpClass(cls): diff --git a/test/test_collection.py b/test/test_collection.py index 4a167bacb3..98e1ca6bb3 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -20,8 +20,11 @@ import re import sys -from codecs import utf_8_decode +from codecs import utf_8_decode # type: ignore from collections import defaultdict +from typing import no_type_check + +from pymongo.database import Database sys.path[0:0] = [""] @@ -66,6 +69,7 @@ class TestCollectionNoConnect(unittest.TestCase): """Test Collection features on a client that does not connect. """ + db: Database @classmethod def setUpClass(cls): @@ -116,11 +120,12 @@ def test_iteration(self): class TestCollection(IntegrationTest): + w: int @classmethod def setUpClass(cls): super(TestCollection, cls).setUpClass() - cls.w = client_context.w + cls.w = client_context.w # type: ignore @classmethod def tearDownClass(cls): @@ -726,7 +731,7 @@ def test_insert_many(self): db = self.db db.test.drop() - docs = [{} for _ in range(5)] + docs: list = [{} for _ in range(5)] result = db.test.insert_many(docs) self.assertTrue(isinstance(result, InsertManyResult)) self.assertTrue(isinstance(result.inserted_ids, list)) @@ -759,7 +764,7 @@ def test_insert_many(self): db = db.client.get_database(db.name, write_concern=WriteConcern(w=0)) - docs = [{} for _ in range(5)] + docs: list = [{} for _ in range(5)] result = db.test.insert_many(docs) self.assertTrue(isinstance(result, InsertManyResult)) self.assertFalse(result.acknowledged) @@ -792,11 +797,11 @@ def test_insert_many_invalid(self): with self.assertRaisesRegex( TypeError, "documents must be a non-empty list"): - db.test.insert_many(1) + db.test.insert_many(1) # type: ignore[arg-type] with self.assertRaisesRegex( TypeError, "documents must be a non-empty list"): - db.test.insert_many(RawBSONDocument(encode({'_id': 2}))) + db.test.insert_many(RawBSONDocument(encode({'_id': 2}))) # type: ignore[arg-type] def test_delete_one(self): self.db.test.drop() @@ -1064,7 +1069,7 @@ def test_bypass_document_validation_bulk_write(self): db_w0 = self.db.client.get_database( self.db.name, write_concern=WriteConcern(w=0)) - ops = [InsertOne({"a": -10}), + ops: list = [InsertOne({"a": -10}), InsertOne({"a": -11}), InsertOne({"a": -12}), UpdateOne({"a": {"$lte": -10}}, {"$inc": {"a": 1}}), @@ -1087,7 +1092,7 @@ def test_bypass_document_validation_bulk_write(self): def test_find_by_default_dct(self): db = self.db db.test.insert_one({'foo': 'bar'}) - dct = defaultdict(dict, [('foo', 'bar')]) + dct = defaultdict(dict, [('foo', 'bar')]) # type: ignore[arg-type] self.assertIsNotNone(db.test.find_one(dct)) self.assertEqual(dct, defaultdict(dict, [('foo', 'bar')])) @@ -1117,6 +1122,7 @@ def test_find_w_fields(self): doc = next(db.test.find({}, ["mike"])) self.assertFalse("extra thing" in doc) + @no_type_check def test_fields_specifier_as_dict(self): db = self.db db.test.delete_many({}) @@ -1333,7 +1339,7 @@ def test_replace_one(self): self.assertTrue(result.acknowledged) self.assertEqual(1, db.test.count_documents({"y": 1})) self.assertEqual(0, db.test.count_documents({"x": 1})) - self.assertEqual(db.test.find_one(id1)["y"], 1) + self.assertEqual(db.test.find_one(id1)["y"], 1) # type: ignore replacement = RawBSONDocument(encode({"_id": id1, "z": 1})) result = db.test.replace_one({"y": 1}, replacement, True) @@ -1344,7 +1350,7 @@ def test_replace_one(self): self.assertTrue(result.acknowledged) self.assertEqual(1, db.test.count_documents({"z": 1})) self.assertEqual(0, db.test.count_documents({"y": 1})) - self.assertEqual(db.test.find_one(id1)["z"], 1) + self.assertEqual(db.test.find_one(id1)["z"], 1) # type: ignore result = db.test.replace_one({"x": 2}, {"y": 2}, True) self.assertTrue(isinstance(result, UpdateResult)) @@ -1377,7 +1383,7 @@ def test_update_one(self): self.assertTrue(result.modified_count in (None, 1)) self.assertIsNone(result.upserted_id) self.assertTrue(result.acknowledged) - self.assertEqual(db.test.find_one(id1)["x"], 6) + self.assertEqual(db.test.find_one(id1)["x"], 6) # type: ignore id2 = db.test.insert_one({"x": 1}).inserted_id result = db.test.update_one({"x": 6}, {"$inc": {"x": 1}}) @@ -1386,8 +1392,8 @@ def test_update_one(self): self.assertTrue(result.modified_count in (None, 1)) self.assertIsNone(result.upserted_id) self.assertTrue(result.acknowledged) - self.assertEqual(db.test.find_one(id1)["x"], 7) - self.assertEqual(db.test.find_one(id2)["x"], 1) + self.assertEqual(db.test.find_one(id1)["x"], 7) # type: ignore + self.assertEqual(db.test.find_one(id2)["x"], 1) # type: ignore result = db.test.update_one({"x": 2}, {"$set": {"y": 1}}, True) self.assertTrue(isinstance(result, UpdateResult)) @@ -1587,12 +1593,12 @@ def test_aggregation_cursor(self): # Test that batchSize is handled properly. cursor = db.test.aggregate([], batchSize=5) - self.assertEqual(5, len(cursor._CommandCursor__data)) + self.assertEqual(5, len(cursor._CommandCursor__data)) # type: ignore # Force a getMore - cursor._CommandCursor__data.clear() + cursor._CommandCursor__data.clear() # type: ignore next(cursor) # batchSize - 1 - self.assertEqual(4, len(cursor._CommandCursor__data)) + self.assertEqual(4, len(cursor._CommandCursor__data)) # type: ignore # Exhaust the cursor. There shouldn't be any errors. for doc in cursor: pass @@ -1679,6 +1685,7 @@ def test_rename(self): with self.write_concern_collection() as coll: coll.rename('foo') + @no_type_check def test_find_one(self): db = self.db db.drop_collection("test") @@ -1727,8 +1734,8 @@ def test_find_one_with_find_args(self): db.test.insert_many([{"x": i} for i in range(1, 4)]) - self.assertEqual(1, db.test.find_one()["x"]) - self.assertEqual(2, db.test.find_one(skip=1, limit=2)["x"]) + self.assertEqual(1, db.test.find_one()["x"]) # type: ignore[index] + self.assertEqual(2, db.test.find_one(skip=1, limit=2)["x"]) # type: ignore[index] def test_find_with_sort(self): db = self.db @@ -1736,9 +1743,9 @@ def test_find_with_sort(self): db.test.insert_many([{"x": 2}, {"x": 1}, {"x": 3}]) - self.assertEqual(2, db.test.find_one()["x"]) - self.assertEqual(1, db.test.find_one(sort=[("x", 1)])["x"]) - self.assertEqual(3, db.test.find_one(sort=[("x", -1)])["x"]) + self.assertEqual(2, db.test.find_one()["x"]) # type: ignore[index] + self.assertEqual(1, db.test.find_one(sort=[("x", 1)])["x"]) # type: ignore[index] + self.assertEqual(3, db.test.find_one(sort=[("x", -1)])["x"]) # type: ignore[index] def to_list(things): return [thing["x"] for thing in things] @@ -1973,17 +1980,17 @@ def __getattr__(self, name): bad = BadGetAttr([('foo', 'bar')]) c.insert_one({'bad': bad}) - self.assertEqual('bar', c.find_one()['bad']['foo']) + self.assertEqual('bar', c.find_one()['bad']['foo']) # type: ignore def test_array_filters_validation(self): # array_filters must be a list. c = self.db.test with self.assertRaises(TypeError): - c.update_one({}, {'$set': {'a': 1}}, array_filters={}) + c.update_one({}, {'$set': {'a': 1}}, array_filters={}) # type: ignore[arg-type] with self.assertRaises(TypeError): - c.update_many({}, {'$set': {'a': 1}}, array_filters={}) + c.update_many({}, {'$set': {'a': 1}}, array_filters={} ) # type: ignore[arg-type] with self.assertRaises(TypeError): - c.find_one_and_update({}, {'$set': {'a': 1}}, array_filters={}) + c.find_one_and_update({}, {'$set': {'a': 1}}, array_filters={}) # type: ignore[arg-type] def test_array_filters_unacknowledged(self): c_w0 = self.db.test.with_options(write_concern=WriteConcern(w=0)) @@ -2158,7 +2165,7 @@ def test_find_regex(self): c.drop() c.insert_one({'r': re.compile('.*')}) - self.assertTrue(isinstance(c.find_one()['r'], Regex)) + self.assertTrue(isinstance(c.find_one()['r'], Regex)) # type: ignore for doc in c.find(): self.assertTrue(isinstance(doc['r'], Regex)) @@ -2189,9 +2196,9 @@ def test_helpers_with_let(self): for helper, args in helpers: with self.assertRaisesRegex(TypeError, "let must be an instance of dict"): - helper(*args, let=let) + helper(*args, let=let) # type: ignore for helper, args in helpers: - helper(*args, let={}) + helper(*args, let={}) # type: ignore if __name__ == "__main__": diff --git a/test/test_command_monitoring_legacy.py b/test/test_command_monitoring_legacy.py index 7ff80d75e5..a05dbd9668 100644 --- a/test/test_command_monitoring_legacy.py +++ b/test/test_command_monitoring_legacy.py @@ -43,6 +43,8 @@ def camel_to_snake(camel): class TestAllScenarios(unittest.TestCase): + listener: EventListener + client: MongoClient @classmethod @client_context.require_connection diff --git a/test/test_common.py b/test/test_common.py index dcd618c509..7d7a26c278 100644 --- a/test/test_common.py +++ b/test/test_common.py @@ -50,13 +50,13 @@ def test_uuid_representation(self): "uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) legacy_opts = coll.codec_options coll.insert_one({'uu': uu}) - self.assertEqual(uu, coll.find_one({'uu': uu})['uu']) + self.assertEqual(uu, coll.find_one({'uu': uu})['uu']) # type: ignore coll = self.db.get_collection( "uuid", CodecOptions(uuid_representation=STANDARD)) self.assertEqual(STANDARD, coll.codec_options.uuid_representation) self.assertEqual(None, coll.find_one({'uu': uu})) uul = Binary.from_uuid(uu, PYTHON_LEGACY) - self.assertEqual(uul, coll.find_one({'uu': uul})['uu']) + self.assertEqual(uul, coll.find_one({'uu': uul})['uu']) # type: ignore # Test count_documents self.assertEqual(0, coll.count_documents({'uu': uu})) @@ -81,9 +81,9 @@ def test_uuid_representation(self): coll.update_one({'_id': uu}, {'$set': {'i': 2}}) coll = self.db.get_collection( "uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) - self.assertEqual(1, coll.find_one({'_id': uu})['i']) + self.assertEqual(1, coll.find_one({'_id': uu})['i']) # type: ignore coll.update_one({'_id': uu}, {'$set': {'i': 2}}) - self.assertEqual(2, coll.find_one({'_id': uu})['i']) + self.assertEqual(2, coll.find_one({'_id': uu})['i']) # type: ignore # Test Cursor.distinct self.assertEqual([2], coll.find({'_id': uu}).distinct('i')) @@ -98,7 +98,7 @@ def test_uuid_representation(self): "uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) self.assertEqual(2, coll.find_one_and_update({'_id': uu}, {'$set': {'i': 5}})['i']) - self.assertEqual(5, coll.find_one({'_id': uu})['i']) + self.assertEqual(5, coll.find_one({'_id': uu})['i']) # type: ignore # Test command self.assertEqual(5, self.db.command( diff --git a/test/test_connections_survive_primary_stepdown_spec.py b/test/test_connections_survive_primary_stepdown_spec.py index 894b14becd..e683974b03 100644 --- a/test/test_connections_survive_primary_stepdown_spec.py +++ b/test/test_connections_survive_primary_stepdown_spec.py @@ -20,6 +20,7 @@ from bson import SON from pymongo import monitoring +from pymongo.collection import Collection from pymongo.errors import NotPrimaryError from pymongo.write_concern import WriteConcern @@ -33,6 +34,9 @@ class TestConnectionsSurvivePrimaryStepDown(IntegrationTest): + listener: CMAPListener + coll: Collection + @classmethod @client_context.require_replica_set def setUpClass(cls): @@ -111,7 +115,7 @@ def run_scenario(self, error_code, retry, pool_status_checker): # Insert record and verify failure. with self.assertRaises(NotPrimaryError) as exc: self.coll.insert_one({"test": 1}) - self.assertEqual(exc.exception.details['code'], error_code) + self.assertEqual(exc.exception.details['code'], error_code) # type: ignore # Retry before CMAPListener assertion if retry_before=True. if retry: self.coll.insert_one({"test": 1}) diff --git a/test/test_crud_v1.py b/test/test_crud_v1.py index 5a63e030fe..4399d9f223 100644 --- a/test/test_crud_v1.py +++ b/test/test_crud_v1.py @@ -53,7 +53,7 @@ def check_result(self, expected_result, result): # SPEC-869: Only BulkWriteResult has upserted_count. if (prop == "upserted_count" and not isinstance(result, BulkWriteResult)): - if result.upserted_id is not None: + if result.upserted_id is not None: # type: ignore upserted_count = 1 else: upserted_count = 0 @@ -69,14 +69,14 @@ def check_result(self, expected_result, result): ids = expected_result[res] if isinstance(ids, dict): ids = [ids[str(i)] for i in range(len(ids))] - self.assertEqual(ids, result.inserted_ids, msg) + self.assertEqual(ids, result.inserted_ids, msg) # type: ignore elif prop == "upserted_ids": # Convert indexes from strings to integers. ids = expected_result[res] expected_ids = {} for str_index in ids: expected_ids[int(str_index)] = ids[str_index] - self.assertEqual(expected_ids, result.upserted_ids, msg) + self.assertEqual(expected_ids, result.upserted_ids, msg) # type: ignore else: self.assertEqual( getattr(result, prop), expected_result[res], msg) diff --git a/test/test_cursor.py b/test/test_cursor.py index 0b8ba049c2..fc804035a8 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -22,6 +22,8 @@ import time import threading +from typing import no_type_check + sys.path[0:0] = [""] from bson import decode_all @@ -57,8 +59,9 @@ def test_deepcopy_cursor_littered_with_regexes(self): re.compile("^key.*"): {"a": [re.compile("^hm.*")]}}) cursor2 = copy.deepcopy(cursor) - self.assertEqual(cursor._Cursor__spec, cursor2._Cursor__spec) + self.assertEqual(cursor._Cursor__spec, cursor2._Cursor__spec) # type: ignore + @no_type_check def test_add_remove_option(self): cursor = self.db.test.find() self.assertEqual(0, cursor._Cursor__query_flags) @@ -125,6 +128,7 @@ def test_add_remove_option(self): cursor.remove_option(128) self.assertEqual(0, cursor._Cursor__query_flags) + @no_type_check def test_add_remove_option_exhaust(self): # Exhaust - which mongos doesn't support if client_context.is_mongos: @@ -149,9 +153,9 @@ def test_allow_disk_use(self): self.assertRaises(TypeError, coll.find().allow_disk_use, 'baz') cursor = coll.find().allow_disk_use(True) - self.assertEqual(True, cursor._Cursor__allow_disk_use) + self.assertEqual(True, cursor._Cursor__allow_disk_use) # type: ignore cursor = coll.find().allow_disk_use(False) - self.assertEqual(False, cursor._Cursor__allow_disk_use) + self.assertEqual(False, cursor._Cursor__allow_disk_use) # type: ignore def test_max_time_ms(self): db = self.db @@ -165,15 +169,15 @@ def test_max_time_ms(self): coll.find().max_time_ms(1) cursor = coll.find().max_time_ms(999) - self.assertEqual(999, cursor._Cursor__max_time_ms) + self.assertEqual(999, cursor._Cursor__max_time_ms) # type: ignore cursor = coll.find().max_time_ms(10).max_time_ms(1000) - self.assertEqual(1000, cursor._Cursor__max_time_ms) + self.assertEqual(1000, cursor._Cursor__max_time_ms) # type: ignore cursor = coll.find().max_time_ms(999) c2 = cursor.clone() - self.assertEqual(999, c2._Cursor__max_time_ms) - self.assertTrue("$maxTimeMS" in cursor._Cursor__query_spec()) - self.assertTrue("$maxTimeMS" in c2._Cursor__query_spec()) + self.assertEqual(999, c2._Cursor__max_time_ms) # type: ignore + self.assertTrue("$maxTimeMS" in cursor._Cursor__query_spec()) # type: ignore + self.assertTrue("$maxTimeMS" in c2._Cursor__query_spec()) # type: ignore self.assertTrue(coll.find_one(max_time_ms=1000)) @@ -199,6 +203,7 @@ def test_max_time_ms(self): "maxTimeAlwaysTimeOut", mode="off") + @no_type_check def test_max_await_time_ms(self): db = self.db db.pymongo_test.drop() @@ -528,6 +533,7 @@ def test_min_max_without_hint(self): with self.assertRaises(InvalidOperation): list(coll.find().max([("j", 3)])) + @no_type_check def test_batch_size(self): db = self.db db.test.drop() @@ -581,6 +587,7 @@ def cursor_count(cursor, expected_count): next(cur) self.assertEqual(0, len(cur._Cursor__data)) + @no_type_check def test_limit_and_batch_size(self): db = self.db db.test.drop() @@ -829,6 +836,7 @@ def test_rewind(self): # oplog_reply, and snapshot are all deprecated. @ignore_deprecations + @no_type_check def test_clone(self): self.db.test.insert_many([{"x": i} for i in range(1, 4)]) @@ -889,7 +897,7 @@ def test_clone(self): # Every attribute should be the same. cursor2 = cursor.clone() - self.assertDictEqual(cursor.__dict__, cursor2.__dict__) + self.assertEqual(cursor.__dict__, cursor2.__dict__) # Shallow copies can so can mutate cursor2 = copy.copy(cursor) @@ -1002,6 +1010,7 @@ def test_getitem_slice_index(self): self.assertRaises(IndexError, lambda: self.db.test.find()[10:8]) + @no_type_check def test_getitem_numeric_index(self): self.db.drop_collection("test") self.db.test.insert_many([{"i": i} for i in range(100)]) @@ -1025,7 +1034,7 @@ def test_properties(self): self.assertEqual(self.db.test, self.db.test.find().collection) def set_coll(): - self.db.test.find().collection = "hello" + self.db.test.find().collection = "hello" # type: ignore self.assertRaises(AttributeError, set_coll) diff --git a/test/test_custom_types.py b/test/test_custom_types.py index 5db208ab7e..dc0a7f55e9 100644 --- a/test/test_custom_types.py +++ b/test/test_custom_types.py @@ -21,6 +21,7 @@ from collections import OrderedDict from decimal import Decimal from random import random +from typing import Any, Tuple, Type, no_type_check sys.path[0:0] = [""] @@ -89,7 +90,7 @@ def __eq__(self, other): class UndecipherableIntDecoder(TypeDecoder): - bson_type = Int64 + bson_type = Int64 # type: ignore[assignment] def transform_bson(self, value): return UndecipherableInt64Type(value) @@ -109,7 +110,7 @@ def transform_python(self, value): class UppercaseTextDecoder(TypeDecoder): - bson_type = str + bson_type = str # type: ignore[assignment] def transform_bson(self, value): return value.upper() @@ -127,6 +128,7 @@ def transform_bson(self, value): class CustomBSONTypeTests(object): + @no_type_check def roundtrip(self, doc): bsonbytes = encode(doc, codec_options=self.codecopts) rt_document = decode(bsonbytes, codec_options=self.codecopts) @@ -139,6 +141,7 @@ def test_encode_decode_roundtrip(self): self.roundtrip({'average': [[Decimal('56.47')]]}) self.roundtrip({'average': [{'b': Decimal('56.47')}]}) + @no_type_check def test_decode_all(self): documents = [] for dec in range(3): @@ -151,12 +154,14 @@ def test_decode_all(self): self.assertEqual( decode_all(bsonstream, self.codecopts), documents) + @no_type_check def test__bson_to_dict(self): document = {'average': Decimal('56.47')} rawbytes = encode(document, codec_options=self.codecopts) decoded_document = _bson_to_dict(rawbytes, self.codecopts) self.assertEqual(document, decoded_document) + @no_type_check def test__dict_to_bson(self): document = {'average': Decimal('56.47')} rawbytes = encode(document, codec_options=self.codecopts) @@ -172,12 +177,14 @@ def _generate_multidocument_bson_stream(self): bsonstream += encode(doc) return edocs, bsonstream + @no_type_check def test_decode_iter(self): expected, bson_data = self._generate_multidocument_bson_stream() for expected_doc, decoded_doc in zip( expected, decode_iter(bson_data, self.codecopts)): self.assertEqual(expected_doc, decoded_doc) + @no_type_check def test_decode_file_iter(self): expected, bson_data = self._generate_multidocument_bson_stream() fileobj = tempfile.TemporaryFile() @@ -195,7 +202,7 @@ class TestCustomPythonBSONTypeToBSONMonolithicCodec(CustomBSONTypeTests, unittest.TestCase): @classmethod def setUpClass(cls): - cls.codecopts = DECIMAL_CODECOPTS + cls.codecopts = DECIMAL_CODECOPTS # type: ignore[attr-defined] class TestCustomPythonBSONTypeToBSONMultiplexedCodec(CustomBSONTypeTests, @@ -204,7 +211,7 @@ class TestCustomPythonBSONTypeToBSONMultiplexedCodec(CustomBSONTypeTests, def setUpClass(cls): codec_options = CodecOptions( type_registry=TypeRegistry((DecimalEncoder(), DecimalDecoder()))) - cls.codecopts = codec_options + cls.codecopts = codec_options # type: ignore[attr-defined] class TestBSONFallbackEncoder(unittest.TestCase): @@ -293,6 +300,15 @@ def test_type_checks(self): class TestBSONCustomTypeEncoderAndFallbackEncoderTandem(unittest.TestCase): + + TypeA: Any + TypeB: Any + fallback_encoder_A2B: Any + fallback_encoder_A2BSON: Any + B2BSON: Type[TypeEncoder] + B2A: Type[TypeEncoder] + A2B: Type[TypeEncoder] + @classmethod def setUpClass(cls): class TypeA(object): @@ -378,6 +394,10 @@ def test_infinite_loop_exceeds_max_recursion_depth(self): class TestTypeRegistry(unittest.TestCase): + types: Tuple[object, object] + codecs: Tuple[Type[TypeCodec], Type[TypeCodec]] + fallback_encoder: Any + @classmethod def setUpClass(cls): class MyIntType(object): @@ -451,47 +471,47 @@ def assert_proper_initialization(type_registry, codec_instances): codec_instances_copy = list(codec_instances) codec_instances.pop(0) self.assertListEqual( - type_registry._TypeRegistry__type_codecs, codec_instances_copy) + type_registry._TypeRegistry__type_codecs, codec_instances_copy) # type: ignore[attr-defined] def test_simple_separate_codecs(self): class MyIntEncoder(TypeEncoder): - python_type = self.types[0] + python_type = self.types[0] # type: ignore[assignment] def transform_python(self, value): return value.x class MyIntDecoder(TypeDecoder): - bson_type = int + bson_type = int # type: ignore[assignment] def transform_bson(self, value): - return self.types[0](value) + return self.types[0](value) # type: ignore[attr-defined] - codec_instances = [MyIntDecoder(), MyIntEncoder()] + codec_instances: list = [MyIntDecoder(), MyIntEncoder()] type_registry = TypeRegistry(codec_instances) self.assertEqual( type_registry._encoder_map, - {MyIntEncoder.python_type: codec_instances[1].transform_python}) + {MyIntEncoder.python_type: codec_instances[1].transform_python}) # type: ignore self.assertEqual( type_registry._decoder_map, - {MyIntDecoder.bson_type: codec_instances[0].transform_bson}) + {MyIntDecoder.bson_type: codec_instances[0].transform_bson}) # type: ignore def test_initialize_fail(self): err_msg = ("Expected an instance of TypeEncoder, TypeDecoder, " "or TypeCodec, got .* instead") with self.assertRaisesRegex(TypeError, err_msg): - TypeRegistry(self.codecs) + TypeRegistry(self.codecs) # type: ignore[arg-type] with self.assertRaisesRegex(TypeError, err_msg): TypeRegistry([type('AnyType', (object,), {})()]) err_msg = "fallback_encoder %r is not a callable" % (True,) with self.assertRaisesRegex(TypeError, err_msg): - TypeRegistry([], True) + TypeRegistry([], True) # type: ignore[arg-type] err_msg = "fallback_encoder %r is not a callable" % ('hello',) with self.assertRaisesRegex(TypeError, err_msg): - TypeRegistry(fallback_encoder='hello') + TypeRegistry(fallback_encoder='hello') # type: ignore[arg-type] def test_type_registry_repr(self): codec_instances = [codec() for codec in self.codecs] @@ -525,7 +545,7 @@ def run_test(base, attrs): if pytype in [bool, type(None), RE_TYPE,]: continue - class MyType(pytype): + class MyType(pytype): # type: ignore pass attrs.update({'python_type': MyType, 'transform_python': lambda x: x}) @@ -598,7 +618,7 @@ def test_aggregate_w_custom_type_decoder(self): test = db.get_collection( 'test', codec_options=UNINT_DECODER_CODECOPTS) - pipeline = [ + pipeline: list = [ {'$match': {'status': 'complete'}}, {'$group': {'_id': "$status", 'total_qty': {"$sum": "$qty"}}},] result = test.aggregate(pipeline) @@ -680,15 +700,18 @@ def test_grid_out_custom_opts(self): class ChangeStreamsWCustomTypesTestMixin(object): + @no_type_check def change_stream(self, *args, **kwargs): return self.watched_target.watch(*args, **kwargs) + @no_type_check def insert_and_check(self, change_stream, insert_doc, expected_doc): self.input_target.insert_one(insert_doc) change = next(change_stream) self.assertEqual(change['fullDocument'], expected_doc) + @no_type_check def kill_change_stream_cursor(self, change_stream): # Cause a cursor not found error on the next getMore. cursor = change_stream._cursor @@ -696,6 +719,7 @@ def kill_change_stream_cursor(self, change_stream): client = self.input_target.database.client client._close_cursor_now(cursor.cursor_id, address) + @no_type_check def test_simple(self): codecopts = CodecOptions(type_registry=TypeRegistry([ UndecipherableIntEncoder(), UppercaseTextDecoder()])) @@ -718,6 +742,7 @@ def test_simple(self): self.kill_change_stream_cursor(change_stream) self.insert_and_check(change_stream, input_docs[2], expected_docs[2]) + @no_type_check def test_custom_type_in_pipeline(self): codecopts = CodecOptions(type_registry=TypeRegistry([ UndecipherableIntEncoder(), UppercaseTextDecoder()])) @@ -741,6 +766,7 @@ def test_custom_type_in_pipeline(self): self.kill_change_stream_cursor(change_stream) self.insert_and_check(change_stream, input_docs[2], expected_docs[1]) + @no_type_check def test_break_resume_token(self): # Get one document from a change stream to determine resumeToken type. self.create_targets() @@ -766,6 +792,7 @@ def test_break_resume_token(self): self.kill_change_stream_cursor(change_stream) self.insert_and_check(change_stream, docs[2], docs[2]) + @no_type_check def test_document_class(self): def run_test(doc_cls): codecopts = CodecOptions(type_registry=TypeRegistry([ diff --git a/test/test_data_lake.py b/test/test_data_lake.py index 2954efe651..0e52950250 100644 --- a/test/test_data_lake.py +++ b/test/test_data_lake.py @@ -94,8 +94,8 @@ def test_3(self): class DataLakeTestSpec(TestCrudV2): # Default test database and collection names. - TEST_DB = 'test' - TEST_COLLECTION = 'driverdata' + TEST_DB = 'test' # type: ignore + TEST_COLLECTION = 'driverdata' # type: ignore @classmethod @client_context.require_data_lake diff --git a/test/test_database.py b/test/test_database.py index 4adccc1b58..675148fc3b 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -17,6 +17,7 @@ import datetime import re import sys +from typing import Any sys.path[0:0] = [""] @@ -57,6 +58,7 @@ class TestDatabaseNoConnect(unittest.TestCase): """Test Database features on a client that does not connect. """ + client: MongoClient @classmethod def setUpClass(cls): @@ -143,7 +145,7 @@ def test_create_collection(self): test = db.create_collection("test") self.assertTrue("test" in db.list_collection_names()) test.insert_one({"hello": "world"}) - self.assertEqual(db.test.find_one()["hello"], "world") + self.assertEqual(db.test.find_one()["hello"], "world") # type: ignore db.drop_collection("test.foo") db.create_collection("test.foo") @@ -198,6 +200,7 @@ def test_list_collection_names_filter(self): self.assertNotIn("nameOnly", results["started"][0].command) # Should send nameOnly (except on 2.6). + filter: Any for filter in (None, {}, {'name': {'$in': ['capped', 'non_capped']}}): results.clear() names = db.list_collection_names(filter=filter) @@ -225,7 +228,7 @@ def test_list_collections(self): self.assertTrue("$" not in coll) # Duplicate check. - coll_cnt = {} + coll_cnt: dict = {} for coll in colls: try: # Found duplicate. @@ -233,7 +236,7 @@ def test_list_collections(self): self.assertTrue(False) except KeyError: coll_cnt[coll] = 1 - coll_cnt = {} + coll_cnt: dict = {} # Checking if is there any collection which don't exists. if (len(set(colls) - set(["test","test.mike"])) == 0 or @@ -466,6 +469,7 @@ def test_insert_find_one(self): self.assertEqual(None, db.test.find_one({"hello": "test"})) b = db.test.find_one() + assert b is not None b["hello"] = "mike" db.test.replace_one({"_id": b["_id"]}, b) @@ -482,12 +486,12 @@ def test_long(self): db = self.client.pymongo_test db.test.drop() db.test.insert_one({"x": 9223372036854775807}) - retrieved = db.test.find_one()['x'] + retrieved = db.test.find_one()['x'] # type: ignore self.assertEqual(Int64(9223372036854775807), retrieved) self.assertIsInstance(retrieved, Int64) db.test.delete_many({}) db.test.insert_one({"x": Int64(1)}) - retrieved = db.test.find_one()['x'] + retrieved = db.test.find_one()['x'] # type: ignore self.assertEqual(Int64(1), retrieved) self.assertIsInstance(retrieved, Int64) @@ -509,8 +513,8 @@ def test_delete(self): length += 1 self.assertEqual(length, 2) - db.test.delete_one(db.test.find_one()) - db.test.delete_one(db.test.find_one()) + db.test.delete_one(db.test.find_one()) # type: ignore[arg-type] + db.test.delete_one(db.test.find_one()) # type: ignore[arg-type] self.assertEqual(db.test.find_one(), None) db.test.insert_one({"x": 1}) @@ -585,7 +589,7 @@ def test_command_max_time_ms(self): db.command('count', 'test') self.assertRaises(ExecutionTimeout, db.command, 'count', 'test', maxTimeMS=1) - pipeline = [{'$project': {'name': 1, 'count': 1}}] + pipeline: list = [{'$project': {'name': 1, 'count': 1}}] # Database command helper. db.command('aggregate', 'test', pipeline=pipeline, cursor={}) self.assertRaises(ExecutionTimeout, db.command, @@ -625,7 +629,7 @@ def test_with_options(self): 'read_preference': ReadPreference.PRIMARY, 'write_concern': WriteConcern(w=1), 'read_concern': ReadConcern(level="local")} - db2 = db1.with_options(**newopts) + db2 = db1.with_options(**newopts) # type: ignore[arg-type] for opt in newopts: self.assertEqual( getattr(db2, opt), newopts.get(opt, getattr(db1, opt))) @@ -633,7 +637,7 @@ def test_with_options(self): class TestDatabaseAggregation(IntegrationTest): def setUp(self): - self.pipeline = [{"$listLocalSessions": {}}, + self.pipeline: list = [{"$listLocalSessions": {}}, {"$limit": 1}, {"$addFields": {"dummy": "dummy field"}}, {"$project": {"_id": 0, "dummy": 1}}] @@ -648,6 +652,7 @@ def test_database_aggregation(self): @client_context.require_no_mongos def test_database_aggregation_fake_cursor(self): coll_name = "test_output" + write_stage: dict if client_context.version < (4, 3): db_name = "admin" write_stage = {"$out": coll_name} @@ -662,7 +667,7 @@ def test_database_aggregation_fake_cursor(self): self.addCleanup(output_coll.drop) admin = self.admin.with_options(write_concern=WriteConcern(w=0)) - pipeline = self.pipeline[:] + pipeline: list = self.pipeline[:] pipeline.append(write_stage) with admin.aggregate(pipeline) as cursor: with self.assertRaises(StopIteration): diff --git a/test/test_dbref.py b/test/test_dbref.py index 964947351e..348b1d14de 100644 --- a/test/test_dbref.py +++ b/test/test_dbref.py @@ -16,6 +16,7 @@ import pickle import sys +from typing import Any sys.path[0:0] = [""] from bson import encode, decode @@ -44,10 +45,10 @@ def test_read_only(self): a = DBRef("coll", ObjectId()) def foo(): - a.collection = "blah" + a.collection = "blah" # type: ignore[misc] def bar(): - a.id = "aoeu" + a.id = "aoeu" # type: ignore[misc] self.assertEqual("coll", a.collection) a.id @@ -136,6 +137,7 @@ def test_dbref_hash(self): # https://github.com/mongodb/specifications/blob/master/source/dbref.rst#test-plan class TestDBRefSpec(unittest.TestCase): def test_decoding_1_2_3(self): + doc: Any for doc in [ # 1, Valid documents MUST be decoded to a DBRef: {"$ref": "coll0", "$id": ObjectId("60a6fe9a54f4180c86309efa")}, @@ -183,6 +185,7 @@ def test_decoding_4_5(self): self.assertIsInstance(dbref, dict) def test_encoding_1_2(self): + doc: Any for doc in [ # 1, Encoding DBRefs with basic fields: {"$ref": "coll0", "$id": ObjectId("60a6fe9a54f4180c86309efa")}, diff --git a/test/test_decimal128.py b/test/test_decimal128.py index 4ff25935dd..3988a4559a 100644 --- a/test/test_decimal128.py +++ b/test/test_decimal128.py @@ -35,6 +35,7 @@ def test_round_trip(self): b'\x00@cR\xbf\xc6\x01\x00\x00\x00\x00\x00\x00\x00\x1c0') coll.insert_one({'dec128': dec128}) doc = coll.find_one({'dec128': dec128}) + assert doc is not None self.assertIsNotNone(doc) self.assertEqual(doc['dec128'], dec128) diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index 107168f294..c3a50709ac 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -364,10 +364,12 @@ def _event_count(self, event): def marked_unknown(e): return (isinstance(e, monitoring.ServerDescriptionChangedEvent) and not e.new_description.is_server_type_known) + assert self.server_listener is not None return len(self.server_listener.matching(marked_unknown)) # Only support CMAP events for now. self.assertTrue(event.startswith('Pool') or event.startswith('Conn')) event_type = getattr(monitoring, event) + assert self.pool_listener is not None return self.pool_listener.event_count(event_type) def assert_event_count(self, event, count): diff --git a/test/test_dns.py b/test/test_dns.py index 8404c2aa69..218e8bbef1 100644 --- a/test/test_dns.py +++ b/test/test_dns.py @@ -18,6 +18,7 @@ import json import os import sys +from pymongo import client_options sys.path[0:0] = [""] diff --git a/test/test_encryption.py b/test/test_encryption.py index 8e47d44525..7ed9d92663 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -25,6 +25,10 @@ import traceback import uuid +from typing import Any + +from pymongo.collection import Collection + sys.path[0:0] = [""] from bson import encode, json_util @@ -126,6 +130,7 @@ def test_init_kms_tls_options(self): with self.assertRaisesRegex( TypeError, r'kms_tls_options\["kmip"\] must be a dict'): AutoEncryptionOpts({}, 'k.d', kms_tls_options={'kmip': 1}) + tls_opts: Any for tls_opts in [ {'kmip': {'tls': True, 'tlsInsecure': True}}, {'kmip': {'tls': True, 'tlsAllowInvalidCertificates': True}}, @@ -138,6 +143,7 @@ def test_init_kms_tls_options(self): AutoEncryptionOpts({}, 'k.d', kms_tls_options={ 'kmip': {'tlsCAFile': 'does-not-exist'}}) # Success cases: + tls_opts: Any for tls_opts in [None, {}]: opts = AutoEncryptionOpts({}, 'k.d', kms_tls_options=tls_opts) self.assertEqual(opts._kms_ssl_contexts, {}) @@ -432,14 +438,14 @@ def test_validation(self): msg = 'value to decrypt must be a bson.binary.Binary with subtype 6' with self.assertRaisesRegex(TypeError, msg): - client_encryption.decrypt('str') + client_encryption.decrypt('str') # type: ignore[arg-type] with self.assertRaisesRegex(TypeError, msg): client_encryption.decrypt(Binary(b'123')) msg = 'key_id must be a bson.binary.Binary with subtype 4' algo = Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic with self.assertRaisesRegex(TypeError, msg): - client_encryption.encrypt('str', algo, key_id=uuid.uuid4()) + client_encryption.encrypt('str', algo, key_id=uuid.uuid4()) # type: ignore[arg-type] with self.assertRaisesRegex(TypeError, msg): client_encryption.encrypt('str', algo, key_id=Binary(b'123')) @@ -459,7 +465,7 @@ def test_bson_errors(self): def test_codec_options(self): with self.assertRaisesRegex(TypeError, 'codec_options must be'): ClientEncryption( - KMS_PROVIDERS, 'keyvault.datakeys', client_context.client, None) + KMS_PROVIDERS, 'keyvault.datakeys', client_context.client, None) # type: ignore[arg-type] opts = CodecOptions(uuid_representation=JAVA_LEGACY) client_encryption_legacy = ClientEncryption( @@ -708,6 +714,10 @@ def create_key_vault(vault, *data_keys): class TestDataKeyDoubleEncryption(EncryptionIntegrationTest): + client_encrypted: MongoClient + client_encryption: ClientEncryption + listener: OvertCommandListener + vault: Any KMS_PROVIDERS = ALL_KMS_PROVIDERS @@ -776,7 +786,7 @@ def setUp(self): def run_test(self, provider_name): # Create data key. - master_key = self.MASTER_KEYS[provider_name] + master_key: Any = self.MASTER_KEYS[provider_name] datakey_id = self.client_encryption.create_data_key( provider_name, master_key=master_key, key_alt_names=['%s_altname' % (provider_name,)]) @@ -798,7 +808,7 @@ def run_test(self, provider_name): {'_id': provider_name, 'value': encrypted}) doc_decrypted = self.client_encrypted.db.coll.find_one( {'_id': provider_name}) - self.assertEqual(doc_decrypted['value'], 'hello %s' % (provider_name,)) + self.assertEqual(doc_decrypted['value'], 'hello %s' % (provider_name,)) # type: ignore # Encrypt by key_alt_name. encrypted_altname = self.client_encryption.encrypt( @@ -876,7 +886,7 @@ def _test_external_key_vault(self, with_external_key_vault): client_encrypted.db.coll.insert_one({"encrypted": "test"}) # AuthenticationFailed error. self.assertIsInstance(ctx.exception.cause, OperationFailure) - self.assertEqual(ctx.exception.cause.code, 18) + self.assertEqual(ctx.exception.cause.code, 18) # type: ignore[attr-defined] else: client_encrypted.db.coll.insert_one({"encrypted": "test"}) @@ -889,7 +899,7 @@ def _test_external_key_vault(self, with_external_key_vault): key_id=LOCAL_KEY_ID) # AuthenticationFailed error. self.assertIsInstance(ctx.exception.cause, OperationFailure) - self.assertEqual(ctx.exception.cause.code, 18) + self.assertEqual(ctx.exception.cause.code, 18) # type: ignore[attr-defined] else: client_encryption.encrypt( "test", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, @@ -985,7 +995,7 @@ def _test_corpus(self, opts): self.addCleanup(client_encryption.close) corpus = self.fix_up_curpus(json_data('corpus', 'corpus.json')) - corpus_copied = SON() + corpus_copied: SON = SON() for key, value in corpus.items(): corpus_copied[key] = copy.deepcopy(value) if key in ('_id', 'altname_aws', 'altname_azure', 'altname_gcp', @@ -1021,7 +1031,7 @@ def _test_corpus(self, opts): try: encrypted_val = client_encryption.encrypt( - value['value'], algo, **kwargs) + value['value'], algo, **kwargs) # type: ignore[arg-type] if not value['allowed']: self.fail('encrypt should have failed: %r: %r' % ( key, value)) @@ -1082,6 +1092,10 @@ def test_corpus_local_schema(self): class TestBsonSizeBatches(EncryptionIntegrationTest): """Prose tests for BSON size limits and batch splitting.""" + coll: Collection + coll_encrypted: Collection + client_encrypted: MongoClient + listener: OvertCommandListener @classmethod def setUpClass(cls): @@ -1388,6 +1402,7 @@ class AzureGCPEncryptionTestMixin(object): KMS_PROVIDER_MAP = None KEYVAULT_DB = 'keyvault' KEYVAULT_COLL = 'datakeys' + client: MongoClient def setUp(self): keyvault = self.client.get_database( @@ -1397,19 +1412,19 @@ def setUp(self): def _test_explicit(self, expectation): client_encryption = ClientEncryption( - self.KMS_PROVIDER_MAP, + self.KMS_PROVIDER_MAP, # type: ignore[arg-type] '.'.join([self.KEYVAULT_DB, self.KEYVAULT_COLL]), client_context.client, OPTS) - self.addCleanup(client_encryption.close) + self.addCleanup(client_encryption.close) # type: ignore[attr-defined] ciphertext = client_encryption.encrypt( 'string0', algorithm=Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, - key_id=Binary.from_uuid(self.DEK['_id'], STANDARD)) + key_id=Binary.from_uuid(self.DEK['_id'], STANDARD)) # type: ignore[index] - self.assertEqual(bytes(ciphertext), base64.b64decode(expectation)) - self.assertEqual(client_encryption.decrypt(ciphertext), 'string0') + self.assertEqual(bytes(ciphertext), base64.b64decode(expectation)) # type: ignore[attr-defined] + self.assertEqual(client_encryption.decrypt(ciphertext), 'string0') # type: ignore[attr-defined] def _test_automatic(self, expectation_extjson, payload): encrypted_db = "db" @@ -1417,15 +1432,15 @@ def _test_automatic(self, expectation_extjson, payload): keyvault_namespace = '.'.join([self.KEYVAULT_DB, self.KEYVAULT_COLL]) encryption_opts = AutoEncryptionOpts( - self.KMS_PROVIDER_MAP, + self.KMS_PROVIDER_MAP, # type: ignore[arg-type] keyvault_namespace, - schema_map=self.SCHEMA_MAP) + schema_map=self.SCHEMA_MAP) # type: ignore[attr-defined] insert_listener = AllowListEventListener('insert') client = rs_or_single_client( auto_encryption_opts=encryption_opts, event_listeners=[insert_listener]) - self.addCleanup(client.close) + self.addCleanup(client.close) # type: ignore[attr-defined] coll = client.get_database(encrypted_db).get_collection( encrypted_coll, codec_options=OPTS, @@ -1440,11 +1455,11 @@ def _test_automatic(self, expectation_extjson, payload): inserted_doc = event.command['documents'][0] for key, value in expected_document.items(): - self.assertEqual(value, inserted_doc[key]) + self.assertEqual(value, inserted_doc[key]) # type: ignore[attr-defined] output_doc = coll.find_one({}) for key, value in payload.items(): - self.assertEqual(output_doc[key], value) + self.assertEqual(output_doc[key], value) # type: ignore[attr-defined] class TestAzureEncryption(AzureGCPEncryptionTestMixin, @@ -1453,9 +1468,9 @@ class TestAzureEncryption(AzureGCPEncryptionTestMixin, @unittest.skipUnless(any(AZURE_CREDS.values()), 'Azure environment credentials are not set') def setUpClass(cls): - cls.KMS_PROVIDER_MAP = {'azure': AZURE_CREDS} + cls.KMS_PROVIDER_MAP = {'azure': AZURE_CREDS} # type: ignore[assignment] cls.DEK = json_data(BASE, 'custom', 'azure-dek.json') - cls.SCHEMA_MAP = json_data(BASE, 'custom', 'azure-gcp-schema.json') + cls.SCHEMA_MAP = json_data(BASE, 'custom', 'azure-gcp-schema.json') # type: ignore[attr-defined] super(TestAzureEncryption, cls).setUpClass() def test_explicit(self): @@ -1479,9 +1494,9 @@ class TestGCPEncryption(AzureGCPEncryptionTestMixin, @unittest.skipUnless(any(GCP_CREDS.values()), 'GCP environment credentials are not set') def setUpClass(cls): - cls.KMS_PROVIDER_MAP = {'gcp': GCP_CREDS} + cls.KMS_PROVIDER_MAP = {'gcp': GCP_CREDS} # type: ignore[assignment] cls.DEK = json_data(BASE, 'custom', 'gcp-dek.json') - cls.SCHEMA_MAP = json_data(BASE, 'custom', 'azure-gcp-schema.json') + cls.SCHEMA_MAP = json_data(BASE, 'custom', 'azure-gcp-schema.json') # type: ignore[attr-defined] super(TestGCPEncryption, cls).setUpClass() def test_explicit(self): @@ -1809,7 +1824,7 @@ class TestKmsTLSOptions(EncryptionIntegrationTest): def setUp(self): super(TestKmsTLSOptions, self).setUp() # 1, create client with only tlsCAFile. - providers = copy.deepcopy(ALL_KMS_PROVIDERS) + providers: dict = copy.deepcopy(ALL_KMS_PROVIDERS) providers['azure']['identityPlatformEndpoint'] = '127.0.0.1:8002' providers['gcp']['endpoint'] = '127.0.0.1:8002' kms_tls_opts_ca_only = { @@ -1831,7 +1846,7 @@ def setUp(self): kms_tls_options=kms_tls_opts) self.addCleanup(self.client_encryption_with_tls.close) # 3, update endpoints to expired host. - providers = copy.deepcopy(providers) + providers: dict = copy.deepcopy(providers) providers['azure']['identityPlatformEndpoint'] = '127.0.0.1:8000' providers['gcp']['endpoint'] = '127.0.0.1:8000' providers['kmip']['endpoint'] = '127.0.0.1:8000' @@ -1840,7 +1855,7 @@ def setUp(self): kms_tls_options=kms_tls_opts_ca_only) self.addCleanup(self.client_encryption_expired.close) # 3, update endpoints to invalid host. - providers = copy.deepcopy(providers) + providers: dict = copy.deepcopy(providers) providers['azure']['identityPlatformEndpoint'] = '127.0.0.1:8001' providers['gcp']['endpoint'] = '127.0.0.1:8001' providers['kmip']['endpoint'] = '127.0.0.1:8001' diff --git a/test/test_examples.py b/test/test_examples.py index dcf9dd2de3..29e5267109 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -693,7 +693,7 @@ def insert_docs(): # End Changestream Example 3 # Start Changestream Example 4 - pipeline = [ + pipeline: list = [ {'$match': {'fullDocument.username': 'alice'}}, {'$addFields': {'newField': 'this is an added field!'}} ] @@ -890,6 +890,7 @@ def update_employee_info(session): update_employee_info(session) employee = employees.find_one({"employee": 3}) + assert employee is not None self.assertIsNotNone(employee) self.assertEqual(employee['status'], 'Inactive') @@ -916,6 +917,7 @@ def run_transaction_with_retry(txn_func, session): run_transaction_with_retry(update_employee_info, session) employee = employees.find_one({"employee": 3}) + assert employee is not None self.assertIsNotNone(employee) self.assertEqual(employee['status'], 'Inactive') @@ -954,12 +956,13 @@ def _insert_employee_retry_commit(session): run_transaction_with_retry(_insert_employee_retry_commit, session) employee = employees.find_one({"employee": 4}) + assert employee is not None self.assertIsNotNone(employee) self.assertEqual(employee['status'], 'Active') # Start Transactions Retry Example 3 - def run_transaction_with_retry(txn_func, session): + def run_transaction_with_retry(txn_func, session): # type: ignore[no-redef] while True: try: txn_func(session) # performs transaction @@ -973,7 +976,7 @@ def run_transaction_with_retry(txn_func, session): else: raise - def commit_with_retry(session): + def commit_with_retry(session): # type: ignore[no-redef] while True: try: # Commit uses write concern set at transaction start. @@ -992,7 +995,7 @@ def commit_with_retry(session): # Updates two collections in a transactions - def update_employee_info(session): + def update_employee_info(session): # type: ignore[no-redef] employees_coll = session.client.hr.employees events_coll = session.client.reporting.events @@ -1021,6 +1024,7 @@ def update_employee_info(session): # End Transactions Retry Example 3 employee = employees.find_one({"employee": 3}) + assert employee is not None self.assertIsNotNone(employee) self.assertEqual(employee['status'], 'Inactive') @@ -1091,6 +1095,8 @@ def test_causal_consistency(self): # Start Causal Consistency Example 2 with client.start_session(causal_consistency=True) as s2: + assert s1.cluster_time is not None + assert s1.operation_time is not None s2.advance_cluster_time(s1.cluster_time) s2.advance_operation_time(s1.operation_time) diff --git a/test/test_grid_file.py b/test/test_grid_file.py index 6d7cc7ba3b..1c3094c7db 100644 --- a/test/test_grid_file.py +++ b/test/test_grid_file.py @@ -24,6 +24,8 @@ from io import BytesIO +from pymongo.database import Database + sys.path[0:0] = [""] from bson.objectid import ObjectId @@ -47,6 +49,7 @@ class TestGridFileNoConnect(unittest.TestCase): """Test GridFile features on a client that does not connect. """ + db: Database @classmethod def setUpClass(cls): @@ -238,7 +241,7 @@ def test_grid_out_cursor_options(self): cursor_dict.pop('_Cursor__session') cursor_clone_dict = cursor_clone.__dict__.copy() cursor_clone_dict.pop('_Cursor__session') - self.assertDictEqual(cursor_dict, cursor_clone_dict) + self.assertEqual(cursor_dict, cursor_clone_dict) self.assertRaises(NotImplementedError, cursor.add_option, 0) self.assertRaises(NotImplementedError, cursor.remove_option, 0) diff --git a/test/test_gridfs.py b/test/test_gridfs.py index d7d5a74e5f..3d8a7d8f6b 100644 --- a/test/test_gridfs.py +++ b/test/test_gridfs.py @@ -27,6 +27,7 @@ sys.path[0:0] = [""] from bson.binary import Binary +from pymongo.database import Database from pymongo.mongo_client import MongoClient from pymongo.errors import (ConfigurationError, NotPrimaryError, @@ -78,6 +79,7 @@ def run(self): class TestGridfsNoConnect(unittest.TestCase): + db: Database @classmethod def setUpClass(cls): @@ -89,6 +91,8 @@ def test_gridfs(self): class TestGridfs(IntegrationTest): + fs: gridfs.GridFS + alt: gridfs.GridFS @classmethod def setUpClass(cls): @@ -152,6 +156,7 @@ def test_empty_file(self): self.assertEqual(0, self.db.fs.chunks.count_documents({})) raw = self.db.fs.files.find_one() + assert raw is not None self.assertEqual(0, raw["length"]) self.assertEqual(oid, raw["_id"]) self.assertTrue(isinstance(raw["uploadDate"], datetime.datetime)) @@ -213,7 +218,7 @@ def test_threaded_reads(self): self.fs.put(b"hello", _id="test") threads = [] - results = [] + results: list = [] for i in range(10): threads.append(JustRead(self.fs, 10, results)) threads[i].start() @@ -396,6 +401,7 @@ def test_missing_length_iter(self): # Test fix that guards against PHP-237 self.fs.put(b"", filename="empty") doc = self.db.fs.files.find_one({"filename": "empty"}) + assert doc is not None doc.pop("length") self.db.fs.files.replace_one({"_id": doc["_id"]}, doc) f = self.fs.get_last_version(filename="empty") @@ -447,23 +453,32 @@ def test_delete_not_initialized(self): # but will still call __del__. cursor = GridOutCursor.__new__(GridOutCursor) # Skip calling __init__ with self.assertRaises(TypeError): - cursor.__init__(self.db.fs.files, {}, {"_id": True}) + cursor.__init__(self.db.fs.files, {}, {"_id": True}) # type: ignore cursor.__del__() # no error def test_gridfs_find_one(self): self.assertEqual(None, self.fs.find_one()) id1 = self.fs.put(b'test1', filename='file1') - self.assertEqual(b'test1', self.fs.find_one().read()) + res = self.fs.find_one() + assert res is not None + self.assertEqual(b'test1', res.read()) id2 = self.fs.put(b'test2', filename='file2', meta='data') - self.assertEqual(b'test1', self.fs.find_one(id1).read()) - self.assertEqual(b'test2', self.fs.find_one(id2).read()) - - self.assertEqual(b'test1', - self.fs.find_one({'filename': 'file1'}).read()) - - self.assertEqual('data', self.fs.find_one(id2).meta) + res1 = self.fs.find_one(id1) + assert res1 is not None + self.assertEqual(b'test1', res1.read()) + res2 = self.fs.find_one(id2) + assert res2 is not None + self.assertEqual(b'test2', res2.read()) + + res3 = self.fs.find_one({'filename': 'file1'}) + assert res3 is not None + self.assertEqual(b'test1', res3.read()) + + res4 = self.fs.find_one(id2) + assert res4 is not None + self.assertEqual('data', res4.meta) def test_grid_in_non_int_chunksize(self): # Lua, and perhaps other buggy GridFS clients, store size as a float. diff --git a/test/test_gridfs_bucket.py b/test/test_gridfs_bucket.py index 499643f673..53f94991d3 100644 --- a/test/test_gridfs_bucket.py +++ b/test/test_gridfs_bucket.py @@ -77,6 +77,8 @@ def run(self): class TestGridfs(IntegrationTest): + fs: gridfs.GridFSBucket + alt: gridfs.GridFSBucket @classmethod def setUpClass(cls): @@ -123,6 +125,7 @@ def test_empty_file(self): self.assertEqual(0, self.db.fs.chunks.count_documents({})) raw = self.db.fs.files.find_one() + assert raw is not None self.assertEqual(0, raw["length"]) self.assertEqual(oid, raw["_id"]) self.assertTrue(isinstance(raw["uploadDate"], datetime.datetime)) @@ -208,7 +211,7 @@ def test_threaded_reads(self): self.fs.upload_from_stream("test", b"hello") threads = [] - results = [] + results: list = [] for i in range(10): threads.append(JustRead(self.fs, 10, results)) threads[i].start() @@ -322,6 +325,7 @@ def test_missing_length_iter(self): # Test fix that guards against PHP-237 self.fs.upload_from_stream("empty", b"") doc = self.db.fs.files.find_one({"filename": "empty"}) + assert doc is not None doc.pop("length") self.db.fs.files.replace_one({"_id": doc["_id"]}, doc) fstr = self.fs.open_download_stream_by_name("empty") diff --git a/test/test_gridfs_spec.py b/test/test_gridfs_spec.py index 86449db370..057a7b4841 100644 --- a/test/test_gridfs_spec.py +++ b/test/test_gridfs_spec.py @@ -55,6 +55,9 @@ def camel_to_snake(camel): class TestAllScenarios(IntegrationTest): + fs: gridfs.GridFSBucket + str_to_cmd: dict + @classmethod def setUpClass(cls): super(TestAllScenarios, cls).setUpClass() diff --git a/test/test_json_util.py b/test/test_json_util.py index dbf4f1c26a..16c7d96a2f 100644 --- a/test/test_json_util.py +++ b/test/test_json_util.py @@ -20,6 +20,8 @@ import sys import uuid +from typing import Any, List, MutableMapping + sys.path[0:0] = [""] from bson import json_util, EPOCH_AWARE, EPOCH_NAIVE, SON @@ -466,7 +468,7 @@ def test_cursor(self): db = self.db db.drop_collection("test") - docs = [ + docs: List[MutableMapping[str, Any]] = [ {'foo': [1, 2]}, {'bar': {'hello': 'world'}}, {'code': Code("function x() { return 1; }")}, diff --git a/test/test_max_staleness.py b/test/test_max_staleness.py index 1fd82884f1..5c484fe334 100644 --- a/test/test_max_staleness.py +++ b/test/test_max_staleness.py @@ -35,7 +35,7 @@ 'max_staleness') -class TestAllScenarios(create_selection_tests(_TEST_PATH)): +class TestAllScenarios(create_selection_tests(_TEST_PATH)): # type: ignore pass diff --git a/test/test_monitor.py b/test/test_monitor.py index 61e2057b52..ed0d4543f8 100644 --- a/test/test_monitor.py +++ b/test/test_monitor.py @@ -59,7 +59,7 @@ def test_cleanup_executors_on_client_del(self): # Each executor stores a weakref to itself in _EXECUTORS. executor_refs = [ - (r, r()._name) for r in _EXECUTORS.copy() if r() in executors] + (r, r()._name) for r in _EXECUTORS.copy() if r() in executors] # type: ignore del executors del client diff --git a/test/test_monitoring.py b/test/test_monitoring.py index 0d925b04bf..4e513c5c69 100644 --- a/test/test_monitoring.py +++ b/test/test_monitoring.py @@ -16,6 +16,7 @@ import datetime import sys import time +from typing import Any import warnings sys.path[0:0] = [""] @@ -43,6 +44,7 @@ class TestCommandMonitoring(IntegrationTest): + listener: EventListener @classmethod @client_context.require_connection @@ -754,7 +756,7 @@ def test_non_bulk_writes(self): # delete_one self.listener.results.clear() - res = coll.delete_one({'x': 3}) + res2 = coll.delete_one({'x': 3}) results = self.listener.results started = results['started'][0] succeeded = results['succeeded'][0] @@ -1091,6 +1093,8 @@ def test_sensitive_commands(self): class TestGlobalListener(IntegrationTest): + listener: EventListener + saved_listeners: Any @classmethod @client_context.require_connection @@ -1167,13 +1171,13 @@ def test_server_heartbeat_event_repr(self): "") delta = 0.1 event = monitoring.ServerHeartbeatSucceededEvent( - delta, {'ok': 1}, connection_id) + delta, {'ok': 1}, connection_id) # type: ignore[arg-type] self.assertEqual( repr(event), "") event = monitoring.ServerHeartbeatFailedEvent( - delta, 'ERROR', connection_id) + delta, 'ERROR', connection_id) # type: ignore[arg-type] self.assertEqual( repr(event), "") event = monitoring.ServerDescriptionChangedEvent( - 'PREV', 'NEW', server_address, topology_id) + 'PREV', 'NEW', server_address, topology_id) # type: ignore[arg-type] self.assertEqual( repr(event), "") event = monitoring.TopologyDescriptionChangedEvent( - 'PREV', 'NEW', topology_id) + 'PREV', 'NEW', topology_id) # type: ignore[arg-type] self.assertEqual( repr(event), " Date: Thu, 3 Feb 2022 08:57:23 -0600 Subject: [PATCH 02/20] cleanup typings --- pymongo/socket_checker.py | 3 +-- pymongo/typings.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/pymongo/socket_checker.py b/pymongo/socket_checker.py index 8b7418725b..4544c8925a 100644 --- a/pymongo/socket_checker.py +++ b/pymongo/socket_checker.py @@ -16,9 +16,8 @@ import errno import select -import socket import sys -from typing import Any, Optional +from typing import Any, Optional, Union # PYTHON-2320: Jython does not fully support poll on SSL sockets, # https://bugs.jython.org/issue2900 diff --git a/pymongo/typings.py b/pymongo/typings.py index e234f59903..dfd4f6cbb5 100644 --- a/pymongo/typings.py +++ b/pymongo/typings.py @@ -28,4 +28,4 @@ _CollationIn = Union[Mapping[str, Any], "Collation"] _DocumentIn = Union[MutableMapping[str, Any], "RawBSONDocument"] _Pipeline = List[Mapping[str, Any]] -_DocumentType = TypeVar('_DocumentType', Mapping[str, Any], MutableMapping[str, Any], Dict[str, Any], "SON") +_DocumentType = TypeVar('_DocumentType', Mapping[str, Any], MutableMapping[str, Any], Dict[str, Any] ) From a385dffe1d23f0f5e8566e435ec7e5f08f86acd5 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 3 Feb 2022 09:03:50 -0600 Subject: [PATCH 03/20] fix dict type --- test/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/__init__.py b/test/__init__.py index 5ed23b7442..d4748d12af 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -946,7 +946,7 @@ class IntegrationTest(PyMongoTestCase): """Base class for TestCases that need a connection to MongoDB to pass.""" client: MongoClient db: Database - credentials: dict[str, str] + credentials: Dict[str, str] @classmethod @client_context.require_connection From d435dad04cfb316bd96fa191752f052c92651d1e Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 3 Feb 2022 09:31:13 -0600 Subject: [PATCH 04/20] fix son tests --- test/test_son.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_son.py b/test/test_son.py index b4e4b77e63..c3603d6596 100644 --- a/test/test_son.py +++ b/test/test_son.py @@ -156,14 +156,14 @@ def test_iteration(self): # test success case test_son = SON([("1", 100), ("2", 200), ("3", 300)]) for ele in test_son: - self.assertEqual(ele * 100, test_son[ele]) + self.assertEqual(int(ele) * 100, test_son[ele]) def test_contains_has(self): """ has_key and __contains__ """ test_son = SON([("1", 100), ("2", 200), ("3", 300)]) - self.assertIn(1, test_son) + self.assertIn("1", test_son) self.assertTrue("2" in test_son, "in failed") self.assertFalse("22" in test_son, "in succeeded when it shouldn't") self.assertTrue(test_son.has_key("2"), "has_key failed") @@ -175,7 +175,7 @@ def test_clears(self): """ test_son: SON = SON([("1", 100), ("2", 200), ("3", 300)]) test_son.clear() - self.assertNotIn(1, test_son) + self.assertNotIn("1", test_son) self.assertEqual(0, len(test_son)) self.assertEqual(0, len(test_son.keys())) self.assertEqual({}, test_son.to_dict()) From 6009bf650ab41022084cda48ba0bc5650fb5ebc4 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 3 Feb 2022 09:40:27 -0600 Subject: [PATCH 05/20] fix python 3.6 syntax --- test/test_session.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_session.py b/test/test_session.py index f451670e27..4b53ca22c8 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -20,6 +20,7 @@ import time from io import BytesIO +from typing import Set from pymongo.mongo_client import MongoClient @@ -67,7 +68,7 @@ def session_ids(client): class TestSession(IntegrationTest): client2: MongoClient - sensitive_commands: set[str] + sensitive_commands: Set[str] @classmethod @client_context.require_sessions From 699e696700afee588f04ceab7f92c2c60760c12d Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 3 Feb 2022 09:57:26 -0600 Subject: [PATCH 06/20] fixups --- pymongo/socket_checker.py | 2 +- pymongo/typings.py | 4 +-- test/mockupdb/test_handshake.py | 62 ++++++++++++++++++++++++++++++++- test/test_client.py | 4 +-- 4 files changed, 65 insertions(+), 7 deletions(-) diff --git a/pymongo/socket_checker.py b/pymongo/socket_checker.py index 4544c8925a..ff7b32a78b 100644 --- a/pymongo/socket_checker.py +++ b/pymongo/socket_checker.py @@ -42,7 +42,7 @@ def __init__(self) -> None: else: self._poller = None - def select(self, sock: Any, read: bool = False, write: bool = False, timeout: Union[float, int] = 0) -> bool: + def select(self, sock: Any, read: bool = False, write: bool = False, timeout: Optional[Union[float, int]] = 0) -> bool: """Select for reads or writes with a timeout in seconds (or None). Returns True if the socket is readable/writable, False on timeout. diff --git a/pymongo/typings.py b/pymongo/typings.py index dfd4f6cbb5..ae5aec3213 100644 --- a/pymongo/typings.py +++ b/pymongo/typings.py @@ -16,10 +16,8 @@ from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, MutableMapping, Optional, Tuple, Type, TypeVar, Union) - if TYPE_CHECKING: from bson.raw_bson import RawBSONDocument - from bson.son import SON from pymongo.collation import Collation @@ -28,4 +26,4 @@ _CollationIn = Union[Mapping[str, Any], "Collation"] _DocumentIn = Union[MutableMapping[str, Any], "RawBSONDocument"] _Pipeline = List[Mapping[str, Any]] -_DocumentType = TypeVar('_DocumentType', Mapping[str, Any], MutableMapping[str, Any], Dict[str, Any] ) +_DocumentType = TypeVar('_DocumentType', Mapping[str, Any], MutableMapping[str, Any], Dict[str, Any]) diff --git a/test/mockupdb/test_handshake.py b/test/mockupdb/test_handshake.py index 621f01728f..844ba30ce2 100644 --- a/test/mockupdb/test_handshake.py +++ b/test/mockupdb/test_handshake.py @@ -12,14 +12,58 @@ # See the License for the specific language governing permissions and # limitations under the License. +from mockupdb import (MockupDB, OpReply, OpMsg, OpMsgReply, OpQuery, absent, + Command, go) -from mockupdb import MockupDB, OpReply, OpMsg, absent, Command, go from pymongo import MongoClient, version as pymongo_version from pymongo.errors import OperationFailure +from pymongo.server_api import ServerApi, ServerApiVersion +from bson.objectid import ObjectId import unittest +def test_hello_with_option(self, protocol, **kwargs): + hello = "ismaster" if isinstance(protocol(), OpQuery) else "hello" + # `db.command("hello"|"ismaster")` commands are the same for primaries and + # secondaries, so we only need one server. + primary = MockupDB() + # Set up a custom handler to save the first request from the driver. + self.handshake_req = None + def respond(r): + # Only save the very first request from the driver. + if self.handshake_req == None: + self.handshake_req = r + load_balanced_kwargs = {"serviceId": ObjectId()} if kwargs.get( + "loadBalanced") else {} + return r.reply(OpMsgReply(minWireVersion=0, maxWireVersion=13, + **kwargs, **load_balanced_kwargs)) + primary.autoresponds(respond) + primary.run() + self.addCleanup(primary.stop) + + # We need a special dict because MongoClient uses "server_api" and all + # of the commands use "apiVersion". + k_map = {("apiVersion", "1"):("server_api", ServerApi( + ServerApiVersion.V1))} + client = MongoClient("mongodb://"+primary.address_string, + appname='my app', # For _check_handshake_data() + **dict([k_map.get((k, v), (k, v)) for k, v # type: ignore[arg-type] + in kwargs.items()])) + + self.addCleanup(client.close) + + # We have an autoresponder luckily, so no need for `go()`. + assert client.db.command(hello) + + # We do this checking here rather than in the autoresponder `respond()` + # because it runs in another Python thread so there are some funky things + # with error handling within that thread, and we want to be able to use + # self.assertRaises(). + self.handshake_req.assert_matches(protocol(hello, **kwargs)) # type: ignore[attr-defined] + _check_handshake_data(self.handshake_req) + + def _check_handshake_data(request): assert 'client' in request data = request['client'] @@ -156,6 +200,22 @@ def test_client_handshake_saslSupportedMechs(self): future() return + def test_handshake_load_balanced(self): + test_hello_with_option(self, OpMsg, loadBalanced=True) + with self.assertRaisesRegex(AssertionError, "does not match"): + test_hello_with_option(self, Command, loadBalanced=True) + + def test_handshake_versioned_api(self): + test_hello_with_option(self, OpMsg, apiVersion="1") + with self.assertRaisesRegex(AssertionError, "does not match"): + test_hello_with_option(self, Command, apiVersion="1") + + def test_handshake_not_either(self): + # If we don't specify either option then it should be using + # OP_QUERY for the initial step of the handshake. + test_hello_with_option(self, Command) + with self.assertRaisesRegex(AssertionError, "does not match"): + test_hello_with_option(self, OpMsg) if __name__ == '__main__': unittest.main() diff --git a/test/test_client.py b/test/test_client.py index 1c89631f5e..9ca9989052 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -28,7 +28,7 @@ import threading import warnings -from typing import no_type_check +from typing import no_type_check, Type sys.path[0:0] = [""] @@ -344,7 +344,7 @@ def transform_python(self, value): return int(value) # Ensure codec options are passed in correctly - document_class = SON + document_class: Type[SON] = SON type_registry = TypeRegistry([MyFloatAsIntEncoder()]) tz_aware = True uuid_representation_label = 'javaLegacy' From 3f651f0127bc098a3cc793b814d7e5a2be156028 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 3 Feb 2022 10:01:09 -0600 Subject: [PATCH 07/20] fixups --- test/performance/perf_test.py | 2 +- test/test_bson.py | 87 +++++++++++++---------------------- 2 files changed, 34 insertions(+), 55 deletions(-) diff --git a/test/performance/perf_test.py b/test/performance/perf_test.py index f25ff675e8..7effa1c1ee 100644 --- a/test/performance/perf_test.py +++ b/test/performance/perf_test.py @@ -25,7 +25,7 @@ try: import simplejson as json except ImportError: - import json # type: ignore + import json # type: ignore[no-redef] sys.path[0:0] = [""] diff --git a/test/test_bson.py b/test/test_bson.py index 6e619f27d5..7052042ca8 100644 --- a/test/test_bson.py +++ b/test/test_bson.py @@ -530,7 +530,9 @@ def test_large_datetime_truncation(self): def test_aware_datetime(self): aware = datetime.datetime(1993, 4, 4, 2, tzinfo=FixedOffset(555, "SomeZone")) - as_utc = (aware - aware.utcoffset()).replace(tzinfo=utc) # type: ignore + offset = aware.utcoffset() + assert offset is not None + as_utc = (aware - offset).replace(tzinfo=utc) self.assertEqual(datetime.datetime(1993, 4, 3, 16, 45, tzinfo=utc), as_utc) after = decode(encode({"date": aware}), CodecOptions(tz_aware=True))[ @@ -591,7 +593,9 @@ def test_local_datetime(self): def test_naive_decode(self): aware = datetime.datetime(1993, 4, 4, 2, tzinfo=FixedOffset(555, "SomeZone")) - naive_utc = (aware - aware.utcoffset()).replace(tzinfo=None) # type: ignore + offset = aware.utcoffset() + assert offset is not None + naive_utc = (aware - offset).replace(tzinfo=None) self.assertEqual(datetime.datetime(1993, 4, 3, 16, 45), naive_utc) after = decode(encode({"date": aware}))["date"] self.assertEqual(None, after.tzinfo) @@ -994,57 +998,32 @@ def test_decode_all_defaults(self): def test_unicode_decode_error_handler(self): enc = encode({"keystr": "foobar"}) - # Test handling of bad key value. + # Test handling of bad key value, bad string value, and both. invalid_key = enc[:7] + b'\xe9' + enc[8:] - replaced_key = b'ke\xe9str'.decode('utf-8', 'replace') - ignored_key = b'ke\xe9str'.decode('utf-8', 'ignore') - - dec = decode(invalid_key, - CodecOptions(unicode_decode_error_handler="replace")) - self.assertEqual(dec, {replaced_key: "foobar"}) - - dec = decode(invalid_key, - CodecOptions(unicode_decode_error_handler="ignore")) - self.assertEqual(dec, {ignored_key: "foobar"}) - - self.assertRaises(InvalidBSON, decode, invalid_key, CodecOptions( - unicode_decode_error_handler="strict")) - self.assertRaises(InvalidBSON, decode, invalid_key, CodecOptions()) - self.assertRaises(InvalidBSON, decode, invalid_key) - - # Test handing of bad string value. - invalid_val = BSON(enc[:18] + b'\xe9' + enc[19:]) - replaced_val = b'fo\xe9bar'.decode('utf-8', 'replace') - ignored_val = b'fo\xe9bar'.decode('utf-8', 'ignore') - - dec = decode(invalid_val, - CodecOptions(unicode_decode_error_handler="replace")) - self.assertEqual(dec, {"keystr": replaced_val}) - - dec = decode(invalid_val, - CodecOptions(unicode_decode_error_handler="ignore")) - self.assertEqual(dec, {"keystr": ignored_val}) - - self.assertRaises(InvalidBSON, decode, invalid_val, CodecOptions( - unicode_decode_error_handler="strict")) - self.assertRaises(InvalidBSON, decode, invalid_val, CodecOptions()) - self.assertRaises(InvalidBSON, decode, invalid_val) - - # Test handing bad key + bad value. + invalid_val = enc[:18] + b'\xe9' + enc[19:] invalid_both = enc[:7] + b'\xe9' + enc[8:18] + b'\xe9' + enc[19:] - dec = decode(invalid_both, - CodecOptions(unicode_decode_error_handler="replace")) - self.assertEqual(dec, {replaced_key: replaced_val}) - - dec = decode(invalid_both, - CodecOptions(unicode_decode_error_handler="ignore")) - self.assertEqual(dec, {ignored_key: ignored_val}) - - self.assertRaises(InvalidBSON, decode, invalid_both, CodecOptions( - unicode_decode_error_handler="strict")) - self.assertRaises(InvalidBSON, decode, invalid_both, CodecOptions()) - self.assertRaises(InvalidBSON, decode, invalid_both) + # Ensure that strict mode raises an error. + for invalid in [invalid_key, invalid_val, invalid_both]: + self.assertRaises(InvalidBSON, decode, invalid, CodecOptions( + unicode_decode_error_handler="strict")) + self.assertRaises(InvalidBSON, decode, invalid, CodecOptions()) + self.assertRaises(InvalidBSON, decode, invalid) + + # Test all other error handlers. + for handler in ['replace', 'backslashreplace', 'surrogateescape', + 'ignore']: + expected_key = b'ke\xe9str'.decode('utf-8', handler) + expected_val = b'fo\xe9bar'.decode('utf-8', handler) + doc = decode(invalid_key, + CodecOptions(unicode_decode_error_handler=handler)) + self.assertEqual(doc, {expected_key: "foobar"}) + doc = decode(invalid_val, + CodecOptions(unicode_decode_error_handler=handler)) + self.assertEqual(doc, {"keystr": expected_val}) + doc = decode(invalid_both, + CodecOptions(unicode_decode_error_handler=handler)) + self.assertEqual(doc, {expected_key: expected_val}) # Test handling bad error mode. dec = decode(enc, @@ -1064,8 +1043,8 @@ def round_trip_pickle(self, obj, pickled_with_older): def test_regex_pickling(self): reg = Regex(".?") - pickled_with_3 = (b'\x80\x04\x959\x00\x00\x00\x00\x00\x00\x00\x8c\n' - b'bson.regex\x94\x8c\x05Regex\x94\x93\x94)\x81\x94}' + pickled_with_3 = (b'\x80\x04\x959\x00\x00\x00\x00\x00\x00\x00\x8c\n' + b'bson.regex\x94\x8c\x05Regex\x94\x93\x94)\x81\x94}' b'\x94(\x8c\x07pattern\x94\x8c\x02.?\x94\x8c\x05flag' b's\x94K\x00ub.') self.round_trip_pickle(reg, pickled_with_3) @@ -1108,8 +1087,8 @@ def test_minkey_pickling(self): def test_maxkey_pickling(self): maxk = MaxKey() - pickled_with_3 = (b'\x80\x04\x95\x1e\x00\x00\x00\x00\x00\x00\x00\x8c' - b'\x0cbson.max_key\x94\x8c\x06MaxKey\x94\x93\x94)' + pickled_with_3 = (b'\x80\x04\x95\x1e\x00\x00\x00\x00\x00\x00\x00\x8c' + b'\x0cbson.max_key\x94\x8c\x06MaxKey\x94\x93\x94)' b'\x81\x94.') self.round_trip_pickle(maxk, pickled_with_3) From 6247b9ad77b426174b784219e47c8fa92b2c3ed1 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 3 Feb 2022 10:12:42 -0600 Subject: [PATCH 08/20] fixups --- test/test_bulk.py | 75 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/test/test_bulk.py b/test/test_bulk.py index dcbc6d1efa..7e11b84959 100644 --- a/test/test_bulk.py +++ b/test/test_bulk.py @@ -15,11 +15,14 @@ """Test the bulk API.""" import sys +import uuid from pymongo.mongo_client import MongoClient sys.path[0:0] = [""] +from bson.binary import Binary, UuidRepresentation +from bson.codec_options import CodecOptions from bson.objectid import ObjectId from pymongo.collection import Collection from pymongo.common import partition_node @@ -382,6 +385,78 @@ def test_client_generated_upsert_id(self): {'index': 2, '_id': 2}]}, result.bulk_api_result) + def test_upsert_uuid_standard(self): + options = CodecOptions(uuid_representation=UuidRepresentation.STANDARD) + coll = self.coll.with_options(codec_options=options) + uuids = [uuid.uuid4() for _ in range(3)] + result = coll.bulk_write([ + UpdateOne({'_id': uuids[0]}, {'$set': {'a': 0}}, upsert=True), + ReplaceOne({'a': 1}, {'_id': uuids[1]}, upsert=True), + # This is just here to make the counts right in all cases. + ReplaceOne({'_id': uuids[2]}, {'_id': uuids[2]}, upsert=True), + ]) + self.assertEqualResponse( + {'nMatched': 0, + 'nModified': 0, + 'nUpserted': 3, + 'nInserted': 0, + 'nRemoved': 0, + 'upserted': [{'index': 0, '_id': uuids[0]}, + {'index': 1, '_id': uuids[1]}, + {'index': 2, '_id': uuids[2]}]}, + result.bulk_api_result) + + def test_upsert_uuid_unspecified(self): + options = CodecOptions(uuid_representation=UuidRepresentation.UNSPECIFIED) + coll = self.coll.with_options(codec_options=options) + uuids = [Binary.from_uuid(uuid.uuid4()) for _ in range(3)] + result = coll.bulk_write([ + UpdateOne({'_id': uuids[0]}, {'$set': {'a': 0}}, upsert=True), + ReplaceOne({'a': 1}, {'_id': uuids[1]}, upsert=True), + # This is just here to make the counts right in all cases. + ReplaceOne({'_id': uuids[2]}, {'_id': uuids[2]}, upsert=True), + ]) + self.assertEqualResponse( + {'nMatched': 0, + 'nModified': 0, + 'nUpserted': 3, + 'nInserted': 0, + 'nRemoved': 0, + 'upserted': [{'index': 0, '_id': uuids[0]}, + {'index': 1, '_id': uuids[1]}, + {'index': 2, '_id': uuids[2]}]}, + result.bulk_api_result) + + def test_upsert_uuid_standard_subdocuments(self): + options = CodecOptions(uuid_representation=UuidRepresentation.STANDARD) + coll = self.coll.with_options(codec_options=options) + ids: list = [ + {'f': Binary(bytes(i)), 'f2': uuid.uuid4()} + for i in range(3) + ] + + result = coll.bulk_write([ + UpdateOne({'_id': ids[0]}, {'$set': {'a': 0}}, upsert=True), + ReplaceOne({'a': 1}, {'_id': ids[1]}, upsert=True), + # This is just here to make the counts right in all cases. + ReplaceOne({'_id': ids[2]}, {'_id': ids[2]}, upsert=True), + ]) + + # The `Binary` values are returned as `bytes` objects. + for _id in ids: + _id['f'] = bytes(_id['f']) # + + self.assertEqualResponse( + {'nMatched': 0, + 'nModified': 0, + 'nUpserted': 3, + 'nInserted': 0, + 'nRemoved': 0, + 'upserted': [{'index': 0, '_id': ids[0]}, + {'index': 1, '_id': ids[1]}, + {'index': 2, '_id': ids[2]}]}, + result.bulk_api_result) + def test_single_ordered_batch(self): result = self.coll.bulk_write([ InsertOne({'a': 1}), From 44607b7b767b241f4391083c43712e1ca64b2caa Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 3 Feb 2022 10:25:37 -0600 Subject: [PATCH 09/20] fixups --- test/test_code.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_code.py b/test/test_code.py index f392865ab5..5c719665c0 100644 --- a/test/test_code.py +++ b/test/test_code.py @@ -17,7 +17,7 @@ """Tests for the Code wrapper.""" import sys -from typing import Any, cast +from typing import Any sys.path[0:0] = [""] from bson.code import Code @@ -36,7 +36,7 @@ def test_read_only(self): c = Code("blah") def set_c(): - cast(Any, c).scope = 5 + c.scope = 5 # type: ignore self.assertRaises(AttributeError, set_c) def test_code(self): @@ -60,7 +60,7 @@ def test_code(self): def test_repr(self): c = Code("hello world", {}) self.assertEqual(repr(c), "Code('hello world', {})") - cast(Any, c).scope["foo"] = "bar" + c.scope["foo"] = "bar" # type: ignore[index] self.assertEqual(repr(c), "Code('hello world', {'foo': 'bar'})") c = Code("hello world", {"blah": 3}) self.assertEqual(repr(c), "Code('hello world', {'blah': 3})") From d1b77b71fa981dbc14265979cc7de3b52e2b70cd Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 3 Feb 2022 10:36:31 -0600 Subject: [PATCH 10/20] fixups --- test/test_code.py | 2 +- test/test_dns.py | 1 - test/unified_format.py | 6 +++--- test/utils_spec_runner.py | 8 ++++---- 4 files changed, 8 insertions(+), 9 deletions(-) diff --git a/test/test_code.py b/test/test_code.py index 5c719665c0..628a8a6560 100644 --- a/test/test_code.py +++ b/test/test_code.py @@ -17,7 +17,7 @@ """Tests for the Code wrapper.""" import sys -from typing import Any + sys.path[0:0] = [""] from bson.code import Code diff --git a/test/test_dns.py b/test/test_dns.py index 218e8bbef1..8404c2aa69 100644 --- a/test/test_dns.py +++ b/test/test_dns.py @@ -18,7 +18,6 @@ import json import os import sys -from pymongo import client_options sys.path[0:0] = [""] diff --git a/test/unified_format.py b/test/unified_format.py index b6607bd526..d24bd361ae 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -27,7 +27,7 @@ import types from collections import abc -from typing import Any, cast +from typing import Any from bson import json_util, Code, Decimal128, DBRef, SON, Int64, MaxKey, MinKey from bson.binary import Binary @@ -524,11 +524,11 @@ def _evaluate_if_special_operation(self, expectation, actual, nested = expectation[key_to_compare] if isinstance(nested, abc.Mapping) and len(nested) == 1: opname, spec = next(iter(nested.items())) - if cast(str, opname).startswith('$$'): + if opname.startswith('$$'): # type: ignore[attr-defined] is_special_op = True elif len(expectation) == 1: opname, spec = next(iter(expectation.items())) - if cast(str, opname).startswith('$$'): + if opname.startswith('$$'): # type: ignore[attr-defined] is_special_op = True key_to_compare = None diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 729212d10a..b9aa64947a 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -18,7 +18,7 @@ import threading from collections import abc -from typing import List, cast +from typing import List from bson import decode, encode from bson.binary import Binary @@ -207,7 +207,7 @@ def check_result(self, expected_result, result): # SPEC-869: Only BulkWriteResult has upserted_count. if (prop == "upserted_count" and not isinstance(result, BulkWriteResult)): - if cast(UpdateResult, result).upserted_id is not None: + if result.upserted_id is not None: # type: ignore[attr-defined] upserted_count = 1 else: upserted_count = 0 @@ -224,14 +224,14 @@ def check_result(self, expected_result, result): if isinstance(ids, dict): ids = [ids[str(i)] for i in range(len(ids))] - self.assertEqual(ids, cast(InsertManyResult, result).inserted_ids, prop) + self.assertEqual(ids, result.inserted_ids, prop) # type: ignore[attr-defined] elif prop == "upserted_ids": # Convert indexes from strings to integers. ids = expected_result[res] expected_ids = {} for str_index in ids: expected_ids[int(str_index)] = ids[str_index] - self.assertEqual(expected_ids, cast(BulkWriteResult, result).upserted_ids, prop) + self.assertEqual(expected_ids, result.upserted_ids, prop) # type: ignore[attr-defined] else: self.assertEqual( getattr(result, prop), expected_result[res], prop) From 3c6ada063a5fc3196d564aa8edc5d66a3638f304 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 3 Feb 2022 10:37:22 -0600 Subject: [PATCH 11/20] fixups --- test/utils_spec_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index b9aa64947a..d15bfaff1d 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -35,7 +35,7 @@ PyMongoError) from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference -from pymongo.results import _WriteResult, BulkWriteResult, InsertManyResult, UpdateResult +from pymongo.results import _WriteResult, BulkWriteResult from pymongo.write_concern import WriteConcern from test import (client_context, From 887d804b60d22616102bad946165020732903506 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 3 Feb 2022 16:13:13 -0600 Subject: [PATCH 12/20] test the package in CI --- .github/workflows/test-python.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml index 3ad5aa79fe..969718fc05 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python.yml @@ -46,4 +46,4 @@ jobs: pip install -e ".[zstd, srv]" - name: Run mypy run: | - mypy --install-types --non-interactive bson gridfs tools + mypy --install-types --non-interactive bson gridfs tools test From d63bcc357fdb14ffa13bd9e1948ed19fcd41047f Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 3 Feb 2022 16:37:28 -0600 Subject: [PATCH 13/20] fixups --- mypy.ini | 3 +++ test/utils.py | 8 +++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/mypy.ini b/mypy.ini index 5772c1355f..55003be11e 100644 --- a/mypy.ini +++ b/mypy.ini @@ -11,6 +11,9 @@ warn_unused_configs = true warn_unused_ignores = true warn_redundant_casts = true +[mypy-dns] +ignore_missing_imports = True + [mypy-gevent.*] ignore_missing_imports = True diff --git a/test/utils.py b/test/utils.py index 243eb73c7c..23b467a44b 100644 --- a/test/utils.py +++ b/test/utils.py @@ -29,6 +29,7 @@ from collections import abc, defaultdict from functools import partial +from typing import Any, cast from bson import json_util from bson.objectid import ObjectId @@ -869,7 +870,12 @@ def frequent_thread_switches(): sys.setswitchinterval(1e-6) else: interval = sys.getcheckinterval() # type: ignore - sys.setcheckinterval(1) # type: ignore + # Cast to any to support deprecated function. + # This function exists in Python 3.7, but + # not in newer versions of Python. + # We can't use "type: ignore" because it results + # in an error when the function is available. + cast(Any, sys).setcheckinterval(1) try: yield From 2b3796ee93c3e1273419460de36d000eae071d8b Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 3 Feb 2022 16:47:43 -0600 Subject: [PATCH 14/20] make pipeline a sequence --- pymongo/typings.py | 4 ++-- test/test_database.py | 4 ++-- test/test_examples.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pymongo/typings.py b/pymongo/typings.py index ae5aec3213..767eed36c5 100644 --- a/pymongo/typings.py +++ b/pymongo/typings.py @@ -14,7 +14,7 @@ """Type aliases used by PyMongo""" from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, MutableMapping, Optional, - Tuple, Type, TypeVar, Union) + Sequence, Tuple, Type, TypeVar, Union) if TYPE_CHECKING: from bson.raw_bson import RawBSONDocument @@ -25,5 +25,5 @@ _Address = Tuple[str, Optional[int]] _CollationIn = Union[Mapping[str, Any], "Collation"] _DocumentIn = Union[MutableMapping[str, Any], "RawBSONDocument"] -_Pipeline = List[Mapping[str, Any]] +_Pipeline = Sequence[Mapping[str, Any]] _DocumentType = TypeVar('_DocumentType', Mapping[str, Any], MutableMapping[str, Any], Dict[str, Any]) diff --git a/test/test_database.py b/test/test_database.py index 675148fc3b..cc4b1053a7 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -589,7 +589,7 @@ def test_command_max_time_ms(self): db.command('count', 'test') self.assertRaises(ExecutionTimeout, db.command, 'count', 'test', maxTimeMS=1) - pipeline: list = [{'$project': {'name': 1, 'count': 1}}] + pipeline = [{'$project': {'name': 1, 'count': 1}}] # Database command helper. db.command('aggregate', 'test', pipeline=pipeline, cursor={}) self.assertRaises(ExecutionTimeout, db.command, @@ -667,7 +667,7 @@ def test_database_aggregation_fake_cursor(self): self.addCleanup(output_coll.drop) admin = self.admin.with_options(write_concern=WriteConcern(w=0)) - pipeline: list = self.pipeline[:] + pipeline = self.pipeline[:] pipeline.append(write_stage) with admin.aggregate(pipeline) as cursor: with self.assertRaises(StopIteration): diff --git a/test/test_examples.py b/test/test_examples.py index 29e5267109..25039fccae 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -693,7 +693,7 @@ def insert_docs(): # End Changestream Example 3 # Start Changestream Example 4 - pipeline: list = [ + pipeline = [ {'$match': {'fullDocument.username': 'alice'}}, {'$addFields': {'newField': 'this is an added field!'}} ] From 937a36640da8ed56619951c9915f93334abd3927 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 3 Feb 2022 16:53:48 -0600 Subject: [PATCH 15/20] fixup --- mypy.ini | 3 --- test/test_srv_polling.py | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/mypy.ini b/mypy.ini index 55003be11e..5772c1355f 100644 --- a/mypy.ini +++ b/mypy.ini @@ -11,9 +11,6 @@ warn_unused_configs = true warn_unused_ignores = true warn_redundant_casts = true -[mypy-dns] -ignore_missing_imports = True - [mypy-gevent.*] ignore_missing_imports = True diff --git a/test/test_srv_polling.py b/test/test_srv_polling.py index 64581d83b7..1174b1511b 100644 --- a/test/test_srv_polling.py +++ b/test/test_srv_polling.py @@ -194,7 +194,7 @@ def test_replace_both_with_two(self): def test_dns_failures(self): from dns import exception - for exc in (exception.FormError, exception.TooBig, exception.Timeout): + for exc in (exception.FormError, exception.TooBig, exception.Timeout): # type: ignore[attr-defined] def response_callback(*args): raise exc("DNS Failure!") self.run_scenario(response_callback, False) From 7ca474fb875e8c82df8a7f4c64b8f2a503503a28 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 3 Feb 2022 16:55:31 -0600 Subject: [PATCH 16/20] fixup --- test/test_database.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_database.py b/test/test_database.py index cc4b1053a7..096eb5b979 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -17,7 +17,7 @@ import datetime import re import sys -from typing import Any +from typing import Any, List, Mapping sys.path[0:0] = [""] @@ -637,7 +637,7 @@ def test_with_options(self): class TestDatabaseAggregation(IntegrationTest): def setUp(self): - self.pipeline: list = [{"$listLocalSessions": {}}, + self.pipeline: List[Mapping[str, Any]] = [{"$listLocalSessions": {}}, {"$limit": 1}, {"$addFields": {"dummy": "dummy field"}}, {"$project": {"_id": 0, "dummy": 1}}] From 1f3a2aaeb4bffe5f3f78e693e32e31c617042901 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 3 Feb 2022 17:00:08 -0600 Subject: [PATCH 17/20] remove use of deprecated functions --- test/utils.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/test/utils.py b/test/utils.py index 23b467a44b..34cb1f0846 100644 --- a/test/utils.py +++ b/test/utils.py @@ -29,7 +29,6 @@ from collections import abc, defaultdict from functools import partial -from typing import Any, cast from bson import json_util from bson.objectid import ObjectId @@ -865,26 +864,14 @@ def frequent_thread_switches(): """Make concurrency bugs more likely to manifest.""" interval = None if not sys.platform.startswith('java'): - if hasattr(sys, 'getswitchinterval'): - interval = sys.getswitchinterval() - sys.setswitchinterval(1e-6) - else: - interval = sys.getcheckinterval() # type: ignore - # Cast to any to support deprecated function. - # This function exists in Python 3.7, but - # not in newer versions of Python. - # We can't use "type: ignore" because it results - # in an error when the function is available. - cast(Any, sys).setcheckinterval(1) + interval = sys.getswitchinterval() + sys.setswitchinterval(1e-6) try: yield finally: if not sys.platform.startswith('java'): - if hasattr(sys, 'setswitchinterval'): - sys.setswitchinterval(interval) # type: ignore - else: - sys.setcheckinterval(interval) # type: ignore + sys.setswitchinterval(interval) def lazy_client_trial(reset, target, test, get_client): From 363c94755781ce5d8d9e36631f0b9a09a4d3787f Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 3 Feb 2022 20:39:59 -0600 Subject: [PATCH 18/20] address review --- .github/workflows/test-python.yml | 3 +- bson/son.py | 2 +- mypy.ini | 1 + pymongo/socket_checker.py | 2 +- pymongo/srv_resolver.py | 2 +- test/__init__.py | 2 +- test/mockupdb/test_handshake.py | 2 +- test/mockupdb/test_query_read_pref_sharded.py | 2 +- test/qcheck.py | 2 +- test/test_bulk.py | 4 +- test/test_cursor.py | 9 ---- test/test_custom_types.py | 16 +++---- test/test_data_lake.py | 4 +- test/test_encryption.py | 26 +++++------ test/test_session.py | 10 ++--- test/test_son.py | 44 +++++++++---------- test/test_srv_polling.py | 2 +- test/unified_format.py | 4 +- test/utils.py | 9 ++-- test/utils_spec_runner.py | 6 +-- 20 files changed, 70 insertions(+), 82 deletions(-) diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml index 969718fc05..27d721e34a 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python.yml @@ -46,4 +46,5 @@ jobs: pip install -e ".[zstd, srv]" - name: Run mypy run: | - mypy --install-types --non-interactive bson gridfs tools test + mypy --install-types --non-interactive bson gridfs tools pymongo + mypy --install-types --non-interactive --disable-error-code var-annotated --disable-error-code attr-defined --disable-error-code union-attr --disable-error-code assignment test diff --git a/bson/son.py b/bson/son.py index 7207367f3d..bb39644637 100644 --- a/bson/son.py +++ b/bson/son.py @@ -28,7 +28,7 @@ # This is essentially the same as re._pattern_type RE_TYPE: Type[Pattern[Any]] = type(re.compile("")) -_Key = TypeVar("_Key", bound=str) +_Key = TypeVar("_Key") _Value = TypeVar("_Value") _T = TypeVar("_T") diff --git a/mypy.ini b/mypy.ini index 5772c1355f..91b1121cd5 100644 --- a/mypy.ini +++ b/mypy.ini @@ -34,6 +34,7 @@ ignore_missing_imports = True [mypy-test.*] allow_redefinition = true +allow_untyped_globals = true [mypy-winkerberos.*] ignore_missing_imports = True diff --git a/pymongo/socket_checker.py b/pymongo/socket_checker.py index ff7b32a78b..42db7b9373 100644 --- a/pymongo/socket_checker.py +++ b/pymongo/socket_checker.py @@ -42,7 +42,7 @@ def __init__(self) -> None: else: self._poller = None - def select(self, sock: Any, read: bool = False, write: bool = False, timeout: Optional[Union[float, int]] = 0) -> bool: + def select(self, sock: Any, read: bool = False, write: bool = False, timeout: Optional[float] = 0) -> bool: """Select for reads or writes with a timeout in seconds (or None). Returns True if the socket is readable/writable, False on timeout. diff --git a/pymongo/srv_resolver.py b/pymongo/srv_resolver.py index d9ee7b7c8a..989e79131c 100644 --- a/pymongo/srv_resolver.py +++ b/pymongo/srv_resolver.py @@ -39,7 +39,7 @@ def maybe_decode(text): def _resolve(*args, **kwargs): if hasattr(resolver, 'resolve'): # dnspython >= 2 - return resolver.resolve(*args, **kwargs) # type: ignore + return resolver.resolve(*args, **kwargs) # dnspython 1.X return resolver.query(*args, **kwargs) diff --git a/test/__init__.py b/test/__init__.py index d4748d12af..da7e761a01 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -488,7 +488,7 @@ def has_secondaries(self): @property def storage_engine(self): try: - return self.server_status.get("storageEngine", {}).get("name") # type: ignore[union-attr] + return self.server_status.get("storageEngine", {}).get("name") except AttributeError: # Raised if self.server_status is None. return None diff --git a/test/mockupdb/test_handshake.py b/test/mockupdb/test_handshake.py index 844ba30ce2..f1ba1c89d6 100644 --- a/test/mockupdb/test_handshake.py +++ b/test/mockupdb/test_handshake.py @@ -60,7 +60,7 @@ def respond(r): # because it runs in another Python thread so there are some funky things # with error handling within that thread, and we want to be able to use # self.assertRaises(). - self.handshake_req.assert_matches(protocol(hello, **kwargs)) # type: ignore[attr-defined] + self.handshake_req.assert_matches(protocol(hello, **kwargs)) _check_handshake_data(self.handshake_req) diff --git a/test/mockupdb/test_query_read_pref_sharded.py b/test/mockupdb/test_query_read_pref_sharded.py index 20bc1e9155..88dcdd8351 100644 --- a/test/mockupdb/test_query_read_pref_sharded.py +++ b/test/mockupdb/test_query_read_pref_sharded.py @@ -50,7 +50,7 @@ def test_query_and_read_mode_sharded_op_msg(self): for pref in read_prefs: collection = client.db.get_collection('test', read_preference=pref) - cursor = collection.find(query.copy()) # type: ignore[attr-defined] + cursor = collection.find(query.copy()) with going(next, cursor): request = server.receives() # Command is not nested in $query. diff --git a/test/qcheck.py b/test/qcheck.py index 0066ac6e57..57e0940b72 100644 --- a/test/qcheck.py +++ b/test/qcheck.py @@ -189,7 +189,7 @@ def simplify(case): # TODO this is a hack simplified[key] = value return (success, success and simplified or case) if isinstance(case, list): - simplified = list(case) # type: ignore[assignment] + simplified = list(case) if random.choice([True, False]): # delete if not len(simplified): diff --git a/test/test_bulk.py b/test/test_bulk.py index 7e11b84959..a895dfddc3 100644 --- a/test/test_bulk.py +++ b/test/test_bulk.py @@ -444,7 +444,7 @@ def test_upsert_uuid_standard_subdocuments(self): # The `Binary` values are returned as `bytes` objects. for _id in ids: - _id['f'] = bytes(_id['f']) # + _id['f'] = bytes(_id['f']) self.assertEqualResponse( {'nMatched': 0, @@ -844,7 +844,7 @@ class TestBulkWriteConcern(BulkTestBase): def setUpClass(cls): super(TestBulkWriteConcern, cls).setUpClass() cls.w = client_context.w - cls.secondary = None # type: ignore[assignment] + cls.secondary = None if cls.w is not None and cls.w > 1: for member in client_context.hello['hosts']: if member != client_context.hello['primary']: diff --git a/test/test_cursor.py b/test/test_cursor.py index fc804035a8..f741b8b0cc 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -22,8 +22,6 @@ import time import threading -from typing import no_type_check - sys.path[0:0] = [""] from bson import decode_all @@ -61,7 +59,6 @@ def test_deepcopy_cursor_littered_with_regexes(self): cursor2 = copy.deepcopy(cursor) self.assertEqual(cursor._Cursor__spec, cursor2._Cursor__spec) # type: ignore - @no_type_check def test_add_remove_option(self): cursor = self.db.test.find() self.assertEqual(0, cursor._Cursor__query_flags) @@ -128,7 +125,6 @@ def test_add_remove_option(self): cursor.remove_option(128) self.assertEqual(0, cursor._Cursor__query_flags) - @no_type_check def test_add_remove_option_exhaust(self): # Exhaust - which mongos doesn't support if client_context.is_mongos: @@ -203,7 +199,6 @@ def test_max_time_ms(self): "maxTimeAlwaysTimeOut", mode="off") - @no_type_check def test_max_await_time_ms(self): db = self.db db.pymongo_test.drop() @@ -533,7 +528,6 @@ def test_min_max_without_hint(self): with self.assertRaises(InvalidOperation): list(coll.find().max([("j", 3)])) - @no_type_check def test_batch_size(self): db = self.db db.test.drop() @@ -587,7 +581,6 @@ def cursor_count(cursor, expected_count): next(cur) self.assertEqual(0, len(cur._Cursor__data)) - @no_type_check def test_limit_and_batch_size(self): db = self.db db.test.drop() @@ -836,7 +829,6 @@ def test_rewind(self): # oplog_reply, and snapshot are all deprecated. @ignore_deprecations - @no_type_check def test_clone(self): self.db.test.insert_many([{"x": i} for i in range(1, 4)]) @@ -1010,7 +1002,6 @@ def test_getitem_slice_index(self): self.assertRaises(IndexError, lambda: self.db.test.find()[10:8]) - @no_type_check def test_getitem_numeric_index(self): self.db.drop_collection("test") self.db.test.insert_many([{"i": i} for i in range(100)]) diff --git a/test/test_custom_types.py b/test/test_custom_types.py index dc0a7f55e9..eee47b9d2b 100644 --- a/test/test_custom_types.py +++ b/test/test_custom_types.py @@ -90,7 +90,7 @@ def __eq__(self, other): class UndecipherableIntDecoder(TypeDecoder): - bson_type = Int64 # type: ignore[assignment] + bson_type = Int64 def transform_bson(self, value): return UndecipherableInt64Type(value) @@ -110,7 +110,7 @@ def transform_python(self, value): class UppercaseTextDecoder(TypeDecoder): - bson_type = str # type: ignore[assignment] + bson_type = str def transform_bson(self, value): return value.upper() @@ -202,7 +202,7 @@ class TestCustomPythonBSONTypeToBSONMonolithicCodec(CustomBSONTypeTests, unittest.TestCase): @classmethod def setUpClass(cls): - cls.codecopts = DECIMAL_CODECOPTS # type: ignore[attr-defined] + cls.codecopts = DECIMAL_CODECOPTS class TestCustomPythonBSONTypeToBSONMultiplexedCodec(CustomBSONTypeTests, @@ -211,7 +211,7 @@ class TestCustomPythonBSONTypeToBSONMultiplexedCodec(CustomBSONTypeTests, def setUpClass(cls): codec_options = CodecOptions( type_registry=TypeRegistry((DecimalEncoder(), DecimalDecoder()))) - cls.codecopts = codec_options # type: ignore[attr-defined] + cls.codecopts = codec_options class TestBSONFallbackEncoder(unittest.TestCase): @@ -471,20 +471,20 @@ def assert_proper_initialization(type_registry, codec_instances): codec_instances_copy = list(codec_instances) codec_instances.pop(0) self.assertListEqual( - type_registry._TypeRegistry__type_codecs, codec_instances_copy) # type: ignore[attr-defined] + type_registry._TypeRegistry__type_codecs, codec_instances_copy) def test_simple_separate_codecs(self): class MyIntEncoder(TypeEncoder): - python_type = self.types[0] # type: ignore[assignment] + python_type = self.types[0] def transform_python(self, value): return value.x class MyIntDecoder(TypeDecoder): - bson_type = int # type: ignore[assignment] + bson_type = int def transform_bson(self, value): - return self.types[0](value) # type: ignore[attr-defined] + return self.types[0](value) codec_instances: list = [MyIntDecoder(), MyIntEncoder()] type_registry = TypeRegistry(codec_instances) diff --git a/test/test_data_lake.py b/test/test_data_lake.py index 0e52950250..2954efe651 100644 --- a/test/test_data_lake.py +++ b/test/test_data_lake.py @@ -94,8 +94,8 @@ def test_3(self): class DataLakeTestSpec(TestCrudV2): # Default test database and collection names. - TEST_DB = 'test' # type: ignore - TEST_COLLECTION = 'driverdata' # type: ignore + TEST_DB = 'test' + TEST_COLLECTION = 'driverdata' @classmethod @client_context.require_data_lake diff --git a/test/test_encryption.py b/test/test_encryption.py index 7ed9d92663..855cb66c28 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -886,7 +886,7 @@ def _test_external_key_vault(self, with_external_key_vault): client_encrypted.db.coll.insert_one({"encrypted": "test"}) # AuthenticationFailed error. self.assertIsInstance(ctx.exception.cause, OperationFailure) - self.assertEqual(ctx.exception.cause.code, 18) # type: ignore[attr-defined] + self.assertEqual(ctx.exception.cause.code, 18) else: client_encrypted.db.coll.insert_one({"encrypted": "test"}) @@ -899,7 +899,7 @@ def _test_external_key_vault(self, with_external_key_vault): key_id=LOCAL_KEY_ID) # AuthenticationFailed error. self.assertIsInstance(ctx.exception.cause, OperationFailure) - self.assertEqual(ctx.exception.cause.code, 18) # type: ignore[attr-defined] + self.assertEqual(ctx.exception.cause.code, 18) else: client_encryption.encrypt( "test", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, @@ -1416,15 +1416,15 @@ def _test_explicit(self, expectation): '.'.join([self.KEYVAULT_DB, self.KEYVAULT_COLL]), client_context.client, OPTS) - self.addCleanup(client_encryption.close) # type: ignore[attr-defined] + self.addCleanup(client_encryption.close) ciphertext = client_encryption.encrypt( 'string0', algorithm=Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id=Binary.from_uuid(self.DEK['_id'], STANDARD)) # type: ignore[index] - self.assertEqual(bytes(ciphertext), base64.b64decode(expectation)) # type: ignore[attr-defined] - self.assertEqual(client_encryption.decrypt(ciphertext), 'string0') # type: ignore[attr-defined] + self.assertEqual(bytes(ciphertext), base64.b64decode(expectation)) + self.assertEqual(client_encryption.decrypt(ciphertext), 'string0') def _test_automatic(self, expectation_extjson, payload): encrypted_db = "db" @@ -1434,13 +1434,13 @@ def _test_automatic(self, expectation_extjson, payload): encryption_opts = AutoEncryptionOpts( self.KMS_PROVIDER_MAP, # type: ignore[arg-type] keyvault_namespace, - schema_map=self.SCHEMA_MAP) # type: ignore[attr-defined] + schema_map=self.SCHEMA_MAP) insert_listener = AllowListEventListener('insert') client = rs_or_single_client( auto_encryption_opts=encryption_opts, event_listeners=[insert_listener]) - self.addCleanup(client.close) # type: ignore[attr-defined] + self.addCleanup(client.close) coll = client.get_database(encrypted_db).get_collection( encrypted_coll, codec_options=OPTS, @@ -1455,11 +1455,11 @@ def _test_automatic(self, expectation_extjson, payload): inserted_doc = event.command['documents'][0] for key, value in expected_document.items(): - self.assertEqual(value, inserted_doc[key]) # type: ignore[attr-defined] + self.assertEqual(value, inserted_doc[key]) output_doc = coll.find_one({}) for key, value in payload.items(): - self.assertEqual(output_doc[key], value) # type: ignore[attr-defined] + self.assertEqual(output_doc[key], value) class TestAzureEncryption(AzureGCPEncryptionTestMixin, @@ -1468,9 +1468,9 @@ class TestAzureEncryption(AzureGCPEncryptionTestMixin, @unittest.skipUnless(any(AZURE_CREDS.values()), 'Azure environment credentials are not set') def setUpClass(cls): - cls.KMS_PROVIDER_MAP = {'azure': AZURE_CREDS} # type: ignore[assignment] + cls.KMS_PROVIDER_MAP = {'azure': AZURE_CREDS} cls.DEK = json_data(BASE, 'custom', 'azure-dek.json') - cls.SCHEMA_MAP = json_data(BASE, 'custom', 'azure-gcp-schema.json') # type: ignore[attr-defined] + cls.SCHEMA_MAP = json_data(BASE, 'custom', 'azure-gcp-schema.json') super(TestAzureEncryption, cls).setUpClass() def test_explicit(self): @@ -1494,9 +1494,9 @@ class TestGCPEncryption(AzureGCPEncryptionTestMixin, @unittest.skipUnless(any(GCP_CREDS.values()), 'GCP environment credentials are not set') def setUpClass(cls): - cls.KMS_PROVIDER_MAP = {'gcp': GCP_CREDS} # type: ignore[assignment] + cls.KMS_PROVIDER_MAP = {'gcp': GCP_CREDS} cls.DEK = json_data(BASE, 'custom', 'gcp-dek.json') - cls.SCHEMA_MAP = json_data(BASE, 'custom', 'azure-gcp-schema.json') # type: ignore[attr-defined] + cls.SCHEMA_MAP = json_data(BASE, 'custom', 'azure-gcp-schema.json') super(TestGCPEncryption, cls).setUpClass() def test_explicit(self): diff --git a/test/test_session.py b/test/test_session.py index 4b53ca22c8..98eccbae36 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -318,12 +318,12 @@ def test_cursor_clone(self): next(cursor) # Session is "owned" by cursor. self.assertIsNone(cursor.session) - self.assertIsNotNone(cursor._Cursor__session) # type: ignore[attr-defined] + self.assertIsNotNone(cursor._Cursor__session) clone = cursor.clone() next(clone) self.assertIsNone(clone.session) - self.assertIsNotNone(clone._Cursor__session) # type: ignore[attr-defined] - self.assertFalse(cursor._Cursor__session is clone._Cursor__session) # type: ignore[attr-defined] + self.assertIsNotNone(clone._Cursor__session) + self.assertFalse(cursor._Cursor__session is clone._Cursor__session) cursor.close() clone.close() @@ -484,12 +484,12 @@ def test_gridfsbucket_cursor(self): cursor = bucket.find(batch_size=1) files = [cursor.next()] - s = cursor._Cursor__session # type: ignore[attr-defined] + s = cursor._Cursor__session self.assertFalse(s.has_ended) cursor.__del__() self.assertTrue(s.has_ended) - self.assertIsNone(cursor._Cursor__session) # type: ignore[attr-defined] + self.assertIsNone(cursor._Cursor__session) # Files are still valid, they use their own sessions. for f in files: diff --git a/test/test_son.py b/test/test_son.py index c3603d6596..edddd6b8b8 100644 --- a/test/test_son.py +++ b/test/test_son.py @@ -19,8 +19,6 @@ import re import sys -from typing import Any - sys.path[0:0] = [""] from bson.son import SON @@ -29,7 +27,7 @@ class TestSON(unittest.TestCase): def test_ordered_dict(self): - a1: SON = SON() + a1 = SON() a1["hello"] = "world" a1["mike"] = "awesome" a1["hello_"] = "mike" @@ -68,7 +66,7 @@ def test_equality(self): ('hello', 'world')))) # Embedded SON. - d4: SON = SON([('blah', {'foo': SON()})]) + d4 = SON([('blah', {'foo': SON()})]) self.assertEqual(d4, {'blah': {'foo': {}}}) self.assertEqual(d4, {'blah': {'foo': SON()}}) self.assertNotEqual(d4, {'blah': {'foo': []}}) @@ -77,10 +75,10 @@ def test_equality(self): self.assertEqual(SON, d4['blah']['foo'].__class__) def test_to_dict(self): - a1: SON = SON() - b2: SON = SON([("blah", SON())]) - c3: SON = SON([("blah", [SON()])]) - d4: SON = SON([("blah", {"foo": SON()})]) + a1 = SON() + b2 = SON([("blah", SON())]) + c3 = SON([("blah", [SON()])]) + d4 = SON([("blah", {"foo": SON()})]) self.assertEqual({}, a1.to_dict()) self.assertEqual({"blah": {}}, b2.to_dict()) self.assertEqual({"blah": [{}]}, c3.to_dict()) @@ -95,7 +93,7 @@ def test_to_dict(self): def test_pickle(self): - simple_son: SON = SON([]) + simple_son = SON([]) complex_son = SON([('son', simple_son), ('list', [simple_son, simple_son])]) @@ -116,7 +114,7 @@ def test_pickle_backwards_compatability(self): self.assertEqual(son_2_1_1, SON([])) def test_copying(self): - simple_son: SON = SON([]) + simple_son = SON([]) complex_son = SON([('son', simple_son), ('list', [simple_son, simple_son])]) regex_son = SON([("x", re.compile("^hello.*"))]) @@ -154,28 +152,28 @@ def test_iteration(self): Test __iter__ """ # test success case - test_son = SON([("1", 100), ("2", 200), ("3", 300)]) + test_son = SON([(1, 100), (2, 200), (3, 300)]) for ele in test_son: - self.assertEqual(int(ele) * 100, test_son[ele]) + self.assertEqual(ele * 100, test_son[ele]) def test_contains_has(self): """ has_key and __contains__ """ - test_son = SON([("1", 100), ("2", 200), ("3", 300)]) - self.assertIn("1", test_son) - self.assertTrue("2" in test_son, "in failed") - self.assertFalse("22" in test_son, "in succeeded when it shouldn't") - self.assertTrue(test_son.has_key("2"), "has_key failed") - self.assertFalse(test_son.has_key("22"), "has_key succeeded when it shouldn't") + test_son = SON([(1, 100), (2, 200), (3, 300)]) + self.assertIn(1, test_son) + self.assertTrue(2 in test_son, "in failed") + self.assertFalse(22 in test_son, "in succeeded when it shouldn't") + self.assertTrue(test_son.has_key(2), "has_key failed") + self.assertFalse(test_son.has_key(22), "has_key succeeded when it shouldn't") def test_clears(self): """ Test clear() """ - test_son: SON = SON([("1", 100), ("2", 200), ("3", 300)]) + test_son = SON([(1, 100), (2, 200), (3, 300)]) test_son.clear() - self.assertNotIn("1", test_son) + self.assertNotIn(1, test_son) self.assertEqual(0, len(test_son)) self.assertEqual(0, len(test_son.keys())) self.assertEqual({}, test_son.to_dict()) @@ -184,16 +182,16 @@ def test_len(self): """ Test len """ - test_son: SON = SON() + test_son = SON() self.assertEqual(0, len(test_son)) - test_son = SON([("1", 100), ("2", 200), ("3", 300)]) + test_son = SON([(1, 100), (2, 200), (3, 300)]) self.assertEqual(3, len(test_son)) test_son.popitem() self.assertEqual(2, len(test_son)) def test_keys(self): # Test to make sure that set operations do not throw an error - d: Any = SON().keys() + d = SON().keys() for i in [OrderedDict, dict]: try: d - i().keys() diff --git a/test/test_srv_polling.py b/test/test_srv_polling.py index 1174b1511b..64581d83b7 100644 --- a/test/test_srv_polling.py +++ b/test/test_srv_polling.py @@ -194,7 +194,7 @@ def test_replace_both_with_two(self): def test_dns_failures(self): from dns import exception - for exc in (exception.FormError, exception.TooBig, exception.Timeout): # type: ignore[attr-defined] + for exc in (exception.FormError, exception.TooBig, exception.Timeout): def response_callback(*args): raise exc("DNS Failure!") self.run_scenario(response_callback, False) diff --git a/test/unified_format.py b/test/unified_format.py index d24bd361ae..9c38c47863 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -524,11 +524,11 @@ def _evaluate_if_special_operation(self, expectation, actual, nested = expectation[key_to_compare] if isinstance(nested, abc.Mapping) and len(nested) == 1: opname, spec = next(iter(nested.items())) - if opname.startswith('$$'): # type: ignore[attr-defined] + if opname.startswith('$$'): is_special_op = True elif len(expectation) == 1: opname, spec = next(iter(expectation.items())) - if opname.startswith('$$'): # type: ignore[attr-defined] + if opname.startswith('$$'): is_special_op = True key_to_compare = None diff --git a/test/utils.py b/test/utils.py index 34cb1f0846..b0b0c87c47 100644 --- a/test/utils.py +++ b/test/utils.py @@ -862,16 +862,13 @@ def run_threads(collection, target): @contextlib.contextmanager def frequent_thread_switches(): """Make concurrency bugs more likely to manifest.""" - interval = None - if not sys.platform.startswith('java'): - interval = sys.getswitchinterval() - sys.setswitchinterval(1e-6) + interval = sys.getswitchinterval() + sys.setswitchinterval(1e-6) try: yield finally: - if not sys.platform.startswith('java'): - sys.setswitchinterval(interval) + sys.setswitchinterval(interval) def lazy_client_trial(reset, target, test, get_client): diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index d15bfaff1d..8a53a365db 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -207,7 +207,7 @@ def check_result(self, expected_result, result): # SPEC-869: Only BulkWriteResult has upserted_count. if (prop == "upserted_count" and not isinstance(result, BulkWriteResult)): - if result.upserted_id is not None: # type: ignore[attr-defined] + if result.upserted_id is not None: upserted_count = 1 else: upserted_count = 0 @@ -224,14 +224,14 @@ def check_result(self, expected_result, result): if isinstance(ids, dict): ids = [ids[str(i)] for i in range(len(ids))] - self.assertEqual(ids, result.inserted_ids, prop) # type: ignore[attr-defined] + self.assertEqual(ids, result.inserted_ids, prop) elif prop == "upserted_ids": # Convert indexes from strings to integers. ids = expected_result[res] expected_ids = {} for str_index in ids: expected_ids[int(str_index)] = ids[str_index] - self.assertEqual(expected_ids, result.upserted_ids, prop) # type: ignore[attr-defined] + self.assertEqual(expected_ids, result.upserted_ids, prop) else: self.assertEqual( getattr(result, prop), expected_result[res], prop) From d12570ed76855554fb131cc3d38c1294fe91dac7 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Mon, 7 Feb 2022 16:49:22 -0600 Subject: [PATCH 19/20] address review --- .github/workflows/test-python.yml | 2 +- test/__init__.py | 2 +- test/test_code.py | 2 +- test/test_collection.py | 10 +++++----- test/test_encryption.py | 2 +- test/test_examples.py | 12 +++++++----- test/test_grid_file.py | 2 +- 7 files changed, 17 insertions(+), 15 deletions(-) diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml index 27d721e34a..ca1845e2cd 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python.yml @@ -47,4 +47,4 @@ jobs: - name: Run mypy run: | mypy --install-types --non-interactive bson gridfs tools pymongo - mypy --install-types --non-interactive --disable-error-code var-annotated --disable-error-code attr-defined --disable-error-code union-attr --disable-error-code assignment test + mypy --install-types --non-interactive --disable-error-code var-annotated --disable-error-code attr-defined --disable-error-code union-attr --disable-error-code assignment --disable-error-code no-redef --disable-error-code index test diff --git a/test/__init__.py b/test/__init__.py index da7e761a01..c02eb97949 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -654,7 +654,7 @@ def supports_secondary_read_pref(self): if self.has_secondaries: return True if self.is_mongos: - shard = self.client.config.shards.find_one()['host'] # type: ignore[index] + shard = self.client.config.shards.find_one()['host'] num_members = shard.count(',') + 1 return num_members > 1 return False diff --git a/test/test_code.py b/test/test_code.py index 628a8a6560..1c4b5be1fe 100644 --- a/test/test_code.py +++ b/test/test_code.py @@ -60,7 +60,7 @@ def test_code(self): def test_repr(self): c = Code("hello world", {}) self.assertEqual(repr(c), "Code('hello world', {})") - c.scope["foo"] = "bar" # type: ignore[index] + c.scope["foo"] = "bar" self.assertEqual(repr(c), "Code('hello world', {'foo': 'bar'})") c = Code("hello world", {"blah": 3}) self.assertEqual(repr(c), "Code('hello world', {'blah': 3})") diff --git a/test/test_collection.py b/test/test_collection.py index 98e1ca6bb3..3d4a107aa9 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -1734,8 +1734,8 @@ def test_find_one_with_find_args(self): db.test.insert_many([{"x": i} for i in range(1, 4)]) - self.assertEqual(1, db.test.find_one()["x"]) # type: ignore[index] - self.assertEqual(2, db.test.find_one(skip=1, limit=2)["x"]) # type: ignore[index] + self.assertEqual(1, db.test.find_one()["x"]) + self.assertEqual(2, db.test.find_one(skip=1, limit=2)["x"]) def test_find_with_sort(self): db = self.db @@ -1743,9 +1743,9 @@ def test_find_with_sort(self): db.test.insert_many([{"x": 2}, {"x": 1}, {"x": 3}]) - self.assertEqual(2, db.test.find_one()["x"]) # type: ignore[index] - self.assertEqual(1, db.test.find_one(sort=[("x", 1)])["x"]) # type: ignore[index] - self.assertEqual(3, db.test.find_one(sort=[("x", -1)])["x"]) # type: ignore[index] + self.assertEqual(2, db.test.find_one()["x"]) + self.assertEqual(1, db.test.find_one(sort=[("x", 1)])["x"]) + self.assertEqual(3, db.test.find_one(sort=[("x", -1)])["x"]) def to_list(things): return [thing["x"] for thing in things] diff --git a/test/test_encryption.py b/test/test_encryption.py index 855cb66c28..4a279b8b6f 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -1421,7 +1421,7 @@ def _test_explicit(self, expectation): ciphertext = client_encryption.encrypt( 'string0', algorithm=Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, - key_id=Binary.from_uuid(self.DEK['_id'], STANDARD)) # type: ignore[index] + key_id=Binary.from_uuid(self.DEK['_id'], STANDARD)) self.assertEqual(bytes(ciphertext), base64.b64decode(expectation)) self.assertEqual(client_encryption.decrypt(ciphertext), 'string0') diff --git a/test/test_examples.py b/test/test_examples.py index 25039fccae..7a0bde00ca 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -962,7 +962,7 @@ def _insert_employee_retry_commit(session): # Start Transactions Retry Example 3 - def run_transaction_with_retry(txn_func, session): # type: ignore[no-redef] + def run_transaction_with_retry(txn_func, session): while True: try: txn_func(session) # performs transaction @@ -976,7 +976,7 @@ def run_transaction_with_retry(txn_func, session): # type: ignore[no-redef] else: raise - def commit_with_retry(session): # type: ignore[no-redef] + def commit_with_retry(session): while True: try: # Commit uses write concern set at transaction start. @@ -995,7 +995,7 @@ def commit_with_retry(session): # type: ignore[no-redef] # Updates two collections in a transactions - def update_employee_info(session): # type: ignore[no-redef] + def update_employee_info(session): employees_coll = session.client.hr.employees events_coll = session.client.reporting.events @@ -1093,10 +1093,12 @@ def test_causal_consistency(self): 'start': current_date}, session=s1) # End Causal Consistency Example 1 + assert s1.cluster_time is not None + assert s1.operation_time is not None + # Start Causal Consistency Example 2 with client.start_session(causal_consistency=True) as s2: - assert s1.cluster_time is not None - assert s1.operation_time is not None + s2.advance_cluster_time(s1.cluster_time) s2.advance_operation_time(s1.operation_time) diff --git a/test/test_grid_file.py b/test/test_grid_file.py index 1c3094c7db..2208e97b42 100644 --- a/test/test_grid_file.py +++ b/test/test_grid_file.py @@ -241,7 +241,7 @@ def test_grid_out_cursor_options(self): cursor_dict.pop('_Cursor__session') cursor_clone_dict = cursor_clone.__dict__.copy() cursor_clone_dict.pop('_Cursor__session') - self.assertEqual(cursor_dict, cursor_clone_dict) + self.assertDictEqual(cursor_dict, cursor_clone_dict) self.assertRaises(NotImplementedError, cursor.add_option, 0) self.assertRaises(NotImplementedError, cursor.remove_option, 0) From 1920c4786204e50ba3556fcab56b77142c5a9e5c Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Mon, 7 Feb 2022 17:21:01 -0600 Subject: [PATCH 20/20] address review --- test/test_examples.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_examples.py b/test/test_examples.py index 7a0bde00ca..ed12c8bcc1 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -1098,7 +1098,6 @@ def test_causal_consistency(self): # Start Causal Consistency Example 2 with client.start_session(causal_consistency=True) as s2: - s2.advance_cluster_time(s1.cluster_time) s2.advance_operation_time(s1.operation_time)