diff --git a/bson/__init__.py b/bson/__init__.py index d95c511fc7..d0a8daa273 100644 --- a/bson/__init__.py +++ b/bson/__init__.py @@ -237,8 +237,8 @@ def get_data_and_view(data: Any) -> Tuple[Any, memoryview]: def _raise_unknown_type(element_type: int, element_name: str) -> NoReturn: """Unknown type helper.""" raise InvalidBSON( - "Detected unknown BSON type %r for fieldname '%s'. Are " - "you using the latest driver version?" % (chr(element_type).encode(), element_name) + "Detected unknown BSON type {!r} for fieldname '{}'. Are " + "you using the latest driver version?".format(chr(element_type).encode(), element_name) ) @@ -626,8 +626,7 @@ def gen_list_name() -> Generator[bytes, None, None]: The first 1000 keys are returned from a pre-built cache. All subsequent keys are generated on the fly. """ - for name in _LIST_NAMES: - yield name + yield from _LIST_NAMES counter = itertools.count(1000) while True: @@ -942,18 +941,18 @@ def _name_value_to_bson( name, fallback_encoder(value), check_keys, opts, in_fallback_call=True ) - raise InvalidDocument("cannot encode object: %r, of type: %r" % (value, type(value))) + raise InvalidDocument(f"cannot encode object: {value!r}, of type: {type(value)!r}") def _element_to_bson(key: Any, value: Any, check_keys: bool, opts: CodecOptions) -> bytes: """Encode a single key, value pair.""" if not isinstance(key, str): - raise InvalidDocument("documents must have only string keys, key was %r" % (key,)) + raise InvalidDocument(f"documents must have only string keys, key was {key!r}") if check_keys: if key.startswith("$"): - raise InvalidDocument("key %r must not start with '$'" % (key,)) + raise InvalidDocument(f"key {key!r} must not start with '$'") if "." in key: - raise InvalidDocument("key %r must not contain '.'" % (key,)) + raise InvalidDocument(f"key {key!r} must not contain '.'") name = _make_name(key) return _name_value_to_bson(name, value, check_keys, opts) @@ -971,7 +970,7 @@ def _dict_to_bson(doc: Any, check_keys: bool, opts: CodecOptions, top_level: boo if not top_level or key != "_id": elements.append(_element_to_bson(key, value, check_keys, opts)) except AttributeError: - raise TypeError("encoder expected a mapping type but got: %r" % (doc,)) + raise TypeError(f"encoder expected a mapping type but got: {doc!r}") encoded = b"".join(elements) return _PACK_INT(len(encoded) + 5) + encoded + b"\x00" diff --git a/bson/_helpers.py b/bson/_helpers.py index ee3b0f1099..5643d77c24 100644 --- a/bson/_helpers.py +++ b/bson/_helpers.py @@ -13,7 +13,7 @@ # limitations under the License. """Setstate and getstate functions for objects with __slots__, allowing - compatibility with default pickling protocol +compatibility with default pickling protocol """ from typing import Any, Mapping @@ -33,7 +33,7 @@ def _mangle_name(name: str, prefix: str) -> str: def _getstate_slots(self: Any) -> Mapping[Any, Any]: prefix = self.__class__.__name__ - ret = dict() + ret = {} for name in self.__slots__: mangled_name = _mangle_name(name, prefix) if hasattr(self, mangled_name): diff --git a/bson/binary.py b/bson/binary.py index a270eae8d2..77e3a3d478 100644 --- a/bson/binary.py +++ b/bson/binary.py @@ -306,7 +306,7 @@ def as_uuid(self, uuid_representation: int = UuidRepresentation.STANDARD) -> UUI .. versionadded:: 3.11 """ if self.subtype not in ALL_UUID_SUBTYPES: - raise ValueError("cannot decode subtype %s as a uuid" % (self.subtype,)) + raise ValueError(f"cannot decode subtype {self.subtype} as a uuid") if uuid_representation not in ALL_UUID_REPRESENTATIONS: raise ValueError( @@ -330,8 +330,7 @@ def as_uuid(self, uuid_representation: int = UuidRepresentation.STANDARD) -> UUI return UUID(bytes=self) raise ValueError( - "cannot decode subtype %s to %s" - % (self.subtype, UUID_REPRESENTATION_NAMES[uuid_representation]) + f"cannot decode subtype {self.subtype} to {UUID_REPRESENTATION_NAMES[uuid_representation]}" ) @property @@ -341,7 +340,7 @@ def subtype(self) -> int: def __getnewargs__(self) -> Tuple[bytes, int]: # type: ignore[override] # Work around http://bugs.python.org/issue7382 - data = super(Binary, self).__getnewargs__()[0] + data = super().__getnewargs__()[0] if not isinstance(data, bytes): data = data.encode("latin-1") return data, self.__subtype @@ -355,10 +354,10 @@ def __eq__(self, other: Any) -> bool: return False def __hash__(self) -> int: - return super(Binary, self).__hash__() ^ hash(self.__subtype) + return super().__hash__() ^ hash(self.__subtype) def __ne__(self, other: Any) -> bool: return not self == other def __repr__(self): - return "Binary(%s, %s)" % (bytes.__repr__(self), self.__subtype) + return f"Binary({bytes.__repr__(self)}, {self.__subtype})" diff --git a/bson/code.py b/bson/code.py index b732e82469..27ec588fae 100644 --- a/bson/code.py +++ b/bson/code.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tools for representing JavaScript code in BSON. -""" +"""Tools for representing JavaScript code in BSON.""" from collections.abc import Mapping as _Mapping from typing import Any, Mapping, Optional, Type, Union @@ -54,7 +53,7 @@ def __new__( cls: Type["Code"], code: Union[str, "Code"], scope: Optional[Mapping[str, Any]] = None, - **kwargs: Any + **kwargs: Any, ) -> "Code": if not isinstance(code, str): raise TypeError("code must be an instance of str") @@ -88,7 +87,7 @@ def scope(self) -> Optional[Mapping[str, Any]]: return self.__scope def __repr__(self): - return "Code(%s, %r)" % (str.__repr__(self), self.__scope) + return f"Code({str.__repr__(self)}, {self.__scope!r})" def __eq__(self, other: Any) -> bool: if isinstance(other, Code): diff --git a/bson/codec_options.py b/bson/codec_options.py index 096be85264..a0bdd0eeb9 100644 --- a/bson/codec_options.py +++ b/bson/codec_options.py @@ -63,12 +63,10 @@ class TypeEncoder(abc.ABC): @abc.abstractproperty def python_type(self) -> Any: """The Python type to be converted into something serializable.""" - pass @abc.abstractmethod def transform_python(self, value: Any) -> Any: """Convert the given Python object into something serializable.""" - pass class TypeDecoder(abc.ABC): @@ -84,12 +82,10 @@ class TypeDecoder(abc.ABC): @abc.abstractproperty def bson_type(self) -> Any: """The BSON type to be converted into our own type.""" - pass @abc.abstractmethod def transform_bson(self, value: Any) -> Any: """Convert the given BSON value into our own type.""" - pass class TypeCodec(TypeEncoder, TypeDecoder): @@ -105,14 +101,12 @@ class TypeCodec(TypeEncoder, TypeDecoder): See :ref:`custom-type-type-codec` documentation for an example. """ - pass - _Codec = Union[TypeEncoder, TypeDecoder, TypeCodec] _Fallback = Callable[[Any], Any] -class TypeRegistry(object): +class TypeRegistry: """Encapsulates type codecs used in encoding and / or decoding BSON, as well as the fallback encoder. Type registries cannot be modified after instantiation. @@ -164,8 +158,7 @@ def __init__( self._decoder_map[codec.bson_type] = codec.transform_bson if not is_valid_codec: raise TypeError( - "Expected an instance of %s, %s, or %s, got %r instead" - % (TypeEncoder.__name__, TypeDecoder.__name__, TypeCodec.__name__, codec) + f"Expected an instance of {TypeEncoder.__name__}, {TypeDecoder.__name__}, or {TypeCodec.__name__}, got {codec!r} instead" ) def _validate_type_encoder(self, codec: _Codec) -> None: @@ -175,12 +168,12 @@ def _validate_type_encoder(self, codec: _Codec) -> None: if issubclass(cast(TypeCodec, codec).python_type, pytype): err_msg = ( "TypeEncoders cannot change how built-in types are " - "encoded (encoder %s transforms type %s)" % (codec, pytype) + "encoded (encoder {} transforms type {})".format(codec, pytype) ) raise TypeError(err_msg) def __repr__(self): - return "%s(type_codecs=%r, fallback_encoder=%r)" % ( + return "{}(type_codecs={!r}, fallback_encoder={!r})".format( self.__class__.__name__, self.__type_codecs, self._fallback_encoder, @@ -446,10 +439,9 @@ def _arguments_repr(self) -> str: ) return ( - "document_class=%s, tz_aware=%r, uuid_representation=%s, " - "unicode_decode_error_handler=%r, tzinfo=%r, " - "type_registry=%r, datetime_conversion=%s" - % ( + "document_class={}, tz_aware={!r}, uuid_representation={}, " + "unicode_decode_error_handler={!r}, tzinfo={!r}, " + "type_registry={!r}, datetime_conversion={!s}".format( document_class_repr, self.tz_aware, uuid_rep_repr, @@ -474,7 +466,7 @@ def _options_dict(self) -> Dict[str, Any]: } def __repr__(self): - return "%s(%s)" % (self.__class__.__name__, self._arguments_repr()) + return f"{self.__class__.__name__}({self._arguments_repr()})" def with_options(self, **kwargs: Any) -> "CodecOptions": """Make a copy of this CodecOptions, overriding some options:: diff --git a/bson/dbref.py b/bson/dbref.py index 7849435f23..491278e6f4 100644 --- a/bson/dbref.py +++ b/bson/dbref.py @@ -21,7 +21,7 @@ from bson.son import SON -class DBRef(object): +class DBRef: """A reference to a document stored in MongoDB.""" __slots__ = "__collection", "__id", "__database", "__kwargs" @@ -36,7 +36,7 @@ def __init__( id: Any, database: Optional[str] = None, _extra: Optional[Mapping[str, Any]] = None, - **kwargs: Any + **kwargs: Any, ) -> None: """Initialize a new :class:`DBRef`. @@ -102,10 +102,10 @@ def as_doc(self) -> SON[str, Any]: return doc def __repr__(self): - extra = "".join([", %s=%r" % (k, v) for k, v in self.__kwargs.items()]) + extra = "".join([f", {k}={v!r}" for k, v in self.__kwargs.items()]) if self.database is None: - return "DBRef(%r, %r%s)" % (self.collection, self.id, extra) - return "DBRef(%r, %r, %r%s)" % (self.collection, self.id, self.database, extra) + return f"DBRef({self.collection!r}, {self.id!r}{extra})" + return f"DBRef({self.collection!r}, {self.id!r}, {self.database!r}{extra})" def __eq__(self, other: Any) -> bool: if isinstance(other, DBRef): diff --git a/bson/decimal128.py b/bson/decimal128.py index bce5b251e9..0e24b5bbae 100644 --- a/bson/decimal128.py +++ b/bson/decimal128.py @@ -115,7 +115,7 @@ def _decimal_to_128(value: _VALUE_OPTIONS) -> Tuple[int, int]: return high, low -class Decimal128(object): +class Decimal128: """BSON Decimal128 type:: >>> Decimal128(Decimal("0.0005")) @@ -226,7 +226,7 @@ def __init__(self, value: _VALUE_OPTIONS) -> None: ) self.__high, self.__low = value # type: ignore else: - raise TypeError("Cannot convert %r to Decimal128" % (value,)) + raise TypeError(f"Cannot convert {value!r} to Decimal128") def to_decimal(self) -> decimal.Decimal: """Returns an instance of :class:`decimal.Decimal` for this @@ -297,7 +297,7 @@ def __str__(self) -> str: return str(dec) def __repr__(self): - return "Decimal128('%s')" % (str(self),) + return f"Decimal128('{str(self)}')" def __setstate__(self, value: Tuple[int, int]) -> None: self.__high, self.__low = value diff --git a/bson/json_util.py b/bson/json_util.py index 8842d5c74d..bc566fa982 100644 --- a/bson/json_util.py +++ b/bson/json_util.py @@ -288,7 +288,7 @@ def __new__( strict_uuid: Optional[bool] = None, json_mode: int = JSONMode.RELAXED, *args: Any, - **kwargs: Any + **kwargs: Any, ) -> "JSONOptions": kwargs["tz_aware"] = kwargs.get("tz_aware", False) if kwargs["tz_aware"]: @@ -303,7 +303,7 @@ def __new__( "JSONOptions.datetime_representation must be one of LEGACY, " "NUMBERLONG, or ISO8601 from DatetimeRepresentation." ) - self = cast(JSONOptions, super(JSONOptions, cls).__new__(cls, *args, **kwargs)) + self = cast(JSONOptions, super().__new__(cls, *args, **kwargs)) if json_mode not in (JSONMode.LEGACY, JSONMode.RELAXED, JSONMode.CANONICAL): raise ValueError( "JSONOptions.json_mode must be one of LEGACY, RELAXED, " @@ -350,21 +350,20 @@ def __new__( def _arguments_repr(self) -> str: return ( - "strict_number_long=%r, " - "datetime_representation=%r, " - "strict_uuid=%r, json_mode=%r, %s" - % ( + "strict_number_long={!r}, " + "datetime_representation={!r}, " + "strict_uuid={!r}, json_mode={!r}, {}".format( self.strict_number_long, self.datetime_representation, self.strict_uuid, self.json_mode, - super(JSONOptions, self)._arguments_repr(), + super()._arguments_repr(), ) ) def _options_dict(self) -> Dict[Any, Any]: # TODO: PYTHON-2442 use _asdict() instead - options_dict = super(JSONOptions, self)._options_dict() + options_dict = super()._options_dict() options_dict.update( { "strict_number_long": self.strict_number_long, @@ -492,7 +491,7 @@ def _json_convert(obj: Any, json_options: JSONOptions = DEFAULT_JSON_OPTIONS) -> if hasattr(obj, "items"): return SON(((k, _json_convert(v, json_options)) for k, v in obj.items())) elif hasattr(obj, "__iter__") and not isinstance(obj, (str, bytes)): - return list((_json_convert(v, json_options) for v in obj)) + return [_json_convert(v, json_options) for v in obj] try: return default(obj, json_options) except TypeError: @@ -568,9 +567,9 @@ def _parse_legacy_regex(doc: Any) -> Any: def _parse_legacy_uuid(doc: Any, json_options: JSONOptions) -> Union[Binary, uuid.UUID]: """Decode a JSON legacy $uuid to Python UUID.""" if len(doc) != 1: - raise TypeError("Bad $uuid, extra field(s): %s" % (doc,)) + raise TypeError(f"Bad $uuid, extra field(s): {doc}") if not isinstance(doc["$uuid"], str): - raise TypeError("$uuid must be a string: %s" % (doc,)) + raise TypeError(f"$uuid must be a string: {doc}") if json_options.uuid_representation == UuidRepresentation.UNSPECIFIED: return Binary.from_uuid(uuid.UUID(doc["$uuid"])) else: @@ -613,11 +612,11 @@ def _parse_canonical_binary(doc: Any, json_options: JSONOptions) -> Union[Binary b64 = binary["base64"] subtype = binary["subType"] if not isinstance(b64, str): - raise TypeError("$binary base64 must be a string: %s" % (doc,)) + raise TypeError(f"$binary base64 must be a string: {doc}") if not isinstance(subtype, str) or len(subtype) > 2: - raise TypeError("$binary subType must be a string at most 2 characters: %s" % (doc,)) + raise TypeError(f"$binary subType must be a string at most 2 characters: {doc}") if len(binary) != 2: - raise TypeError('$binary must include only "base64" and "subType" components: %s' % (doc,)) + raise TypeError(f'$binary must include only "base64" and "subType" components: {doc}') data = base64.b64decode(b64.encode()) return _binary_or_uuid(data, int(subtype, 16), json_options) @@ -629,7 +628,7 @@ def _parse_canonical_datetime( """Decode a JSON datetime to python datetime.datetime.""" dtm = doc["$date"] if len(doc) != 1: - raise TypeError("Bad $date, extra field(s): %s" % (doc,)) + raise TypeError(f"Bad $date, extra field(s): {doc}") # mongoexport 2.6 and newer if isinstance(dtm, str): # Parse offset @@ -692,7 +691,7 @@ def _parse_canonical_datetime( def _parse_canonical_oid(doc: Any) -> ObjectId: """Decode a JSON ObjectId to bson.objectid.ObjectId.""" if len(doc) != 1: - raise TypeError("Bad $oid, extra field(s): %s" % (doc,)) + raise TypeError(f"Bad $oid, extra field(s): {doc}") return ObjectId(doc["$oid"]) @@ -700,7 +699,7 @@ def _parse_canonical_symbol(doc: Any) -> str: """Decode a JSON symbol to Python string.""" symbol = doc["$symbol"] if len(doc) != 1: - raise TypeError("Bad $symbol, extra field(s): %s" % (doc,)) + raise TypeError(f"Bad $symbol, extra field(s): {doc}") return str(symbol) @@ -708,7 +707,7 @@ def _parse_canonical_code(doc: Any) -> Code: """Decode a JSON code to bson.code.Code.""" for key in doc: if key not in ("$code", "$scope"): - raise TypeError("Bad $code, extra field(s): %s" % (doc,)) + raise TypeError(f"Bad $code, extra field(s): {doc}") return Code(doc["$code"], scope=doc.get("$scope")) @@ -716,11 +715,11 @@ def _parse_canonical_regex(doc: Any) -> Regex: """Decode a JSON regex to bson.regex.Regex.""" regex = doc["$regularExpression"] if len(doc) != 1: - raise TypeError("Bad $regularExpression, extra field(s): %s" % (doc,)) + raise TypeError(f"Bad $regularExpression, extra field(s): {doc}") if len(regex) != 2: raise TypeError( 'Bad $regularExpression must include only "pattern"' - 'and "options" components: %s' % (doc,) + 'and "options" components: {}'.format(doc) ) opts = regex["options"] if not isinstance(opts, str): @@ -739,28 +738,28 @@ def _parse_canonical_dbpointer(doc: Any) -> Any: """Decode a JSON (deprecated) DBPointer to bson.dbref.DBRef.""" dbref = doc["$dbPointer"] if len(doc) != 1: - raise TypeError("Bad $dbPointer, extra field(s): %s" % (doc,)) + raise TypeError(f"Bad $dbPointer, extra field(s): {doc}") if isinstance(dbref, DBRef): dbref_doc = dbref.as_doc() # DBPointer must not contain $db in its value. if dbref.database is not None: - raise TypeError("Bad $dbPointer, extra field $db: %s" % (dbref_doc,)) + raise TypeError(f"Bad $dbPointer, extra field $db: {dbref_doc}") if not isinstance(dbref.id, ObjectId): - raise TypeError("Bad $dbPointer, $id must be an ObjectId: %s" % (dbref_doc,)) + raise TypeError(f"Bad $dbPointer, $id must be an ObjectId: {dbref_doc}") if len(dbref_doc) != 2: - raise TypeError("Bad $dbPointer, extra field(s) in DBRef: %s" % (dbref_doc,)) + raise TypeError(f"Bad $dbPointer, extra field(s) in DBRef: {dbref_doc}") return dbref else: - raise TypeError("Bad $dbPointer, expected a DBRef: %s" % (doc,)) + raise TypeError(f"Bad $dbPointer, expected a DBRef: {doc}") def _parse_canonical_int32(doc: Any) -> int: """Decode a JSON int32 to python int.""" i_str = doc["$numberInt"] if len(doc) != 1: - raise TypeError("Bad $numberInt, extra field(s): %s" % (doc,)) + raise TypeError(f"Bad $numberInt, extra field(s): {doc}") if not isinstance(i_str, str): - raise TypeError("$numberInt must be string: %s" % (doc,)) + raise TypeError(f"$numberInt must be string: {doc}") return int(i_str) @@ -768,7 +767,7 @@ def _parse_canonical_int64(doc: Any) -> Int64: """Decode a JSON int64 to bson.int64.Int64.""" l_str = doc["$numberLong"] if len(doc) != 1: - raise TypeError("Bad $numberLong, extra field(s): %s" % (doc,)) + raise TypeError(f"Bad $numberLong, extra field(s): {doc}") return Int64(l_str) @@ -776,9 +775,9 @@ def _parse_canonical_double(doc: Any) -> float: """Decode a JSON double to python float.""" d_str = doc["$numberDouble"] if len(doc) != 1: - raise TypeError("Bad $numberDouble, extra field(s): %s" % (doc,)) + raise TypeError(f"Bad $numberDouble, extra field(s): {doc}") if not isinstance(d_str, str): - raise TypeError("$numberDouble must be string: %s" % (doc,)) + raise TypeError(f"$numberDouble must be string: {doc}") return float(d_str) @@ -786,18 +785,18 @@ def _parse_canonical_decimal128(doc: Any) -> Decimal128: """Decode a JSON decimal128 to bson.decimal128.Decimal128.""" d_str = doc["$numberDecimal"] if len(doc) != 1: - raise TypeError("Bad $numberDecimal, extra field(s): %s" % (doc,)) + raise TypeError(f"Bad $numberDecimal, extra field(s): {doc}") if not isinstance(d_str, str): - raise TypeError("$numberDecimal must be string: %s" % (doc,)) + raise TypeError(f"$numberDecimal must be string: {doc}") return Decimal128(d_str) def _parse_canonical_minkey(doc: Any) -> MinKey: """Decode a JSON MinKey to bson.min_key.MinKey.""" if type(doc["$minKey"]) is not int or doc["$minKey"] != 1: - raise TypeError("$minKey value must be 1: %s" % (doc,)) + raise TypeError(f"$minKey value must be 1: {doc}") if len(doc) != 1: - raise TypeError("Bad $minKey, extra field(s): %s" % (doc,)) + raise TypeError(f"Bad $minKey, extra field(s): {doc}") return MinKey() @@ -806,7 +805,7 @@ def _parse_canonical_maxkey(doc: Any) -> MaxKey: if type(doc["$maxKey"]) is not int or doc["$maxKey"] != 1: raise TypeError("$maxKey value must be 1: %s", (doc,)) if len(doc) != 1: - raise TypeError("Bad $minKey, extra field(s): %s" % (doc,)) + raise TypeError(f"Bad $minKey, extra field(s): {doc}") return MaxKey() @@ -839,7 +838,7 @@ def default(obj: Any, json_options: JSONOptions = DEFAULT_JSON_OPTIONS) -> Any: millis = int(obj.microsecond / 1000) fracsecs = ".%03d" % (millis,) if millis else "" return { - "$date": "%s%s%s" % (obj.strftime("%Y-%m-%dT%H:%M:%S"), fracsecs, tz_string) + "$date": "{}{}{}".format(obj.strftime("%Y-%m-%dT%H:%M:%S"), fracsecs, tz_string) } millis = _datetime_to_millis(obj) diff --git a/bson/max_key.py b/bson/max_key.py index b4f38d072e..eb5705d378 100644 --- a/bson/max_key.py +++ b/bson/max_key.py @@ -12,12 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Representation for the MongoDB internal MaxKey type. -""" +"""Representation for the MongoDB internal MaxKey type.""" from typing import Any -class MaxKey(object): +class MaxKey: """MongoDB internal MaxKey type.""" __slots__ = () diff --git a/bson/min_key.py b/bson/min_key.py index babc655e43..2c8f73d560 100644 --- a/bson/min_key.py +++ b/bson/min_key.py @@ -12,12 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Representation for the MongoDB internal MinKey type. -""" +"""Representation for the MongoDB internal MinKey type.""" from typing import Any -class MinKey(object): +class MinKey: """MongoDB internal MinKey type.""" __slots__ = () diff --git a/bson/objectid.py b/bson/objectid.py index 1fab986b8b..b045e93d04 100644 --- a/bson/objectid.py +++ b/bson/objectid.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tools for working with MongoDB ObjectIds. -""" +"""Tools for working with MongoDB ObjectIds.""" import binascii import calendar @@ -43,7 +42,7 @@ def _random_bytes() -> bytes: return os.urandom(5) -class ObjectId(object): +class ObjectId: """A MongoDB ObjectId.""" _pid = os.getpid() @@ -166,7 +165,6 @@ def _random(cls) -> bytes: def __generate(self) -> None: """Generate a new value for this ObjectId.""" - # 4 bytes current time oid = struct.pack(">I", int(time.time())) @@ -202,9 +200,7 @@ def __validate(self, oid: Any) -> None: else: _raise_invalid_id(oid) else: - raise TypeError( - "id must be an instance of (bytes, str, ObjectId), not %s" % (type(oid),) - ) + raise TypeError(f"id must be an instance of (bytes, str, ObjectId), not {type(oid)}") @property def binary(self) -> bytes: @@ -224,13 +220,13 @@ def generation_time(self) -> datetime.datetime: return datetime.datetime.fromtimestamp(timestamp, utc) def __getstate__(self) -> bytes: - """return value of object for pickling. + """Return value of object for pickling. needed explicitly because __slots__() defined. """ return self.__id def __setstate__(self, value: Any) -> None: - """explicit state set from pickling""" + """Explicit state set from pickling""" # Provide backwards compatibility with OIDs # pickled with pymongo-1.9 or older. if isinstance(value, dict): @@ -249,7 +245,7 @@ def __str__(self) -> str: return binascii.hexlify(self.__id).decode() def __repr__(self): - return "ObjectId('%s')" % (str(self),) + return f"ObjectId('{str(self)}')" def __eq__(self, other: Any) -> bool: if isinstance(other, ObjectId): diff --git a/bson/raw_bson.py b/bson/raw_bson.py index 2c2b3c97ca..bb1dbd22a5 100644 --- a/bson/raw_bson.py +++ b/bson/raw_bson.py @@ -131,7 +131,7 @@ class from the standard library so it can be used like a read-only elif not issubclass(codec_options.document_class, RawBSONDocument): raise TypeError( "RawBSONDocument cannot use CodecOptions with document " - "class %s" % (codec_options.document_class,) + "class {}".format(codec_options.document_class) ) self.__codec_options = codec_options # Validate the bson object size. @@ -174,7 +174,7 @@ def __eq__(self, other: Any) -> bool: return NotImplemented def __repr__(self): - return "%s(%r, codec_options=%r)" % ( + return "{}({!r}, codec_options={!r})".format( self.__class__.__name__, self.raw, self.__codec_options, diff --git a/bson/regex.py b/bson/regex.py index 3e98477198..c06e493f38 100644 --- a/bson/regex.py +++ b/bson/regex.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tools for representing MongoDB regular expressions. -""" +"""Tools for representing MongoDB regular expressions.""" import re from typing import Any, Generic, Pattern, Type, TypeVar, Union @@ -117,7 +116,7 @@ def __ne__(self, other: Any) -> bool: return not self == other def __repr__(self): - return "Regex(%r, %r)" % (self.pattern, self.flags) + return f"Regex({self.pattern!r}, {self.flags!r})" def try_compile(self) -> "Pattern[_T]": """Compile this :class:`Regex` as a Python regular expression. diff --git a/bson/son.py b/bson/son.py index bba108aa80..482e8d2584 100644 --- a/bson/son.py +++ b/bson/son.py @@ -16,7 +16,8 @@ Regular dictionaries can be used instead of SON objects, but not when the order of keys is important. A SON object can be used just like a normal Python -dictionary.""" +dictionary. +""" import copy import re @@ -58,7 +59,7 @@ class SON(Dict[_Key, _Value]): def __init__( self, data: Optional[Union[Mapping[_Key, _Value], Iterable[Tuple[_Key, _Value]]]] = None, - **kwargs: Any + **kwargs: Any, ) -> None: self.__keys = [] dict.__init__(self) @@ -66,14 +67,14 @@ def __init__( self.update(kwargs) def __new__(cls: Type["SON[_Key, _Value]"], *args: Any, **kwargs: Any) -> "SON[_Key, _Value]": - instance = super(SON, cls).__new__(cls, *args, **kwargs) # type: ignore[type-var] + instance = super().__new__(cls, *args, **kwargs) # type: ignore[type-var] instance.__keys = [] return instance def __repr__(self): result = [] for key in self.__keys: - result.append("(%r, %r)" % (key, self[key])) + result.append(f"({key!r}, {self[key]!r})") return "SON([%s])" % ", ".join(result) def __setitem__(self, key: _Key, value: _Value) -> None: @@ -94,8 +95,7 @@ def copy(self) -> "SON[_Key, _Value]": # efficient. # second level definitions support higher levels def __iter__(self) -> Iterator[_Key]: - for k in self.__keys: - yield k + yield from self.__keys def has_key(self, key: _Key) -> bool: return key in self.__keys @@ -113,7 +113,7 @@ def values(self) -> List[_Value]: # type: ignore[override] def clear(self) -> None: self.__keys = [] - super(SON, self).clear() + super().clear() def setdefault(self, key: _Key, default: _Value) -> _Value: try: @@ -189,7 +189,7 @@ def transform_value(value: Any) -> Any: if isinstance(value, list): return [transform_value(v) for v in value] elif isinstance(value, _Mapping): - return dict([(k, transform_value(v)) for k, v in value.items()]) + return {k: transform_value(v) for k, v in value.items()} else: return value diff --git a/bson/timestamp.py b/bson/timestamp.py index a333b9fa3e..5591b60e41 100644 --- a/bson/timestamp.py +++ b/bson/timestamp.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tools for representing MongoDB internal Timestamps. -""" +"""Tools for representing MongoDB internal Timestamps.""" import calendar import datetime @@ -25,7 +24,7 @@ UPPERBOUND = 4294967296 -class Timestamp(object): +class Timestamp: """MongoDB internal timestamps used in the opLog.""" __slots__ = ("__time", "__inc") @@ -113,7 +112,7 @@ def __ge__(self, other: Any) -> bool: return NotImplemented def __repr__(self): - return "Timestamp(%s, %s)" % (self.__time, self.__inc) + return f"Timestamp({self.__time}, {self.__inc})" def as_datetime(self) -> datetime.datetime: """Return a :class:`~datetime.datetime` instance corresponding diff --git a/gridfs/__init__.py b/gridfs/__init__.py index 692567b2de..9a4cda5527 100644 --- a/gridfs/__init__.py +++ b/gridfs/__init__.py @@ -53,7 +53,7 @@ ] -class GridFS(object): +class GridFS: """An instance of GridFS on top of a single Database.""" def __init__(self, database: Database, collection: str = "fs"): @@ -141,7 +141,6 @@ def put(self, data: Any, **kwargs: Any) -> Any: .. versionchanged:: 3.0 w=0 writes to GridFS are now prohibited. """ - with GridIn(self.__collection, **kwargs) as grid_file: grid_file.write(data) return grid_file._id @@ -449,7 +448,7 @@ def exists( return f is not None -class GridFSBucket(object): +class GridFSBucket: """An instance of GridFS on top of a single Database.""" def __init__( diff --git a/gridfs/grid_file.py b/gridfs/grid_file.py index 5ec6352684..fd260963d7 100644 --- a/gridfs/grid_file.py +++ b/gridfs/grid_file.py @@ -76,7 +76,7 @@ def setter(self: Any, value: Any) -> Any: if read_only: docstring += "\n\nThis attribute is read-only." elif closed_only: - docstring = "%s\n\n%s" % ( + docstring = "{}\n\n{}".format( docstring, "This attribute is read-only and " "can only be read after :meth:`close` " @@ -114,7 +114,7 @@ def _disallow_transactions(session: Optional[ClientSession]) -> None: raise InvalidOperation("GridFS does not support multi-document transactions") -class GridIn(object): +class GridIn: """Class to write data to GridFS.""" def __init__( @@ -497,7 +497,7 @@ def _ensure_file(self) -> None: self._file = self.__files.find_one({"_id": self.__file_id}, session=self._session) if not self._file: raise NoFile( - "no file in gridfs collection %r with _id %r" % (self.__files, self.__file_id) + f"no file in gridfs collection {self.__files!r} with _id {self.__file_id!r}" ) def __getattr__(self, name: str) -> Any: @@ -640,10 +640,10 @@ def seek(self, pos: int, whence: int = _SEEK_SET) -> int: elif whence == _SEEK_END: new_pos = int(self.length) + pos else: - raise IOError(22, "Invalid value for `whence`") + raise OSError(22, "Invalid value for `whence`") if new_pos < 0: - raise IOError(22, "Invalid value for `pos` - must be positive") + raise OSError(22, "Invalid value for `pos` - must be positive") # Optimization, continue using the same buffer and chunk iterator. if new_pos == self.__position: @@ -732,7 +732,7 @@ def __del__(self) -> None: pass -class _GridOutChunkIterator(object): +class _GridOutChunkIterator: """Iterates over a file's chunks using a single cursor. Raises CorruptGridFile when encountering any truncated, missing, or extra @@ -832,7 +832,7 @@ def close(self) -> None: self._cursor = None -class GridOutIterator(object): +class GridOutIterator: def __init__(self, grid_out: GridOut, chunks: Collection, session: ClientSession): self.__chunk_iter = _GridOutChunkIterator(grid_out, chunks, session, 0) @@ -878,7 +878,7 @@ def __init__( # Hold on to the base "fs" collection to create GridOut objects later. self.__root_collection = collection - super(GridOutCursor, self).__init__( + super().__init__( collection.files, filter, skip=skip, @@ -892,7 +892,7 @@ def __init__( def next(self) -> GridOut: """Get next GridOut object from cursor.""" _disallow_transactions(self.session) - next_file = super(GridOutCursor, self).next() + next_file = super().next() return GridOut(self.__root_collection, file_document=next_file, session=self.session) __next__ = next diff --git a/pymongo/_csot.py b/pymongo/_csot.py index 8a4617ecaf..7a5a8a7302 100644 --- a/pymongo/_csot.py +++ b/pymongo/_csot.py @@ -57,7 +57,7 @@ def clamp_remaining(max_timeout: float) -> float: return min(timeout, max_timeout) -class _TimeoutContext(object): +class _TimeoutContext: """Internal timeout context manager. Use :func:`pymongo.timeout` instead:: diff --git a/pymongo/aggregation.py b/pymongo/aggregation.py index a13f164f53..a97455cb29 100644 --- a/pymongo/aggregation.py +++ b/pymongo/aggregation.py @@ -21,7 +21,7 @@ from pymongo.read_preferences import ReadPreference, _AggWritePref -class _AggregationCommand(object): +class _AggregationCommand: """The internal abstract base class for aggregation cursors. Should not be called directly by application developers. Use @@ -202,7 +202,7 @@ def _database(self): class _CollectionRawAggregationCommand(_CollectionAggregationCommand): def __init__(self, *args, **kwargs): - super(_CollectionRawAggregationCommand, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) # For raw-batches, we set the initial batchSize for the cursor to 0. if not self._performs_write: @@ -216,7 +216,7 @@ def _aggregation_target(self): @property def _cursor_namespace(self): - return "%s.$cmd.aggregate" % (self._target.name,) + return f"{self._target.name}.$cmd.aggregate" @property def _database(self): diff --git a/pymongo/auth.py b/pymongo/auth.py index 4bc31ee97b..ac7cb254e9 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -61,7 +61,7 @@ """The authentication mechanisms supported by PyMongo.""" -class _Cache(object): +class _Cache: __slots__ = ("data",) _hash_val = hash("_Cache") @@ -104,7 +104,7 @@ def __hash__(self): def _build_credentials_tuple(mech, source, user, passwd, extra, database): """Build and return a mechanism specific credentials tuple.""" if mech not in ("MONGODB-X509", "MONGODB-AWS", "MONGODB-OIDC") and user is None: - raise ConfigurationError("%s requires a username." % (mech,)) + raise ConfigurationError(f"{mech} requires a username.") if mech == "GSSAPI": if source is not None and source != "$external": raise ValueError("authentication source must be $external or None for GSSAPI") @@ -297,7 +297,7 @@ def _password_digest(username, password): raise TypeError("username must be an instance of str") md5hash = hashlib.md5() - data = "%s:mongo:%s" % (username, password) + data = f"{username}:mongo:{password}" md5hash.update(data.encode("utf-8")) return md5hash.hexdigest() @@ -306,7 +306,7 @@ def _auth_key(nonce, username, password): """Get an auth key to use for authentication.""" digest = _password_digest(username, password) md5hash = hashlib.md5() - data = "%s%s%s" % (nonce, username, digest) + data = f"{nonce}{username}{digest}" md5hash.update(data.encode("utf-8")) return md5hash.hexdigest() @@ -448,7 +448,7 @@ def _authenticate_plain(credentials, sock_info): source = credentials.source username = credentials.username password = credentials.password - payload = ("\x00%s\x00%s" % (username, password)).encode("utf-8") + payload = (f"\x00{username}\x00{password}").encode() cmd = SON( [ ("saslStart", 1), @@ -518,7 +518,7 @@ def _authenticate_default(credentials, sock_info): } -class _AuthContext(object): +class _AuthContext: def __init__(self, credentials, address): self.credentials = credentials self.speculative_authenticate = None @@ -543,7 +543,7 @@ def speculate_succeeded(self): class _ScramContext(_AuthContext): def __init__(self, credentials, address, mechanism): - super(_ScramContext, self).__init__(credentials, address) + super().__init__(credentials, address) self.scram_data = None self.mechanism = mechanism @@ -569,7 +569,7 @@ def speculate_command(self): authenticator = _get_authenticator(self.credentials, self.address) cmd = authenticator.auth_start_cmd(False) if cmd is None: - return + return None cmd["db"] = self.credentials.source return cmd diff --git a/pymongo/auth_aws.py b/pymongo/auth_aws.py index e84465ea66..bfa4c731d3 100644 --- a/pymongo/auth_aws.py +++ b/pymongo/auth_aws.py @@ -21,7 +21,7 @@ _HAVE_MONGODB_AWS = True except ImportError: - class AwsSaslContext(object): # type: ignore + class AwsSaslContext: # type: ignore def __init__(self, credentials): pass @@ -102,9 +102,7 @@ def _authenticate_aws(credentials, sock_info): # Clear the cached credentials if we hit a failure in auth. set_cached_credentials(None) # Convert to OperationFailure and include pymongo-auth-aws version. - raise OperationFailure( - "%s (pymongo-auth-aws version %s)" % (exc, pymongo_auth_aws.__version__) - ) + raise OperationFailure(f"{exc} (pymongo-auth-aws version {pymongo_auth_aws.__version__})") except Exception: # Clear the cached credentials if we hit a failure in auth. set_cached_credentials(None) diff --git a/pymongo/auth_oidc.py b/pymongo/auth_oidc.py index 530b1bb068..543dc0200d 100644 --- a/pymongo/auth_oidc.py +++ b/pymongo/auth_oidc.py @@ -131,11 +131,11 @@ def get_current_token(self, use_callbacks=True): refresh_token = self.idp_resp and self.idp_resp.get("refresh_token") refresh_token = refresh_token or "" - context = dict( - timeout_seconds=timeout, - version=CALLBACK_VERSION, - refresh_token=refresh_token, - ) + context = { + "timeout_seconds": timeout, + "version": CALLBACK_VERSION, + "refresh_token": refresh_token, + } if self.idp_resp is None or refresh_cb is None: self.idp_resp = request_cb(self.idp_info, context) @@ -181,7 +181,7 @@ def auth_start_cmd(self, use_callbacks=True): aws_identity_file = os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] with open(aws_identity_file) as fid: token = fid.read().strip() - payload = dict(jwt=token) + payload = {"jwt": token} cmd = SON( [ ("saslStart", 1), @@ -203,7 +203,7 @@ def auth_start_cmd(self, use_callbacks=True): if self.idp_info is None: # Send the SASL start with the optional principal name. - payload = dict() + payload = {} if principal_name: payload["n"] = principal_name @@ -221,7 +221,7 @@ def auth_start_cmd(self, use_callbacks=True): token = self.get_current_token(use_callbacks) if not token: return None - bin_payload = Binary(bson.encode(dict(jwt=token))) + bin_payload = Binary(bson.encode({"jwt": token})) return SON( [ ("saslStart", 1), @@ -268,7 +268,7 @@ def authenticate(self, sock_info, reauthenticate=False): if resp["done"]: sock_info.oidc_token_gen_id = self.token_gen_id - return + return None server_resp: Dict = bson.decode(resp["payload"]) if "issuer" in server_resp: @@ -278,7 +278,7 @@ def authenticate(self, sock_info, reauthenticate=False): conversation_id = resp["conversationId"] token = self.get_current_token() sock_info.oidc_token_gen_id = self.token_gen_id - bin_payload = Binary(bson.encode(dict(jwt=token))) + bin_payload = Binary(bson.encode({"jwt": token})) cmd = SON( [ ("saslContinue", 1), diff --git a/pymongo/bulk.py b/pymongo/bulk.py index b21b576aa5..b0f61b9f9f 100644 --- a/pymongo/bulk.py +++ b/pymongo/bulk.py @@ -60,7 +60,7 @@ _COMMANDS = ("insert", "update", "delete") -class _Run(object): +class _Run: """Represents a batch of write operations.""" def __init__(self, op_type): @@ -136,7 +136,7 @@ def _raise_bulk_write_error(full_result: Any) -> NoReturn: raise BulkWriteError(full_result) -class _Bulk(object): +class _Bulk: """The private guts of the bulk write API.""" def __init__(self, collection, ordered, bypass_document_validation, comment=None, let=None): @@ -509,5 +509,6 @@ def execute(self, write_concern, session): if not write_concern.acknowledged: with client._socket_for_writes(session) as sock_info: self.execute_no_results(sock_info, generator, write_concern) + return None else: return self.execute_command(generator, write_concern, session) diff --git a/pymongo/change_stream.py b/pymongo/change_stream.py index 775f93c79a..c53f981188 100644 --- a/pymongo/change_stream.py +++ b/pymongo/change_stream.py @@ -156,7 +156,8 @@ def _aggregation_command_class(self): @property def _client(self): """The client against which the aggregation commands for - this ChangeStream will be run.""" + this ChangeStream will be run. + """ raise NotImplementedError def _change_stream_options(self): @@ -221,7 +222,7 @@ def _process_result(self, result, sock_info): if self._start_at_operation_time is None: raise OperationFailure( "Expected field 'operationTime' missing from command " - "response : %r" % (result,) + "response : {!r}".format(result) ) def _run_aggregation_cmd(self, session, explicit_session): @@ -473,6 +474,6 @@ class ClusterChangeStream(DatabaseChangeStream, Generic[_DocumentType]): """ def _change_stream_options(self): - options = super(ClusterChangeStream, self)._change_stream_options() + options = super()._change_stream_options() options["allChangesForCluster"] = True return options diff --git a/pymongo/client_options.py b/pymongo/client_options.py index 882474e258..c9f63dc95a 100644 --- a/pymongo/client_options.py +++ b/pymongo/client_options.py @@ -167,7 +167,7 @@ def _parse_pool_options(username, password, database, options): ) -class ClientOptions(object): +class ClientOptions: """Read only configuration options for a MongoClient. Should not be instantiated directly by application developers. Access diff --git a/pymongo/client_session.py b/pymongo/client_session.py index 1ec0b16476..dbc5f3aa8d 100644 --- a/pymongo/client_session.py +++ b/pymongo/client_session.py @@ -169,7 +169,7 @@ from pymongo.write_concern import WriteConcern -class SessionOptions(object): +class SessionOptions: """Options for a new :class:`ClientSession`. :Parameters: @@ -203,8 +203,9 @@ def __init__( if not isinstance(default_transaction_options, TransactionOptions): raise TypeError( "default_transaction_options must be an instance of " - "pymongo.client_session.TransactionOptions, not: %r" - % (default_transaction_options,) + "pymongo.client_session.TransactionOptions, not: {!r}".format( + default_transaction_options + ) ) self._default_transaction_options = default_transaction_options self._snapshot = snapshot @@ -232,7 +233,7 @@ def snapshot(self) -> Optional[bool]: return self._snapshot -class TransactionOptions(object): +class TransactionOptions: """Options for :meth:`ClientSession.start_transaction`. :Parameters: @@ -275,25 +276,25 @@ def __init__( if not isinstance(read_concern, ReadConcern): raise TypeError( "read_concern must be an instance of " - "pymongo.read_concern.ReadConcern, not: %r" % (read_concern,) + "pymongo.read_concern.ReadConcern, not: {!r}".format(read_concern) ) if write_concern is not None: if not isinstance(write_concern, WriteConcern): raise TypeError( "write_concern must be an instance of " - "pymongo.write_concern.WriteConcern, not: %r" % (write_concern,) + "pymongo.write_concern.WriteConcern, not: {!r}".format(write_concern) ) if not write_concern.acknowledged: raise ConfigurationError( "transactions do not support unacknowledged write concern" - ": %r" % (write_concern,) + ": {!r}".format(write_concern) ) if read_preference is not None: if not isinstance(read_preference, _ServerMode): raise TypeError( - "%r is not valid for read_preference. See " + "{!r} is not valid for read_preference. See " "pymongo.read_preferences for valid " - "options." % (read_preference,) + "options.".format(read_preference) ) if max_commit_time_ms is not None: if not isinstance(max_commit_time_ms, int): @@ -340,12 +341,12 @@ def _validate_session_write_concern(session, write_concern): else: raise ConfigurationError( "Explicit sessions are incompatible with " - "unacknowledged write concern: %r" % (write_concern,) + "unacknowledged write concern: {!r}".format(write_concern) ) return session -class _TransactionContext(object): +class _TransactionContext: """Internal transaction context manager for start_transaction.""" def __init__(self, session): @@ -362,7 +363,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.__session.abort_transaction() -class _TxnState(object): +class _TxnState: NONE = 1 STARTING = 2 IN_PROGRESS = 3 @@ -371,7 +372,7 @@ class _TxnState(object): ABORTED = 6 -class _Transaction(object): +class _Transaction: """Internal class to hold transaction information in a ClientSession.""" def __init__(self, opts, client): @@ -973,7 +974,7 @@ def _apply_to(self, command, is_retryable, read_preference, sock_info): if read_preference != ReadPreference.PRIMARY: raise InvalidOperation( "read preference in a transaction must be primary, not: " - "%r" % (read_preference,) + "{!r}".format(read_preference) ) if self._transaction.state == _TxnState.STARTING: @@ -1023,7 +1024,7 @@ def inc_transaction_id(self): self.started_retryable_write = True -class _ServerSession(object): +class _ServerSession: def __init__(self, generation): # Ensure id is type 4, regardless of CodecOptions.uuid_representation. self.session_id = {"id": Binary(uuid.uuid4().bytes, 4)} @@ -1062,7 +1063,7 @@ class _ServerSessionPool(collections.deque): """ def __init__(self, *args, **kwargs): - super(_ServerSessionPool, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self.generation = 0 def reset(self): diff --git a/pymongo/collation.py b/pymongo/collation.py index 3d8503f7d5..bdc996be1b 100644 --- a/pymongo/collation.py +++ b/pymongo/collation.py @@ -21,7 +21,7 @@ from pymongo import common -class CollationStrength(object): +class CollationStrength: """ An enum that defines values for `strength` on a :class:`~pymongo.collation.Collation`. @@ -43,7 +43,7 @@ class CollationStrength(object): """Differentiate unicode code point (characters are exactly identical).""" -class CollationAlternate(object): +class CollationAlternate: """ An enum that defines values for `alternate` on a :class:`~pymongo.collation.Collation`. @@ -62,7 +62,7 @@ class CollationAlternate(object): """ -class CollationMaxVariable(object): +class CollationMaxVariable: """ An enum that defines values for `max_variable` on a :class:`~pymongo.collation.Collation`. @@ -75,7 +75,7 @@ class CollationMaxVariable(object): """Spaces alone are ignored.""" -class CollationCaseFirst(object): +class CollationCaseFirst: """ An enum that defines values for `case_first` on a :class:`~pymongo.collation.Collation`. @@ -91,7 +91,7 @@ class CollationCaseFirst(object): """Default for locale or collation strength.""" -class Collation(object): +class Collation: """Collation :Parameters: @@ -163,7 +163,7 @@ def __init__( maxVariable: Optional[str] = None, normalization: Optional[bool] = None, backwards: Optional[bool] = None, - **kwargs: Any + **kwargs: Any, ) -> None: locale = common.validate_string("locale", locale) self.__document: Dict[str, Any] = {"locale": locale} @@ -201,7 +201,7 @@ def document(self) -> Dict[str, Any]: def __repr__(self): document = self.document - return "Collation(%s)" % (", ".join("%s=%r" % (key, document[key]) for key in document),) + return "Collation({})".format(", ".join(f"{key}={document[key]!r}" for key in document)) def __eq__(self, other: Any) -> bool: if isinstance(other, Collation): diff --git a/pymongo/collection.py b/pymongo/collection.py index 91b4013ee8..3b9001240e 100644 --- a/pymongo/collection.py +++ b/pymongo/collection.py @@ -88,7 +88,7 @@ ] -class ReturnDocument(object): +class ReturnDocument: """An enum used with :meth:`~pymongo.collection.Collection.find_one_and_replace` and :meth:`~pymongo.collection.Collection.find_one_and_update`. @@ -201,7 +201,7 @@ def __init__( .. seealso:: The MongoDB documentation on `collections `_. """ - super(Collection, self).__init__( + super().__init__( codec_options or database.codec_options, read_preference or database.read_preference, write_concern or database.write_concern, @@ -212,7 +212,7 @@ def __init__( if not name or ".." in name: raise InvalidName("collection names cannot be empty") - if "$" in name and not (name.startswith("oplog.$main") or name.startswith("$cmd")): + if "$" in name and not (name.startswith(("oplog.$main", "$cmd"))): raise InvalidName("collection names must not contain '$': %r" % name) if name[0] == "." or name[-1] == ".": raise InvalidName("collection names must not start or end with '.': %r" % name) @@ -222,7 +222,7 @@ def __init__( self.__database: Database[_DocumentType] = database self.__name = name - self.__full_name = "%s.%s" % (self.__database.name, self.__name) + self.__full_name = f"{self.__database.name}.{self.__name}" self.__write_response_codec_options = self.codec_options._replace( unicode_decode_error_handler="replace", document_class=dict ) @@ -344,17 +344,17 @@ def __getattr__(self, name: str) -> "Collection[_DocumentType]": - `name`: the name of the collection to get """ if name.startswith("_"): - full_name = "%s.%s" % (self.__name, name) + full_name = f"{self.__name}.{name}" raise AttributeError( - "Collection has no attribute %r. To access the %s" - " collection, use database['%s']." % (name, full_name, full_name) + "Collection has no attribute {!r}. To access the {}" + " collection, use database['{}'].".format(name, full_name, full_name) ) return self.__getitem__(name) def __getitem__(self, name: str) -> "Collection[_DocumentType]": return Collection( self.__database, - "%s.%s" % (self.__name, name), + f"{self.__name}.{name}", False, self.codec_options, self.read_preference, @@ -363,7 +363,7 @@ def __getitem__(self, name: str) -> "Collection[_DocumentType]": ) def __repr__(self): - return "Collection(%r, %r)" % (self.__database, self.__name) + return f"Collection({self.__database!r}, {self.__name!r})" def __eq__(self, other: Any) -> bool: if isinstance(other, Collection): @@ -541,7 +541,7 @@ def bulk_write( try: request._add_to_bulk(blk) except AttributeError: - raise TypeError("%r is not a valid request" % (request,)) + raise TypeError(f"{request!r} is not a valid request") write_concern = self._write_concern_for(session) bulk_api_result = blk.execute(write_concern, session) @@ -579,6 +579,7 @@ def _insert_command(session, sock_info, retryable_write): if not isinstance(doc, RawBSONDocument): return doc.get("_id") + return None def insert_one( self, @@ -719,7 +720,7 @@ def gen(): write_concern = self._write_concern_for(session) blk = _Bulk(self, ordered, bypass_document_validation, comment=comment) - blk.ops = [doc for doc in gen()] + blk.ops = list(gen()) blk.execute(write_concern, session=session) return InsertManyResult(inserted_ids, write_concern.acknowledged) @@ -1924,7 +1925,7 @@ def gen_indexes(): for index in indexes: if not isinstance(index, IndexModel): raise TypeError( - "%r is not an instance of pymongo.operations.IndexModel" % (index,) + f"{index!r} is not an instance of pymongo.operations.IndexModel" ) document = index.document names.append(document["name"]) @@ -2442,7 +2443,6 @@ def aggregate( .. _aggregate command: https://mongodb.com/docs/manual/reference/command/aggregate """ - with self.__database.client._tmp_session(session, close=False) as s: return self._aggregate( _CollectionAggregationCommand, @@ -2687,7 +2687,7 @@ def rename( if "$" in new_name and not new_name.startswith("oplog.$main"): raise InvalidName("collection names must not contain '$'") - new_name = "%s.%s" % (self.__database.name, new_name) + new_name = f"{self.__database.name}.{new_name}" cmd = SON([("renameCollection", self.__full_name), ("to", new_name)]) cmd.update(kwargs) if comment is not None: @@ -2794,7 +2794,6 @@ def __find_and_modify( **kwargs, ): """Internal findAndModify helper.""" - common.validate_is_mapping("filter", filter) if not isinstance(return_document, bool): raise ValueError( diff --git a/pymongo/command_cursor.py b/pymongo/command_cursor.py index 6f3f244419..d57b45154d 100644 --- a/pymongo/command_cursor.py +++ b/pymongo/command_cursor.py @@ -132,13 +132,15 @@ def batch_size(self, batch_size: int) -> "CommandCursor[_DocumentType]": def _has_next(self): """Returns `True` if the cursor has documents remaining from the - previous batch.""" + previous batch. + """ return len(self.__data) > 0 @property def _post_batch_resume_token(self): """Retrieve the postBatchResumeToken from the response to a - changeStream aggregate or getMore.""" + changeStream aggregate or getMore. + """ return self.__postbatchresumetoken def _maybe_pin_connection(self, sock_info): @@ -328,7 +330,7 @@ def __init__( .. seealso:: The MongoDB documentation on `cursors `_. """ assert not cursor_info.get("firstBatch") - super(RawBatchCommandCursor, self).__init__( + super().__init__( collection, cursor_info, address, diff --git a/pymongo/common.py b/pymongo/common.py index 4e39c8e514..82c773695a 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -157,7 +157,7 @@ def clean_node(node: str) -> Tuple[str, int]: def raise_config_error(key: str, dummy: Any) -> NoReturn: """Raise ConfigurationError with the given key name.""" - raise ConfigurationError("Unknown option %s" % (key,)) + raise ConfigurationError(f"Unknown option {key}") # Mapping of URI uuid representation options to valid subtypes. @@ -174,14 +174,14 @@ def validate_boolean(option: str, value: Any) -> bool: """Validates that 'value' is True or False.""" if isinstance(value, bool): return value - raise TypeError("%s must be True or False" % (option,)) + raise TypeError(f"{option} must be True or False") def validate_boolean_or_string(option: str, value: Any) -> bool: """Validates that value is True, False, 'true', or 'false'.""" if isinstance(value, str): if value not in ("true", "false"): - raise ValueError("The value of %s must be 'true' or 'false'" % (option,)) + raise ValueError(f"The value of {option} must be 'true' or 'false'") return value == "true" return validate_boolean(option, value) @@ -194,15 +194,15 @@ def validate_integer(option: str, value: Any) -> int: try: return int(value) except ValueError: - raise ValueError("The value of %s must be an integer" % (option,)) - raise TypeError("Wrong type for %s, value must be an integer" % (option,)) + raise ValueError(f"The value of {option} must be an integer") + raise TypeError(f"Wrong type for {option}, value must be an integer") def validate_positive_integer(option: str, value: Any) -> int: """Validate that 'value' is a positive integer, which does not include 0.""" val = validate_integer(option, value) if val <= 0: - raise ValueError("The value of %s must be a positive integer" % (option,)) + raise ValueError(f"The value of {option} must be a positive integer") return val @@ -210,7 +210,7 @@ def validate_non_negative_integer(option: str, value: Any) -> int: """Validate that 'value' is a positive integer or 0.""" val = validate_integer(option, value) if val < 0: - raise ValueError("The value of %s must be a non negative integer" % (option,)) + raise ValueError(f"The value of {option} must be a non negative integer") return val @@ -221,7 +221,7 @@ def validate_readable(option: str, value: Any) -> Optional[str]: # First make sure its a string py3.3 open(True, 'r') succeeds # Used in ssl cert checking due to poor ssl module error reporting value = validate_string(option, value) - open(value, "r").close() + open(value).close() return value @@ -243,7 +243,7 @@ def validate_string(option: str, value: Any) -> str: """Validates that 'value' is an instance of `str`.""" if isinstance(value, str): return value - raise TypeError("Wrong type for %s, value must be an instance of str" % (option,)) + raise TypeError(f"Wrong type for {option}, value must be an instance of str") def validate_string_or_none(option: str, value: Any) -> Optional[str]: @@ -262,7 +262,7 @@ def validate_int_or_basestring(option: str, value: Any) -> Union[int, str]: return int(value) except ValueError: return value - raise TypeError("Wrong type for %s, value must be an integer or a string" % (option,)) + raise TypeError(f"Wrong type for {option}, value must be an integer or a string") def validate_non_negative_int_or_basestring(option: Any, value: Any) -> Union[int, str]: @@ -275,16 +275,14 @@ def validate_non_negative_int_or_basestring(option: Any, value: Any) -> Union[in except ValueError: return value return validate_non_negative_integer(option, val) - raise TypeError( - "Wrong type for %s, value must be an non negative integer or a string" % (option,) - ) + raise TypeError(f"Wrong type for {option}, value must be an non negative integer or a string") def validate_positive_float(option: str, value: Any) -> float: """Validates that 'value' is a float, or can be converted to one, and is positive. """ - errmsg = "%s must be an integer or float" % (option,) + errmsg = f"{option} must be an integer or float" try: value = float(value) except ValueError: @@ -295,7 +293,7 @@ def validate_positive_float(option: str, value: Any) -> float: # float('inf') doesn't work in 2.4 or 2.5 on Windows, so just cap floats at # one billion - this is a reasonable approximation for infinity if not 0 < value < 1e9: - raise ValueError("%s must be greater than 0 and less than one billion" % (option,)) + raise ValueError(f"{option} must be greater than 0 and less than one billion") return value @@ -324,7 +322,7 @@ def validate_timeout_or_zero(option: str, value: Any) -> float: config error. """ if value is None: - raise ConfigurationError("%s cannot be None" % (option,)) + raise ConfigurationError(f"{option} cannot be None") if value == 0 or value == "0": return 0 return validate_positive_float(option, value) / 1000.0 @@ -360,7 +358,7 @@ def validate_max_staleness(option: str, value: Any) -> int: def validate_read_preference(dummy: Any, value: Any) -> _ServerMode: """Validate a read preference.""" if not isinstance(value, _ServerMode): - raise TypeError("%r is not a read preference." % (value,)) + raise TypeError(f"{value!r} is not a read preference.") return value @@ -372,14 +370,14 @@ def validate_read_preference_mode(dummy: Any, value: Any) -> _ServerMode: mode. """ if value not in _MONGOS_MODES: - raise ValueError("%s is not a valid read preference" % (value,)) + raise ValueError(f"{value} is not a valid read preference") return value def validate_auth_mechanism(option: str, value: Any) -> str: """Validate the authMechanism URI option.""" if value not in MECHANISMS: - raise ValueError("%s must be in %s" % (option, tuple(MECHANISMS))) + raise ValueError(f"{option} must be in {tuple(MECHANISMS)}") return value @@ -389,9 +387,9 @@ def validate_uuid_representation(dummy: Any, value: Any) -> int: return _UUID_REPRESENTATIONS[value] except KeyError: raise ValueError( - "%s is an invalid UUID representation. " + "{} is an invalid UUID representation. " "Must be one of " - "%s" % (value, tuple(_UUID_REPRESENTATIONS)) + "{}".format(value, tuple(_UUID_REPRESENTATIONS)) ) @@ -412,7 +410,7 @@ def validate_read_preference_tags(name: str, value: Any) -> List[Dict[str, str]] tags[unquote_plus(key)] = unquote_plus(val) tag_sets.append(tags) except Exception: - raise ValueError("%r not a valid value for %s" % (tag_set, name)) + raise ValueError(f"{tag_set!r} not a valid value for {name}") return tag_sets @@ -472,13 +470,13 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> Dict[str, Uni raise ValueError( "auth mechanism properties must be " "key:value pairs like SERVICE_NAME:" - "mongodb, not %s." % (opt,) + "mongodb, not {}.".format(opt) ) if key not in _MECHANISM_PROPS: raise ValueError( - "%s is not a supported auth " + "{} is not a supported auth " "mechanism property. Must be one of " - "%s." % (key, tuple(_MECHANISM_PROPS)) + "{}.".format(key, tuple(_MECHANISM_PROPS)) ) if key == "CANONICALIZE_HOST_NAME": props[key] = validate_boolean_or_string(key, val) @@ -502,9 +500,9 @@ def validate_document_class( is_mapping = issubclass(value.__origin__, abc.MutableMapping) if not is_mapping and not issubclass(value, RawBSONDocument): raise TypeError( - "%s must be dict, bson.son.SON, " + "{} must be dict, bson.son.SON, " "bson.raw_bson.RawBSONDocument, or a " - "subclass of collections.MutableMapping" % (option,) + "subclass of collections.MutableMapping".format(option) ) return value @@ -512,14 +510,14 @@ def validate_document_class( def validate_type_registry(option: Any, value: Any) -> Optional[TypeRegistry]: """Validate the type_registry option.""" if value is not None and not isinstance(value, TypeRegistry): - raise TypeError("%s must be an instance of %s" % (option, TypeRegistry)) + raise TypeError(f"{option} must be an instance of {TypeRegistry}") return value def validate_list(option: str, value: Any) -> List: """Validates that 'value' is a list.""" if not isinstance(value, list): - raise TypeError("%s must be a list" % (option,)) + raise TypeError(f"{option} must be a list") return value @@ -534,9 +532,9 @@ def validate_list_or_mapping(option: Any, value: Any) -> None: """Validates that 'value' is a list or a document.""" if not isinstance(value, (abc.Mapping, list)): raise TypeError( - "%s must either be a list or an instance of dict, " + "{} must either be a list or an instance of dict, " "bson.son.SON, or any other type that inherits from " - "collections.Mapping" % (option,) + "collections.Mapping".format(option) ) @@ -544,9 +542,9 @@ def validate_is_mapping(option: str, value: Any) -> None: """Validate the type of method arguments that expect a document.""" if not isinstance(value, abc.Mapping): raise TypeError( - "%s must be an instance of dict, bson.son.SON, or " + "{} must be an instance of dict, bson.son.SON, or " "any other type that inherits from " - "collections.Mapping" % (option,) + "collections.Mapping".format(option) ) @@ -554,10 +552,10 @@ def validate_is_document_type(option: str, value: Any) -> None: """Validate the type of method arguments that expect a MongoDB document.""" if not isinstance(value, (abc.MutableMapping, RawBSONDocument)): raise TypeError( - "%s must be an instance of dict, bson.son.SON, " + "{} must be an instance of dict, bson.son.SON, " "bson.raw_bson.RawBSONDocument, or " "a type that inherits from " - "collections.MutableMapping" % (option,) + "collections.MutableMapping".format(option) ) @@ -568,7 +566,7 @@ def validate_appname_or_none(option: str, value: Any) -> Optional[str]: validate_string(option, value) # We need length in bytes, so encode utf8 first. if len(value.encode("utf-8")) > 128: - raise ValueError("%s must be <= 128 bytes" % (option,)) + raise ValueError(f"{option} must be <= 128 bytes") return value @@ -577,7 +575,7 @@ def validate_driver_or_none(option: Any, value: Any) -> Optional[DriverInfo]: if value is None: return value if not isinstance(value, DriverInfo): - raise TypeError("%s must be an instance of DriverInfo" % (option,)) + raise TypeError(f"{option} must be an instance of DriverInfo") return value @@ -586,7 +584,7 @@ def validate_server_api_or_none(option: Any, value: Any) -> Optional[ServerApi]: if value is None: return value if not isinstance(value, ServerApi): - raise TypeError("%s must be an instance of ServerApi" % (option,)) + raise TypeError(f"{option} must be an instance of ServerApi") return value @@ -595,7 +593,7 @@ def validate_is_callable_or_none(option: Any, value: Any) -> Optional[Callable]: if value is None: return value if not callable(value): - raise ValueError("%s must be a callable" % (option,)) + raise ValueError(f"{option} must be a callable") return value @@ -629,9 +627,9 @@ def validate_unicode_decode_error_handler(dummy: Any, value: str) -> str: """Validate the Unicode decode error handler option of CodecOptions.""" if value not in _UNICODE_DECODE_ERROR_HANDLERS: raise ValueError( - "%s is an invalid Unicode decode error handler. " + "{} is an invalid Unicode decode error handler. " "Must be one of " - "%s" % (value, tuple(_UNICODE_DECODE_ERROR_HANDLERS)) + "{}".format(value, tuple(_UNICODE_DECODE_ERROR_HANDLERS)) ) return value @@ -650,7 +648,7 @@ def validate_auto_encryption_opts_or_none(option: Any, value: Any) -> Optional[A from pymongo.encryption_options import AutoEncryptionOpts if not isinstance(value, AutoEncryptionOpts): - raise TypeError("%s must be an instance of AutoEncryptionOpts" % (option,)) + raise TypeError(f"{option} must be an instance of AutoEncryptionOpts") return value @@ -667,7 +665,7 @@ def validate_datetime_conversion(option: Any, value: Any) -> Optional[DatetimeCo elif isinstance(value, int): return DatetimeConversion(value) - raise TypeError("%s must be a str or int representing DatetimeConversion" % (option,)) + raise TypeError(f"{option} must be a str or int representing DatetimeConversion") # Dictionary where keys are the names of public URI options, and values @@ -805,7 +803,7 @@ def validate_auth_option(option: str, value: Any) -> Tuple[str, Any]: """Validate optional authentication parameters.""" lower, value = validate(option, value) if lower not in _AUTH_OPTIONS: - raise ConfigurationError("Unknown authentication option: %s" % (option,)) + raise ConfigurationError(f"Unknown authentication option: {option}") return option, value @@ -866,7 +864,7 @@ def _ecoc_coll_name(encrypted_fields, name): WRITE_CONCERN_OPTIONS = frozenset(["w", "wtimeout", "wtimeoutms", "fsync", "j", "journal"]) -class BaseObject(object): +class BaseObject: """A base class that provides attributes and methods common to multiple pymongo classes. @@ -886,9 +884,9 @@ def __init__( if not isinstance(read_preference, _ServerMode): raise TypeError( - "%r is not valid for read_preference. See " + "{!r} is not valid for read_preference. See " "pymongo.read_preferences for valid " - "options." % (read_preference,) + "options.".format(read_preference) ) self.__read_preference = read_preference diff --git a/pymongo/compression_support.py b/pymongo/compression_support.py index c9632a43d3..40bad403f3 100644 --- a/pymongo/compression_support.py +++ b/pymongo/compression_support.py @@ -40,8 +40,8 @@ from pymongo.hello import HelloCompat from pymongo.monitoring import _SENSITIVE_COMMANDS -_SUPPORTED_COMPRESSORS = set(["snappy", "zlib", "zstd"]) -_NO_COMPRESSION = set([HelloCompat.CMD, HelloCompat.LEGACY_CMD]) +_SUPPORTED_COMPRESSORS = {"snappy", "zlib", "zstd"} +_NO_COMPRESSION = {HelloCompat.CMD, HelloCompat.LEGACY_CMD} _NO_COMPRESSION.update(_SENSITIVE_COMMANDS) @@ -56,7 +56,7 @@ def validate_compressors(dummy, value): for compressor in compressors[:]: if compressor not in _SUPPORTED_COMPRESSORS: compressors.remove(compressor) - warnings.warn("Unsupported compressor: %s" % (compressor,)) + warnings.warn(f"Unsupported compressor: {compressor}") elif compressor == "snappy" and not _HAVE_SNAPPY: compressors.remove(compressor) warnings.warn( @@ -82,13 +82,13 @@ def validate_zlib_compression_level(option, value): try: level = int(value) except Exception: - raise TypeError("%s must be an integer, not %r." % (option, value)) + raise TypeError(f"{option} must be an integer, not {value!r}.") if level < -1 or level > 9: raise ValueError("%s must be between -1 and 9, not %d." % (option, level)) return level -class CompressionSettings(object): +class CompressionSettings: def __init__(self, compressors, zlib_compression_level): self.compressors = compressors self.zlib_compression_level = zlib_compression_level @@ -102,9 +102,11 @@ def get_compression_context(self, compressors): return ZlibContext(self.zlib_compression_level) elif chosen == "zstd": return ZstdContext() + return None + return None -class SnappyContext(object): +class SnappyContext: compressor_id = 1 @staticmethod @@ -112,7 +114,7 @@ def compress(data): return snappy.compress(data) -class ZlibContext(object): +class ZlibContext: compressor_id = 2 def __init__(self, level): @@ -122,7 +124,7 @@ def compress(self, data: bytes) -> bytes: return zlib.compress(data, self.level) -class ZstdContext(object): +class ZstdContext: compressor_id = 3 @staticmethod diff --git a/pymongo/cursor.py b/pymongo/cursor.py index ccf0bfd71b..cc4e1a1146 100644 --- a/pymongo/cursor.py +++ b/pymongo/cursor.py @@ -97,7 +97,7 @@ } -class CursorType(object): +class CursorType: NON_TAILABLE = 0 """The standard cursor type.""" @@ -126,7 +126,7 @@ class CursorType(object): """ -class _SocketManager(object): +class _SocketManager: """Used with exhaust cursors to ensure the socket is returned.""" def __init__(self, sock, more_to_come): @@ -387,11 +387,11 @@ def _clone(self, deepcopy=True, base=None): "exhaust", "has_filter", ) - data = dict( - (k, v) + data = { + k: v for k, v in self.__dict__.items() if k.startswith("_Cursor__") and k[9:] in values_to_clone - ) + } if deepcopy: data = self._deepcopy(data) base.__dict__.update(data) @@ -412,7 +412,7 @@ def __die(self, synchronous=False): self.__killed = True if self.__id and not already_killed: cursor_id = self.__id - address = _CursorAddress(self.__address, "%s.%s" % (self.__dbname, self.__collname)) + address = _CursorAddress(self.__address, f"{self.__dbname}.{self.__collname}") else: # Skip killCursors. cursor_id = 0 @@ -1322,7 +1322,7 @@ def __init__(self, collection: "Collection[_DocumentType]", *args: Any, **kwargs .. seealso:: The MongoDB documentation on `cursors `_. """ - super(RawBatchCursor, self).__init__(collection, *args, **kwargs) + super().__init__(collection, *args, **kwargs) def _unpack_response( self, response, cursor_id, codec_options, user_fields=None, legacy_response=False diff --git a/pymongo/database.py b/pymongo/database.py index 1e19d860e3..66cfce2090 100644 --- a/pymongo/database.py +++ b/pymongo/database.py @@ -125,7 +125,7 @@ def __init__( db.__my_collection__ """ - super(Database, self).__init__( + super().__init__( codec_options or client.codec_options, read_preference or client.read_preference, write_concern or client.write_concern, @@ -211,7 +211,7 @@ def __hash__(self) -> int: return hash((self.__client, self.__name)) def __repr__(self): - return "Database(%r, %r)" % (self.__client, self.__name) + return f"Database({self.__client!r}, {self.__name!r})" def __getattr__(self, name: str) -> Collection[_DocumentType]: """Get a collection of this database by name. @@ -223,8 +223,8 @@ def __getattr__(self, name: str) -> Collection[_DocumentType]: """ if name.startswith("_"): raise AttributeError( - "Database has no attribute %r. To access the %s" - " collection, use database[%r]." % (name, name, name) + "Database has no attribute {!r}. To access the {}" + " collection, use database[{!r}].".format(name, name, name) ) return self.__getitem__(name) @@ -415,9 +415,9 @@ def create_collection( { // key pattern must be {_id: 1} key: , // required - unique: , // required, must be ‘true’ + unique: , // required, must be `true` name: , // optional, otherwise automatically generated - v: , // optional, must be ‘2’ if provided + v: , // optional, must be `2` if provided } - ``changeStreamPreAndPostImages`` (dict): a document with a boolean field ``enabled`` for enabling pre- and post-images. @@ -863,7 +863,6 @@ def _cmd(session, server, sock_info, read_preference): def _list_collections(self, sock_info, session, read_preference, **kwargs): """Internal listCollections helper.""" - coll = self.get_collection("$cmd", read_preference=read_preference) cmd = SON([("listCollections", 1), ("cursor", {})]) cmd.update(kwargs) @@ -1128,14 +1127,14 @@ def validate_collection( if "result" in result: info = result["result"] if info.find("exception") != -1 or info.find("corrupt") != -1: - raise CollectionInvalid("%s invalid: %s" % (name, info)) + raise CollectionInvalid(f"{name} invalid: {info}") # Sharded results elif "raw" in result: for _, res in result["raw"].items(): if "result" in res: info = res["result"] if info.find("exception") != -1 or info.find("corrupt") != -1: - raise CollectionInvalid("%s invalid: %s" % (name, info)) + raise CollectionInvalid(f"{name} invalid: {info}") elif not res.get("valid", False): valid = False break @@ -1144,7 +1143,7 @@ def validate_collection( valid = False if not valid: - raise CollectionInvalid("%s invalid: %r" % (name, result)) + raise CollectionInvalid(f"{name} invalid: {result!r}") return result @@ -1200,7 +1199,7 @@ def dereference( if dbref.database is not None and dbref.database != self.__name: raise ValueError( "trying to dereference a DBRef that points to " - "another database (%r not %r)" % (dbref.database, self.__name) + "another database ({!r} not {!r})".format(dbref.database, self.__name) ) return self[dbref.collection].find_one( {"_id": dbref.id}, session=session, comment=comment, **kwargs diff --git a/pymongo/driver_info.py b/pymongo/driver_info.py index 53fbfd3428..86ddfcfb3e 100644 --- a/pymongo/driver_info.py +++ b/pymongo/driver_info.py @@ -31,12 +31,12 @@ class DriverInfo(namedtuple("DriverInfo", ["name", "version", "platform"])): def __new__( cls, name: str, version: Optional[str] = None, platform: Optional[str] = None ) -> "DriverInfo": - self = super(DriverInfo, cls).__new__(cls, name, version, platform) + self = super().__new__(cls, name, version, platform) for key, value in self._asdict().items(): if value is not None and not isinstance(value, str): raise TypeError( - "Wrong type for DriverInfo %s option, value " - "must be an instance of str" % (key,) + "Wrong type for DriverInfo {} option, value " + "must be an instance of str".format(key) ) return self diff --git a/pymongo/encryption.py b/pymongo/encryption.py index 4c46bf56ae..f2eb71ce71 100644 --- a/pymongo/encryption.py +++ b/pymongo/encryption.py @@ -177,6 +177,7 @@ def collection_info(self, database, filter): with self.client_ref()[database].list_collections(filter=RawBSONDocument(filter)) as cursor: for doc in cursor: return _dict_to_bson(doc, False, _DATA_KEY_OPTS) + return None def spawn(self): """Spawn mongocryptd. @@ -272,7 +273,7 @@ def close(self): self.mongocryptd_client = None -class RewrapManyDataKeyResult(object): +class RewrapManyDataKeyResult: """Result object returned by a :meth:`~ClientEncryption.rewrap_many_data_key` operation. .. versionadded:: 4.2 @@ -292,11 +293,12 @@ def bulk_write_result(self) -> Optional[BulkWriteResult]: return self._bulk_write_result -class _Encrypter(object): +class _Encrypter: """Encrypts and decrypts MongoDB commands. This class is used to support automatic encryption and decryption of - MongoDB commands.""" + MongoDB commands. + """ def __init__(self, client, opts): """Create a _Encrypter for a client. diff --git a/pymongo/encryption_options.py b/pymongo/encryption_options.py index 0cb96d7dad..e87d96b31a 100644 --- a/pymongo/encryption_options.py +++ b/pymongo/encryption_options.py @@ -31,7 +31,7 @@ from pymongo.mongo_client import MongoClient -class AutoEncryptionOpts(object): +class AutoEncryptionOpts: """Options to configure automatic client-side field level encryption.""" def __init__( diff --git a/pymongo/errors.py b/pymongo/errors.py index 192eec99d9..36f97f4b5a 100644 --- a/pymongo/errors.py +++ b/pymongo/errors.py @@ -33,7 +33,7 @@ class PyMongoError(Exception): """Base class for all PyMongo exceptions.""" def __init__(self, message: str = "", error_labels: Optional[Iterable[str]] = None) -> None: - super(PyMongoError, self).__init__(message) + super().__init__(message) self._message = message self._error_labels = set(error_labels or []) @@ -105,7 +105,7 @@ def __init__( if errors is not None: if isinstance(errors, dict): error_labels = errors.get("errorLabels") - super(AutoReconnect, self).__init__(message, error_labels) + super().__init__(message, error_labels) self.errors = self.details = errors or [] @@ -125,7 +125,7 @@ def timeout(self) -> bool: def _format_detailed_error(message, details): if details is not None: - message = "%s, full error: %s" % (message, details) + message = f"{message}, full error: {details}" return message @@ -148,9 +148,7 @@ class NotPrimaryError(AutoReconnect): def __init__( self, message: str = "", errors: Optional[Union[Mapping[str, Any], List]] = None ) -> None: - super(NotPrimaryError, self).__init__( - _format_detailed_error(message, errors), errors=errors - ) + super().__init__(_format_detailed_error(message, errors), errors=errors) class ServerSelectionTimeoutError(AutoReconnect): @@ -191,9 +189,7 @@ def __init__( error_labels = None if details is not None: error_labels = details.get("errorLabels") - super(OperationFailure, self).__init__( - _format_detailed_error(error, details), error_labels=error_labels - ) + super().__init__(_format_detailed_error(error, details), error_labels=error_labels) self.__code = code self.__details = details self.__max_wire_version = max_wire_version @@ -293,7 +289,7 @@ class BulkWriteError(OperationFailure): details: Mapping[str, Any] def __init__(self, results: Mapping[str, Any]) -> None: - super(BulkWriteError, self).__init__("batch op errors occurred", 65, results) + super().__init__("batch op errors occurred", 65, results) def __reduce__(self) -> Tuple[Any, Any]: return self.__class__, (self.details,) @@ -331,8 +327,6 @@ class InvalidURI(ConfigurationError): class DocumentTooLarge(InvalidDocument): """Raised when an encoded document is too large for the connected server.""" - pass - class EncryptionError(PyMongoError): """Raised when encryption or decryption fails. @@ -344,7 +338,7 @@ class EncryptionError(PyMongoError): """ def __init__(self, cause: Exception) -> None: - super(EncryptionError, self).__init__(str(cause)) + super().__init__(str(cause)) self.__cause = cause @property @@ -369,7 +363,7 @@ class EncryptedCollectionError(EncryptionError): """ def __init__(self, cause: Exception, encrypted_fields: Mapping[str, Any]) -> None: - super(EncryptedCollectionError, self).__init__(cause) + super().__init__(cause) self.__encrypted_fields = encrypted_fields @property @@ -386,5 +380,3 @@ def encrypted_fields(self) -> Mapping[str, Any]: class _OperationCancelled(AutoReconnect): """Internal error raised when a socket operation is cancelled.""" - - pass diff --git a/pymongo/helpers.py b/pymongo/helpers.py index 1a753c66f4..f4582854dc 100644 --- a/pymongo/helpers.py +++ b/pymongo/helpers.py @@ -74,7 +74,7 @@ def _gen_index_name(keys): """Generate an index name from the set of fields it is over.""" - return "_".join(["%s_%s" % item for item in keys]) + return "_".join(["{}_{}".format(*item) for item in keys]) def _index_list(key_or_list, direction=None): @@ -248,12 +248,10 @@ def _fields_list_to_dict(fields, option_name): if isinstance(fields, (abc.Sequence, abc.Set)): if not all(isinstance(field, str) for field in fields): - raise TypeError( - "%s must be a list of key names, each an instance of str" % (option_name,) - ) + raise TypeError(f"{option_name} must be a list of key names, each an instance of str") return dict.fromkeys(fields, 1) - raise TypeError("%s must be a mapping or list of key names" % (option_name,)) + raise TypeError(f"{option_name} must be a mapping or list of key names") def _handle_exception(): @@ -266,7 +264,7 @@ def _handle_exception(): einfo = sys.exc_info() try: traceback.print_exception(einfo[0], einfo[1], einfo[2], None, sys.stderr) - except IOError: + except OSError: pass finally: del einfo diff --git a/pymongo/message.py b/pymongo/message.py index 3510d210a5..34f6e6235d 100644 --- a/pymongo/message.py +++ b/pymongo/message.py @@ -115,7 +115,6 @@ def _convert_exception(exception): def _convert_write_result(operation, command, result): """Convert a legacy write result to write command format.""" - # Based on _merge_legacy from bulk.py affected = result.get("n", 0) res = {"ok": 1, "n": affected} @@ -240,7 +239,7 @@ def _gen_get_more_command(cursor_id, coll, batch_size, max_await_time_ms, commen return cmd -class _Query(object): +class _Query: """A query operation.""" __slots__ = ( @@ -310,7 +309,7 @@ def reset(self): self._as_command = None def namespace(self): - return "%s.%s" % (self.db, self.coll) + return f"{self.db}.{self.coll}" def use_command(self, sock_info): use_find_cmd = False @@ -421,7 +420,7 @@ def get_message(self, read_preference, sock_info, use_cmd=False): ) -class _GetMore(object): +class _GetMore: """A getmore operation.""" __slots__ = ( @@ -475,7 +474,7 @@ def reset(self): self._as_command = None def namespace(self): - return "%s.%s" % (self.db, self.coll) + return f"{self.db}.{self.coll}" def use_command(self, sock_info): use_cmd = False @@ -518,7 +517,6 @@ def as_command(self, sock_info, apply_timeout=False): def get_message(self, dummy0, sock_info, use_cmd=False): """Get a getmore message.""" - ns = self.namespace() ctx = sock_info.compression_context @@ -539,7 +537,7 @@ def get_message(self, dummy0, sock_info, use_cmd=False): class _RawBatchQuery(_Query): def use_command(self, sock_info): # Compatibility checks. - super(_RawBatchQuery, self).use_command(sock_info) + super().use_command(sock_info) if sock_info.max_wire_version >= 8: # MongoDB 4.2+ supports exhaust over OP_MSG return True @@ -551,7 +549,7 @@ def use_command(self, sock_info): class _RawBatchGetMore(_GetMore): def use_command(self, sock_info): # Compatibility checks. - super(_RawBatchGetMore, self).use_command(sock_info) + super().use_command(sock_info) if sock_info.max_wire_version >= 8: # MongoDB 4.2+ supports exhaust over OP_MSG return True @@ -578,7 +576,7 @@ def namespace(self): def __hash__(self): # Two _CursorAddress instances with different namespaces # must not hash the same. - return (self + (self.__namespace,)).__hash__() + return ((*self, self.__namespace)).__hash__() def __eq__(self, other): if isinstance(other, _CursorAddress): @@ -648,7 +646,7 @@ def _op_msg_no_header(flags, command, identifier, docs, opts): encoded_size = _pack_int(size) total_size += size max_doc_size = max(len(doc) for doc in encoded_docs) - data = [flags_type, encoded, type_one, encoded_size, cstring] + encoded_docs + data = [flags_type, encoded, type_one, encoded_size, cstring, *encoded_docs] else: data = [flags_type, encoded] return b"".join(data), total_size, max_doc_size @@ -795,7 +793,7 @@ def _get_more(collection_name, num_to_return, cursor_id, ctx=None): return _get_more_uncompressed(collection_name, num_to_return, cursor_id) -class _BulkWriteContext(object): +class _BulkWriteContext: """A wrapper around SocketInfo for use with write splitting functions.""" __slots__ = ( @@ -1033,7 +1031,7 @@ def _raise_document_too_large(operation: str, doc_size: int, max_size: int) -> N else: # There's nothing intelligent we can say # about size for update and delete - raise DocumentTooLarge("%r command document too large" % (operation,)) + raise DocumentTooLarge(f"{operation!r} command document too large") # OP_MSG ------------------------------------------------------------- @@ -1253,7 +1251,7 @@ def _batched_write_command_impl(namespace, operation, command, docs, opts, ctx, return to_send, length -class _OpReply(object): +class _OpReply: """A MongoDB OP_REPLY response message.""" __slots__ = ("flags", "cursor_id", "number_returned", "documents") @@ -1363,7 +1361,7 @@ def unpack(cls, msg): return cls(flags, cursor_id, number_returned, documents) -class _OpMsg(object): +class _OpMsg: """A MongoDB OP_MSG response message.""" __slots__ = ("flags", "cursor_id", "number_returned", "payload_document") @@ -1427,12 +1425,12 @@ def unpack(cls, msg): flags, first_payload_type, first_payload_size = cls.UNPACK_FROM(msg) if flags != 0: if flags & cls.CHECKSUM_PRESENT: - raise ProtocolError("Unsupported OP_MSG flag checksumPresent: 0x%x" % (flags,)) + raise ProtocolError(f"Unsupported OP_MSG flag checksumPresent: 0x{flags:x}") if flags ^ cls.MORE_TO_COME: - raise ProtocolError("Unsupported OP_MSG flags: 0x%x" % (flags,)) + raise ProtocolError(f"Unsupported OP_MSG flags: 0x{flags:x}") if first_payload_type != 0: - raise ProtocolError("Unsupported OP_MSG payload type: 0x%x" % (first_payload_type,)) + raise ProtocolError(f"Unsupported OP_MSG payload type: 0x{first_payload_type:x}") if len(msg) != first_payload_size + 5: raise ProtocolError("Unsupported OP_MSG reply: >1 section") diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index ca60affdf5..ccfaaa31c1 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -805,7 +805,7 @@ def __init__( self.__kill_cursors_queue: List = [] self._event_listeners = options.pool_options._event_listeners - super(MongoClient, self).__init__( + super().__init__( options.codec_options, options.read_preference, options.write_concern, @@ -1509,11 +1509,11 @@ def option_repr(option, value): if value is dict: return "document_class=dict" else: - return "document_class=%s.%s" % (value.__module__, value.__name__) + return f"document_class={value.__module__}.{value.__name__}" if option in common.TIMEOUT_OPTIONS and value is not None: - return "%s=%s" % (option, int(value * 1000)) + return f"{option}={int(value * 1000)}" - return "%s=%r" % (option, value) + return f"{option}={value!r}" # Host first... options = [ @@ -1536,7 +1536,7 @@ def option_repr(option, value): return ", ".join(options) def __repr__(self): - return "MongoClient(%s)" % (self._repr_helper(),) + return f"MongoClient({self._repr_helper()})" def __getattr__(self, name: str) -> database.Database[_DocumentType]: """Get a database by name. @@ -1549,8 +1549,8 @@ def __getattr__(self, name: str) -> database.Database[_DocumentType]: """ if name.startswith("_"): raise AttributeError( - "MongoClient has no attribute %r. To access the %s" - " database, use client[%r]." % (name, name, name) + "MongoClient has no attribute {!r}. To access the {}" + " database, use client[{!r}].".format(name, name, name) ) return self.__getitem__(name) @@ -1685,7 +1685,8 @@ def _process_kill_cursors(self): # This method is run periodically by a background thread. def _process_periodic_tasks(self): """Process any pending kill cursors requests and - maintain connection pool parameters.""" + maintain connection pool parameters. + """ try: self._process_kill_cursors() self._topology.update_pool() @@ -1742,7 +1743,7 @@ def _get_server_session(self): def _return_server_session(self, server_session, lock): """Internal: return a _ServerSession to the pool.""" if isinstance(server_session, _EmptyServerSession): - return + return None return self._topology.return_server_session(server_session, lock) def _ensure_session(self, session=None): @@ -2121,7 +2122,7 @@ def _add_retryable_write_error(exc, max_wire_version): exc._add_error_label("RetryableWriteError") -class _MongoClientErrorHandler(object): +class _MongoClientErrorHandler: """Handle errors raised when executing an operation.""" __slots__ = ( diff --git a/pymongo/monitor.py b/pymongo/monitor.py index 9031d4b785..2fc0bf8bab 100644 --- a/pymongo/monitor.py +++ b/pymongo/monitor.py @@ -37,7 +37,7 @@ def _sanitize(error): error.__cause__ = None -class MonitorBase(object): +class MonitorBase: def __init__(self, topology, name, interval, min_interval): """Base class to do periodic work on a background thread. @@ -108,7 +108,7 @@ def __init__(self, server_description, topology, pool, topology_settings): The Topology is weakly referenced. The Pool must be exclusive to this Monitor. """ - super(Monitor, self).__init__( + super().__init__( topology, "pymongo_server_monitor_thread", topology_settings.heartbeat_frequency, @@ -290,7 +290,7 @@ def __init__(self, topology, topology_settings): The Topology is weakly referenced. """ - super(SrvMonitor, self).__init__( + super().__init__( topology, "pymongo_srv_polling_thread", common.MIN_SRV_RESCAN_INTERVAL, @@ -343,7 +343,7 @@ def __init__(self, topology, topology_settings, pool): The Topology is weakly referenced. """ - super(_RttMonitor, self).__init__( + super().__init__( topology, "pymongo_server_rtt_thread", topology_settings.heartbeat_frequency, diff --git a/pymongo/monitoring.py b/pymongo/monitoring.py index 5b729652ad..391ca13540 100644 --- a/pymongo/monitoring.py +++ b/pymongo/monitoring.py @@ -211,7 +211,7 @@ def connection_checked_in(self, event): _LISTENERS = _Listeners([], [], [], [], []) -class _EventListener(object): +class _EventListener: """Abstract base class for all event listeners.""" @@ -486,14 +486,14 @@ def _to_micros(dur): def _validate_event_listeners(option, listeners): """Validate event listeners""" if not isinstance(listeners, abc.Sequence): - raise TypeError("%s must be a list or tuple" % (option,)) + raise TypeError(f"{option} must be a list or tuple") for listener in listeners: if not isinstance(listener, _EventListener): raise TypeError( - "Listeners for %s must be either a " + "Listeners for {} must be either a " "CommandListener, ServerHeartbeatListener, " "ServerListener, TopologyListener, or " - "ConnectionPoolListener." % (option,) + "ConnectionPoolListener.".format(option) ) return listeners @@ -508,10 +508,10 @@ def register(listener: _EventListener) -> None: """ if not isinstance(listener, _EventListener): raise TypeError( - "Listeners for %s must be either a " + "Listeners for {} must be either a " "CommandListener, ServerHeartbeatListener, " "ServerListener, TopologyListener, or " - "ConnectionPoolListener." % (listener,) + "ConnectionPoolListener.".format(listener) ) if isinstance(listener, CommandListener): _LISTENERS.command_listeners.append(listener) @@ -528,19 +528,17 @@ def register(listener: _EventListener) -> None: # Note - to avoid bugs from forgetting which if these is all lowercase and # which are camelCase, and at the same time avoid having to add a test for # every command, use all lowercase here and test against command_name.lower(). -_SENSITIVE_COMMANDS: set = set( - [ - "authenticate", - "saslstart", - "saslcontinue", - "getnonce", - "createuser", - "updateuser", - "copydbgetnonce", - "copydbsaslstart", - "copydb", - ] -) +_SENSITIVE_COMMANDS: set = { + "authenticate", + "saslstart", + "saslcontinue", + "getnonce", + "createuser", + "updateuser", + "copydbgetnonce", + "copydbsaslstart", + "copydb", +} # The "hello" command is also deemed sensitive when attempting speculative @@ -554,7 +552,7 @@ def _is_speculative_authenticate(command_name, doc): return False -class _CommandEvent(object): +class _CommandEvent: """Base class for command events.""" __slots__ = ("__cmd_name", "__rqst_id", "__conn_id", "__op_id", "__service_id") @@ -627,10 +625,10 @@ def __init__( service_id: Optional[ObjectId] = None, ) -> None: if not command: - raise ValueError("%r is not a valid command" % (command,)) + raise ValueError(f"{command!r} is not a valid command") # Command name must be first key. command_name = next(iter(command)) - super(CommandStartedEvent, self).__init__( + super().__init__( command_name, request_id, connection_id, operation_id, service_id=service_id ) cmd_name = command_name.lower() @@ -651,7 +649,7 @@ def database_name(self) -> str: return self.__db def __repr__(self): - return ("<%s %s db: %r, command: %r, operation_id: %s, service_id: %s>") % ( + return ("<{} {} db: {!r}, command: {!r}, operation_id: {}, service_id: {}>").format( self.__class__.__name__, self.connection_id, self.database_name, @@ -687,7 +685,7 @@ def __init__( operation_id: Optional[int], service_id: Optional[ObjectId] = None, ) -> None: - super(CommandSucceededEvent, self).__init__( + super().__init__( command_name, request_id, connection_id, operation_id, service_id=service_id ) self.__duration_micros = _to_micros(duration) @@ -708,7 +706,9 @@ def reply(self) -> _DocumentOut: return self.__reply def __repr__(self): - return ("<%s %s command: %r, operation_id: %s, duration_micros: %s, service_id: %s>") % ( + return ( + "<{} {} command: {!r}, operation_id: {}, duration_micros: {}, service_id: {}>" + ).format( self.__class__.__name__, self.connection_id, self.command_name, @@ -744,7 +744,7 @@ def __init__( operation_id: Optional[int], service_id: Optional[ObjectId] = None, ) -> None: - super(CommandFailedEvent, self).__init__( + super().__init__( command_name, request_id, connection_id, operation_id, service_id=service_id ) self.__duration_micros = _to_micros(duration) @@ -762,9 +762,9 @@ def failure(self) -> _DocumentOut: def __repr__(self): return ( - "<%s %s command: %r, operation_id: %s, duration_micros: %s, " - "failure: %r, service_id: %s>" - ) % ( + "<{} {} command: {!r}, operation_id: {}, duration_micros: {}, " + "failure: {!r}, service_id: {}>" + ).format( self.__class__.__name__, self.connection_id, self.command_name, @@ -775,7 +775,7 @@ def __repr__(self): ) -class _PoolEvent(object): +class _PoolEvent: """Base class for pool events.""" __slots__ = ("__address",) @@ -791,7 +791,7 @@ def address(self) -> _Address: return self.__address def __repr__(self): - return "%s(%r)" % (self.__class__.__name__, self.__address) + return f"{self.__class__.__name__}({self.__address!r})" class PoolCreatedEvent(_PoolEvent): @@ -807,7 +807,7 @@ class PoolCreatedEvent(_PoolEvent): __slots__ = ("__options",) def __init__(self, address: _Address, options: Dict[str, Any]) -> None: - super(PoolCreatedEvent, self).__init__(address) + super().__init__(address) self.__options = options @property @@ -816,7 +816,7 @@ def options(self) -> Dict[str, Any]: return self.__options def __repr__(self): - return "%s(%r, %r)" % (self.__class__.__name__, self.address, self.__options) + return f"{self.__class__.__name__}({self.address!r}, {self.__options!r})" class PoolReadyEvent(_PoolEvent): @@ -846,7 +846,7 @@ class PoolClearedEvent(_PoolEvent): __slots__ = ("__service_id",) def __init__(self, address: _Address, service_id: Optional[ObjectId] = None) -> None: - super(PoolClearedEvent, self).__init__(address) + super().__init__(address) self.__service_id = service_id @property @@ -860,7 +860,7 @@ def service_id(self) -> Optional[ObjectId]: return self.__service_id def __repr__(self): - return "%s(%r, %r)" % (self.__class__.__name__, self.address, self.__service_id) + return f"{self.__class__.__name__}({self.address!r}, {self.__service_id!r})" class PoolClosedEvent(_PoolEvent): @@ -876,7 +876,7 @@ class PoolClosedEvent(_PoolEvent): __slots__ = () -class ConnectionClosedReason(object): +class ConnectionClosedReason: """An enum that defines values for `reason` on a :class:`ConnectionClosedEvent`. @@ -897,7 +897,7 @@ class ConnectionClosedReason(object): """The pool was closed, making the connection no longer valid.""" -class ConnectionCheckOutFailedReason(object): +class ConnectionCheckOutFailedReason: """An enum that defines values for `reason` on a :class:`ConnectionCheckOutFailedEvent`. @@ -916,7 +916,7 @@ class ConnectionCheckOutFailedReason(object): """ -class _ConnectionEvent(object): +class _ConnectionEvent: """Private base class for connection events.""" __slots__ = ("__address",) @@ -932,7 +932,7 @@ def address(self) -> _Address: return self.__address def __repr__(self): - return "%s(%r)" % (self.__class__.__name__, self.__address) + return f"{self.__class__.__name__}({self.__address!r})" class _ConnectionIdEvent(_ConnectionEvent): @@ -950,7 +950,7 @@ def connection_id(self) -> int: return self.__connection_id def __repr__(self): - return "%s(%r, %r)" % (self.__class__.__name__, self.address, self.__connection_id) + return f"{self.__class__.__name__}({self.address!r}, {self.__connection_id!r})" class ConnectionCreatedEvent(_ConnectionIdEvent): @@ -999,7 +999,7 @@ class ConnectionClosedEvent(_ConnectionIdEvent): __slots__ = ("__reason",) def __init__(self, address, connection_id, reason): - super(ConnectionClosedEvent, self).__init__(address, connection_id) + super().__init__(address, connection_id) self.__reason = reason @property @@ -1012,7 +1012,7 @@ def reason(self): return self.__reason def __repr__(self): - return "%s(%r, %r, %r)" % ( + return "{}({!r}, {!r}, {!r})".format( self.__class__.__name__, self.address, self.connection_id, @@ -1060,7 +1060,7 @@ def reason(self) -> str: return self.__reason def __repr__(self): - return "%s(%r, %r)" % (self.__class__.__name__, self.address, self.__reason) + return f"{self.__class__.__name__}({self.address!r}, {self.__reason!r})" class ConnectionCheckedOutEvent(_ConnectionIdEvent): @@ -1091,7 +1091,7 @@ class ConnectionCheckedInEvent(_ConnectionIdEvent): __slots__ = () -class _ServerEvent(object): +class _ServerEvent: """Base class for server events.""" __slots__ = ("__server_address", "__topology_id") @@ -1111,7 +1111,7 @@ def topology_id(self) -> ObjectId: return self.__topology_id def __repr__(self): - return "<%s %s topology_id: %s>" % ( + return "<{} {} topology_id: {}>".format( self.__class__.__name__, self.server_address, self.topology_id, @@ -1130,26 +1130,28 @@ def __init__( self, previous_description: "ServerDescription", new_description: "ServerDescription", - *args: Any + *args: Any, ) -> None: - super(ServerDescriptionChangedEvent, self).__init__(*args) + super().__init__(*args) self.__previous_description = previous_description self.__new_description = new_description @property def previous_description(self) -> "ServerDescription": """The previous - :class:`~pymongo.server_description.ServerDescription`.""" + :class:`~pymongo.server_description.ServerDescription`. + """ return self.__previous_description @property def new_description(self) -> "ServerDescription": """The new - :class:`~pymongo.server_description.ServerDescription`.""" + :class:`~pymongo.server_description.ServerDescription`. + """ return self.__new_description def __repr__(self): - return "<%s %s changed from: %s, to: %s>" % ( + return "<{} {} changed from: {}, to: {}>".format( self.__class__.__name__, self.server_address, self.previous_description, @@ -1175,7 +1177,7 @@ class ServerClosedEvent(_ServerEvent): __slots__ = () -class TopologyEvent(object): +class TopologyEvent: """Base class for topology description events.""" __slots__ = "__topology_id" @@ -1189,7 +1191,7 @@ def topology_id(self) -> ObjectId: return self.__topology_id def __repr__(self): - return "<%s topology_id: %s>" % (self.__class__.__name__, self.topology_id) + return f"<{self.__class__.__name__} topology_id: {self.topology_id}>" class TopologyDescriptionChangedEvent(TopologyEvent): @@ -1204,26 +1206,28 @@ def __init__( self, previous_description: "TopologyDescription", new_description: "TopologyDescription", - *args: Any + *args: Any, ) -> None: - super(TopologyDescriptionChangedEvent, self).__init__(*args) + super().__init__(*args) self.__previous_description = previous_description self.__new_description = new_description @property def previous_description(self) -> "TopologyDescription": """The previous - :class:`~pymongo.topology_description.TopologyDescription`.""" + :class:`~pymongo.topology_description.TopologyDescription`. + """ return self.__previous_description @property def new_description(self) -> "TopologyDescription": """The new - :class:`~pymongo.topology_description.TopologyDescription`.""" + :class:`~pymongo.topology_description.TopologyDescription`. + """ return self.__new_description def __repr__(self): - return "<%s topology_id: %s changed from: %s, to: %s>" % ( + return "<{} topology_id: {} changed from: {}, to: {}>".format( self.__class__.__name__, self.topology_id, self.previous_description, @@ -1249,7 +1253,7 @@ class TopologyClosedEvent(TopologyEvent): __slots__ = () -class _ServerHeartbeatEvent(object): +class _ServerHeartbeatEvent: """Base class for server heartbeat events.""" __slots__ = "__connection_id" @@ -1260,11 +1264,12 @@ def __init__(self, connection_id: _Address) -> None: @property def connection_id(self) -> _Address: """The address (host, port) of the server this heartbeat was sent - to.""" + to. + """ return self.__connection_id def __repr__(self): - return "<%s %s>" % (self.__class__.__name__, self.connection_id) + return f"<{self.__class__.__name__} {self.connection_id}>" class ServerHeartbeatStartedEvent(_ServerHeartbeatEvent): @@ -1287,7 +1292,7 @@ class ServerHeartbeatSucceededEvent(_ServerHeartbeatEvent): def __init__( self, duration: float, reply: Hello, connection_id: _Address, awaited: bool = False ) -> None: - super(ServerHeartbeatSucceededEvent, self).__init__(connection_id) + super().__init__(connection_id) self.__duration = duration self.__reply = reply self.__awaited = awaited @@ -1313,7 +1318,7 @@ def awaited(self) -> bool: return self.__awaited def __repr__(self): - return "<%s %s duration: %s, awaited: %s, reply: %s>" % ( + return "<{} {} duration: {}, awaited: {}, reply: {}>".format( self.__class__.__name__, self.connection_id, self.duration, @@ -1334,7 +1339,7 @@ class ServerHeartbeatFailedEvent(_ServerHeartbeatEvent): def __init__( self, duration: float, reply: Exception, connection_id: _Address, awaited: bool = False ) -> None: - super(ServerHeartbeatFailedEvent, self).__init__(connection_id) + super().__init__(connection_id) self.__duration = duration self.__reply = reply self.__awaited = awaited @@ -1360,7 +1365,7 @@ def awaited(self) -> bool: return self.__awaited def __repr__(self): - return "<%s %s duration: %s, awaited: %s, reply: %r>" % ( + return "<{} {} duration: {}, awaited: {}, reply: {!r}>".format( self.__class__.__name__, self.connection_id, self.duration, @@ -1369,7 +1374,7 @@ def __repr__(self): ) -class _EventListeners(object): +class _EventListeners: """Configure event listeners for a client instance. Any event listeners registered globally are included by default. diff --git a/pymongo/network.py b/pymongo/network.py index a5c5459e14..d105c8b8b5 100644 --- a/pymongo/network.py +++ b/pymongo/network.py @@ -219,15 +219,15 @@ def receive_message(sock_info, request_id, max_message_size=MAX_MESSAGE_SIZE): # No request_id for exhaust cursor "getMore". if request_id is not None: if request_id != response_to: - raise ProtocolError("Got response id %r but expected %r" % (response_to, request_id)) + raise ProtocolError(f"Got response id {response_to!r} but expected {request_id!r}") if length <= 16: raise ProtocolError( - "Message length (%r) not longer than standard message header size (16)" % (length,) + f"Message length ({length!r}) not longer than standard message header size (16)" ) if length > max_message_size: raise ProtocolError( - "Message length (%r) is larger than server max " - "message size (%r)" % (length, max_message_size) + "Message length ({!r}) is larger than server max " + "message size ({!r})".format(length, max_message_size) ) if op_code == 2012: op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER( @@ -240,7 +240,7 @@ def receive_message(sock_info, request_id, max_message_size=MAX_MESSAGE_SIZE): try: unpack_reply = _UNPACK_REPLY[op_code] except KeyError: - raise ProtocolError("Got opcode %r but expected %r" % (op_code, _UNPACK_REPLY.keys())) + raise ProtocolError(f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}") return unpack_reply(data) @@ -281,7 +281,7 @@ def wait_for_read(sock_info, deadline): # Errors raised by sockets (and TLS sockets) when in non-blocking mode. -BLOCKING_IO_ERRORS = (BlockingIOError,) + ssl_support.BLOCKING_IO_ERRORS +BLOCKING_IO_ERRORS = (BlockingIOError, *ssl_support.BLOCKING_IO_ERRORS) def _receive_data_on_socket(sock_info, length, deadline): @@ -299,7 +299,7 @@ def _receive_data_on_socket(sock_info, length, deadline): chunk_length = sock_info.sock.recv_into(mv[bytes_read:]) except BLOCKING_IO_ERRORS: raise socket.timeout("timed out") - except (IOError, OSError) as exc: # noqa: B014 + except OSError as exc: # noqa: B014 if _errno_from_exception(exc) == errno.EINTR: continue raise diff --git a/pymongo/ocsp_cache.py b/pymongo/ocsp_cache.py index 389ee09ce7..0c50902167 100644 --- a/pymongo/ocsp_cache.py +++ b/pymongo/ocsp_cache.py @@ -20,7 +20,7 @@ from pymongo.lock import _create_lock -class _OCSPCache(object): +class _OCSPCache: """A cache for OCSP responses.""" CACHE_KEY_TYPE = namedtuple( # type: ignore diff --git a/pymongo/operations.py b/pymongo/operations.py index ad119f2ecc..3ff4ed57a3 100644 --- a/pymongo/operations.py +++ b/pymongo/operations.py @@ -48,7 +48,7 @@ def _add_to_bulk(self, bulkobj): bulkobj.add_insert(self._doc) def __repr__(self): - return "InsertOne(%r)" % (self._doc,) + return f"InsertOne({self._doc!r})" def __eq__(self, other: Any) -> bool: if type(other) == type(self): @@ -59,7 +59,7 @@ def __ne__(self, other: Any) -> bool: return not self == other -class DeleteOne(object): +class DeleteOne: """Represents a delete_one operation.""" __slots__ = ("_filter", "_collation", "_hint") @@ -104,7 +104,7 @@ def _add_to_bulk(self, bulkobj): bulkobj.add_delete(self._filter, 1, collation=self._collation, hint=self._hint) def __repr__(self): - return "DeleteOne(%r, %r)" % (self._filter, self._collation) + return f"DeleteOne({self._filter!r}, {self._collation!r})" def __eq__(self, other: Any) -> bool: if type(other) == type(self): @@ -115,7 +115,7 @@ def __ne__(self, other: Any) -> bool: return not self == other -class DeleteMany(object): +class DeleteMany: """Represents a delete_many operation.""" __slots__ = ("_filter", "_collation", "_hint") @@ -160,7 +160,7 @@ def _add_to_bulk(self, bulkobj): bulkobj.add_delete(self._filter, 0, collation=self._collation, hint=self._hint) def __repr__(self): - return "DeleteMany(%r, %r)" % (self._filter, self._collation) + return f"DeleteMany({self._filter!r}, {self._collation!r})" def __eq__(self, other: Any) -> bool: if type(other) == type(self): @@ -242,7 +242,7 @@ def __ne__(self, other: Any) -> bool: return not self == other def __repr__(self): - return "%s(%r, %r, %r, %r, %r)" % ( + return "{}({!r}, {!r}, {!r}, {!r}, {!r})".format( self.__class__.__name__, self._filter, self._doc, @@ -252,7 +252,7 @@ def __repr__(self): ) -class _UpdateOp(object): +class _UpdateOp: """Private base class for update operations.""" __slots__ = ("_filter", "_doc", "_upsert", "_collation", "_array_filters", "_hint") @@ -298,7 +298,7 @@ def __ne__(self, other): return not self == other def __repr__(self): - return "%s(%r, %r, %r, %r, %r, %r)" % ( + return "{}({!r}, {!r}, {!r}, {!r}, {!r}, {!r})".format( self.__class__.__name__, self._filter, self._doc, @@ -352,7 +352,7 @@ def __init__( .. versionchanged:: 3.5 Added the `collation` option. """ - super(UpdateOne, self).__init__(filter, update, upsert, collation, array_filters, hint) + super().__init__(filter, update, upsert, collation, array_filters, hint) def _add_to_bulk(self, bulkobj): """Add this operation to the _Bulk instance `bulkobj`.""" @@ -410,7 +410,7 @@ def __init__( .. versionchanged:: 3.5 Added the `collation` option. """ - super(UpdateMany, self).__init__(filter, update, upsert, collation, array_filters, hint) + super().__init__(filter, update, upsert, collation, array_filters, hint) def _add_to_bulk(self, bulkobj): """Add this operation to the _Bulk instance `bulkobj`.""" @@ -425,7 +425,7 @@ def _add_to_bulk(self, bulkobj): ) -class IndexModel(object): +class IndexModel: """Represents an index to create.""" __slots__ = ("__document",) diff --git a/pymongo/periodic_executor.py b/pymongo/periodic_executor.py index 95e7830674..24090e0160 100644 --- a/pymongo/periodic_executor.py +++ b/pymongo/periodic_executor.py @@ -22,7 +22,7 @@ from pymongo.lock import _create_lock -class PeriodicExecutor(object): +class PeriodicExecutor: def __init__(self, interval, min_interval, target, name=None): """ "Run a target function periodically on a background thread. @@ -51,7 +51,7 @@ def __init__(self, interval, min_interval, target, name=None): self._lock = _create_lock() def __repr__(self): - return "<%s(name=%s) object at 0x%x>" % (self.__class__.__name__, self._name, id(self)) + return f"<{self.__class__.__name__}(name={self._name}) object at 0x{id(self):x}>" def open(self) -> None: """Start. Multiple calls have no effect. diff --git a/pymongo/pool.py b/pymongo/pool.py index 6ba1554231..5bae8ce878 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -81,7 +81,6 @@ def _set_non_inheritable_non_atomic(fd): # everything we need from fcntl, etc. def _set_non_inheritable_non_atomic(fd): """Dummy function for platforms that don't provide fcntl.""" - pass _MAX_TCP_KEEPIDLE = 120 @@ -134,7 +133,7 @@ def _set_tcp_option(sock, tcp_option, max_value): default = sock.getsockopt(socket.IPPROTO_TCP, sockopt) if default > max_value: sock.setsockopt(socket.IPPROTO_TCP, sockopt, max_value) - except socket.error: + except OSError: pass def _set_keepalive_times(sock): @@ -351,7 +350,7 @@ def _raise_connection_failure( if port is not None: msg = "%s:%d: %s" % (host, port, error) else: - msg = "%s: %s" % (host, error) + msg = f"{host}: {error}" if msg_prefix: msg = msg_prefix + msg if isinstance(error, socket.timeout): @@ -371,7 +370,7 @@ def _cond_wait(condition, deadline): return condition.wait(timeout) -class PoolOptions(object): +class PoolOptions: """Read only connection pool options for a MongoClient. Should not be instantiated directly by application developers. Access @@ -456,17 +455,17 @@ def __init__( # } if driver: if driver.name: - self.__metadata["driver"]["name"] = "%s|%s" % ( + self.__metadata["driver"]["name"] = "{}|{}".format( _METADATA["driver"]["name"], driver.name, ) if driver.version: - self.__metadata["driver"]["version"] = "%s|%s" % ( + self.__metadata["driver"]["version"] = "{}|{}".format( _METADATA["driver"]["version"], driver.version, ) if driver.platform: - self.__metadata["platform"] = "%s|%s" % (_METADATA["platform"], driver.platform) + self.__metadata["platform"] = "{}|{}".format(_METADATA["platform"], driver.platform) env = _metadata_env() if env: @@ -601,7 +600,7 @@ def load_balanced(self): return self.__load_balanced -class _CancellationContext(object): +class _CancellationContext: def __init__(self): self._cancelled = False @@ -615,7 +614,7 @@ def cancelled(self): return self._cancelled -class SocketInfo(object): +class SocketInfo: """Store a socket with some metadata. :Parameters: @@ -1080,7 +1079,7 @@ def __hash__(self): return hash(self.sock) def __repr__(self): - return "SocketInfo(%s)%s at %s" % ( + return "SocketInfo({}){} at {}".format( repr(self.sock), self.closed and " CLOSED" or "", id(self), @@ -1106,7 +1105,7 @@ def _create_connection(address, options): try: sock.connect(host) return sock - except socket.error: + except OSError: sock.close() raise @@ -1125,7 +1124,7 @@ def _create_connection(address, options): # all file descriptors are created non-inheritable. See PEP 446. try: sock = socket.socket(af, socktype | getattr(socket, "SOCK_CLOEXEC", 0), proto) - except socket.error: + except OSError: # Can SOCK_CLOEXEC be defined even if the kernel doesn't support # it? sock = socket.socket(af, socktype, proto) @@ -1144,7 +1143,7 @@ def _create_connection(address, options): _set_keepalive_times(sock) sock.connect(sa) return sock - except socket.error as e: + except OSError as e: err = e sock.close() @@ -1155,7 +1154,7 @@ def _create_connection(address, options): # host with an OS/kernel or Python interpreter that doesn't # support IPv6. The test case is Jython2.5.1 which doesn't # support IPv6 at all. - raise socket.error("getaddrinfo failed") + raise OSError("getaddrinfo failed") def _configured_socket(address, options): @@ -1182,7 +1181,7 @@ def _configured_socket(address, options): # Raise _CertificateError directly like we do after match_hostname # below. raise - except (IOError, OSError, SSLError) as exc: # noqa: B014 + except (OSError, SSLError) as exc: # noqa: B014 sock.close() # We raise AutoReconnect for transient and permanent SSL handshake # failures alike. Permanent handshake failures, like protocol @@ -1208,10 +1207,8 @@ class _PoolClosedError(PyMongoError): closed pool. """ - pass - -class _PoolGeneration(object): +class _PoolGeneration: def __init__(self): # Maps service_id to generation. self._generations = collections.defaultdict(int) @@ -1242,7 +1239,7 @@ def stale(self, gen, service_id): return gen != self.get(service_id) -class PoolState(object): +class PoolState: PAUSED = 1 READY = 2 CLOSED = 3 @@ -1753,10 +1750,9 @@ def _raise_wait_queue_timeout(self) -> NoReturn: other_ops = self.active_sockets - self.ncursors - self.ntxns raise WaitQueueTimeoutError( "Timeout waiting for connection from the connection pool. " - "maxPoolSize: %s, connections in use by cursors: %s, " - "connections in use by transactions: %s, connections in use " - "by other operations: %s, timeout: %s" - % ( + "maxPoolSize: {}, connections in use by cursors: {}, " + "connections in use by transactions: {}, connections in use " + "by other operations: {}, timeout: {}".format( self.opts.max_pool_size, self.ncursors, self.ntxns, @@ -1766,7 +1762,7 @@ def _raise_wait_queue_timeout(self) -> NoReturn: ) raise WaitQueueTimeoutError( "Timed out while checking out a connection from connection pool. " - "maxPoolSize: %s, timeout: %s" % (self.opts.max_pool_size, timeout) + "maxPoolSize: {}, timeout: {}".format(self.opts.max_pool_size, timeout) ) def __del__(self): diff --git a/pymongo/pyopenssl_context.py b/pymongo/pyopenssl_context.py index 2d9c904bb3..bfc52df671 100644 --- a/pymongo/pyopenssl_context.py +++ b/pymongo/pyopenssl_context.py @@ -67,7 +67,7 @@ _stdlibssl.CERT_REQUIRED: _SSL.VERIFY_PEER | _SSL.VERIFY_FAIL_IF_NO_PEER_CERT, } -_REVERSE_VERIFY_MAP = dict((value, key) for key, value in _VERIFY_MAP.items()) +_REVERSE_VERIFY_MAP = {value: key for key, value in _VERIFY_MAP.items()} # For SNI support. According to RFC6066, section 3, IPv4 and IPv6 literals are @@ -97,7 +97,7 @@ class _sslConn(_SSL.Connection): def __init__(self, ctx, sock, suppress_ragged_eofs): self.socket_checker = _SocketChecker() self.suppress_ragged_eofs = suppress_ragged_eofs - super(_sslConn, self).__init__(ctx, sock) + super().__init__(ctx, sock) def _call(self, call, *args, **kwargs): timeout = self.gettimeout() @@ -122,11 +122,11 @@ def _call(self, call, *args, **kwargs): continue def do_handshake(self, *args, **kwargs): - return self._call(super(_sslConn, self).do_handshake, *args, **kwargs) + return self._call(super().do_handshake, *args, **kwargs) def recv(self, *args, **kwargs): try: - return self._call(super(_sslConn, self).recv, *args, **kwargs) + return self._call(super().recv, *args, **kwargs) except _SSL.SysCallError as exc: # Suppress ragged EOFs to match the stdlib. if self.suppress_ragged_eofs and _ragged_eof(exc): @@ -135,7 +135,7 @@ def recv(self, *args, **kwargs): def recv_into(self, *args, **kwargs): try: - return self._call(super(_sslConn, self).recv_into, *args, **kwargs) + return self._call(super().recv_into, *args, **kwargs) except _SSL.SysCallError as exc: # Suppress ragged EOFs to match the stdlib. if self.suppress_ragged_eofs and _ragged_eof(exc): @@ -148,11 +148,11 @@ def sendall(self, buf, flags=0): total_sent = 0 while total_sent < total_length: try: - sent = self._call(super(_sslConn, self).send, view[total_sent:], flags) + sent = self._call(super().send, view[total_sent:], flags) # XXX: It's not clear if this can actually happen. PyOpenSSL # doesn't appear to have any interrupt handling, nor any interrupt # errors for OpenSSL connections. - except (IOError, OSError) as exc: # noqa: B014 + except OSError as exc: # noqa: B014 if _errno_from_exception(exc) == _EINTR: continue raise @@ -163,7 +163,7 @@ def sendall(self, buf, flags=0): total_sent += sent -class _CallbackData(object): +class _CallbackData: """Data class which is passed to the OCSP callback.""" def __init__(self): @@ -172,7 +172,7 @@ def __init__(self): self.ocsp_response_cache = _OCSPCache() -class SSLContext(object): +class SSLContext: """A CPython compatible SSLContext implementation wrapping PyOpenSSL's context. """ @@ -328,7 +328,8 @@ def load_default_certs(self): def set_default_verify_paths(self): """Specify that the platform provided CA certificates are to be used - for verification purposes.""" + for verification purposes. + """ # Note: See PyOpenSSL's docs for limitations, which are similar # but not that same as CPython's. self._ctx.set_default_verify_paths() diff --git a/pymongo/read_concern.py b/pymongo/read_concern.py index dfb3930ab0..c673c44780 100644 --- a/pymongo/read_concern.py +++ b/pymongo/read_concern.py @@ -17,7 +17,7 @@ from typing import Any, Dict, Optional -class ReadConcern(object): +class ReadConcern: """ReadConcern :Parameters: @@ -45,7 +45,8 @@ def level(self) -> Optional[str]: @property def ok_for_legacy(self) -> bool: """Return ``True`` if this read concern is compatible with - old wire protocol versions.""" + old wire protocol versions. + """ return self.level is None or self.level == "local" @property diff --git a/pymongo/read_preferences.py b/pymongo/read_preferences.py index 46f029ed31..f3aa003a1c 100644 --- a/pymongo/read_preferences.py +++ b/pymongo/read_preferences.py @@ -46,18 +46,18 @@ def _validate_tag_sets(tag_sets): return tag_sets if not isinstance(tag_sets, (list, tuple)): - raise TypeError(("Tag sets %r invalid, must be a sequence") % (tag_sets,)) + raise TypeError(f"Tag sets {tag_sets!r} invalid, must be a sequence") if len(tag_sets) == 0: raise ValueError( - ("Tag sets %r invalid, must be None or contain at least one set of tags") % (tag_sets,) + f"Tag sets {tag_sets!r} invalid, must be None or contain at least one set of tags" ) for tags in tag_sets: if not isinstance(tags, abc.Mapping): raise TypeError( - "Tag set %r invalid, must be an instance of dict, " + "Tag set {!r} invalid, must be an instance of dict, " "bson.son.SON or other type that inherits from " - "collection.Mapping" % (tags,) + "collection.Mapping".format(tags) ) return list(tag_sets) @@ -88,7 +88,7 @@ def _validate_hedge(hedge): return None if not isinstance(hedge, dict): - raise TypeError("hedge must be a dictionary, not %r" % (hedge,)) + raise TypeError(f"hedge must be a dictionary, not {hedge!r}") return hedge @@ -97,7 +97,7 @@ def _validate_hedge(hedge): _TagSets = Sequence[Mapping[str, Any]] -class _ServerMode(object): +class _ServerMode: """Base class for all read preferences.""" __slots__ = ("__mongos_mode", "__mode", "__tag_sets", "__max_staleness", "__hedge") @@ -168,7 +168,8 @@ def tag_sets(self) -> _TagSets: def max_staleness(self) -> int: """The maximum estimated length of time (in seconds) a replica set secondary can fall behind the primary in replication before it will - no longer be selected for operations, or -1 for no maximum.""" + no longer be selected for operations, or -1 for no maximum. + """ return self.__max_staleness @property @@ -209,7 +210,7 @@ def min_wire_version(self) -> int: return 0 if self.__max_staleness == -1 else 5 def __repr__(self): - return "%s(tag_sets=%r, max_staleness=%r, hedge=%r)" % ( + return "{}(tag_sets={!r}, max_staleness={!r}, hedge={!r})".format( self.name, self.__tag_sets, self.__max_staleness, @@ -263,7 +264,7 @@ class Primary(_ServerMode): __slots__ = () def __init__(self) -> None: - super(Primary, self).__init__(_PRIMARY) + super().__init__(_PRIMARY) def __call__(self, selection: Any) -> Any: """Apply this read preference to a Selection.""" @@ -314,7 +315,7 @@ def __init__( max_staleness: int = -1, hedge: Optional[_Hedge] = None, ) -> None: - super(PrimaryPreferred, self).__init__(_PRIMARY_PREFERRED, tag_sets, max_staleness, hedge) + super().__init__(_PRIMARY_PREFERRED, tag_sets, max_staleness, hedge) def __call__(self, selection: Any) -> Any: """Apply this read preference to Selection.""" @@ -357,7 +358,7 @@ def __init__( max_staleness: int = -1, hedge: Optional[_Hedge] = None, ) -> None: - super(Secondary, self).__init__(_SECONDARY, tag_sets, max_staleness, hedge) + super().__init__(_SECONDARY, tag_sets, max_staleness, hedge) def __call__(self, selection: Any) -> Any: """Apply this read preference to Selection.""" @@ -401,9 +402,7 @@ def __init__( max_staleness: int = -1, hedge: Optional[_Hedge] = None, ) -> None: - super(SecondaryPreferred, self).__init__( - _SECONDARY_PREFERRED, tag_sets, max_staleness, hedge - ) + super().__init__(_SECONDARY_PREFERRED, tag_sets, max_staleness, hedge) def __call__(self, selection: Any) -> Any: """Apply this read preference to Selection.""" @@ -448,7 +447,7 @@ def __init__( max_staleness: int = -1, hedge: Optional[_Hedge] = None, ) -> None: - super(Nearest, self).__init__(_NEAREST, tag_sets, max_staleness, hedge) + super().__init__(_NEAREST, tag_sets, max_staleness, hedge) def __call__(self, selection: Any) -> Any: """Apply this read preference to Selection.""" @@ -490,7 +489,7 @@ def __call__(self, selection): return self.effective_pref(selection) def __repr__(self): - return "_AggWritePref(pref=%r)" % (self.pref,) + return f"_AggWritePref(pref={self.pref!r})" # Proxy other calls to the effective_pref so that _AggWritePref can be # used in place of an actual read preference. @@ -524,7 +523,7 @@ def make_read_preference( ) -class ReadPreference(object): +class ReadPreference: """An enum that defines some commonly used read preference modes. Apps can also create a custom read preference, for example:: @@ -591,7 +590,7 @@ def read_pref_mode_from_name(name: str) -> int: return _MONGOS_MODES.index(name) -class MovingAverage(object): +class MovingAverage: """Tracks an exponentially-weighted moving average.""" average: Optional[float] diff --git a/pymongo/response.py b/pymongo/response.py index 1369eac4e0..fc01b0f1bf 100644 --- a/pymongo/response.py +++ b/pymongo/response.py @@ -15,7 +15,7 @@ """Represent a response from the server.""" -class Response(object): +class Response: __slots__ = ("_data", "_address", "_request_id", "_duration", "_from_command", "_docs") def __init__(self, data, address, request_id, duration, from_command, docs): @@ -86,9 +86,7 @@ def __init__( - `more_to_come`: Bool indicating whether cursor is ready to be exhausted. """ - super(PinnedResponse, self).__init__( - data, address, request_id, duration, from_command, docs - ) + super().__init__(data, address, request_id, duration, from_command, docs) self._socket_info = socket_info self._more_to_come = more_to_come @@ -105,5 +103,6 @@ def socket_info(self): @property def more_to_come(self): """If true, server is ready to send batches on the socket until the - result set is exhausted or there is an error.""" + result set is exhausted or there is an error. + """ return self._more_to_come diff --git a/pymongo/results.py b/pymongo/results.py index b072979499..3bd9e82069 100644 --- a/pymongo/results.py +++ b/pymongo/results.py @@ -18,7 +18,7 @@ from pymongo.errors import InvalidOperation -class _WriteResult(object): +class _WriteResult: """Base class for write result classes.""" __slots__ = ("__acknowledged",) @@ -30,10 +30,10 @@ def _raise_if_unacknowledged(self, property_name): """Raise an exception on property access if unacknowledged.""" if not self.__acknowledged: raise InvalidOperation( - "A value for %s is not available when " + "A value for {} is not available when " "the write is unacknowledged. Check the " "acknowledged attribute to avoid this " - "error." % (property_name,) + "error.".format(property_name) ) @property @@ -63,7 +63,7 @@ class InsertOneResult(_WriteResult): def __init__(self, inserted_id: Any, acknowledged: bool) -> None: self.__inserted_id = inserted_id - super(InsertOneResult, self).__init__(acknowledged) + super().__init__(acknowledged) @property def inserted_id(self) -> Any: @@ -78,7 +78,7 @@ class InsertManyResult(_WriteResult): def __init__(self, inserted_ids: List[Any], acknowledged: bool) -> None: self.__inserted_ids = inserted_ids - super(InsertManyResult, self).__init__(acknowledged) + super().__init__(acknowledged) @property def inserted_ids(self) -> List: @@ -102,7 +102,7 @@ class UpdateResult(_WriteResult): def __init__(self, raw_result: Dict[str, Any], acknowledged: bool) -> None: self.__raw_result = raw_result - super(UpdateResult, self).__init__(acknowledged) + super().__init__(acknowledged) @property def raw_result(self) -> Dict[str, Any]: @@ -134,13 +134,14 @@ def upserted_id(self) -> Any: class DeleteResult(_WriteResult): """The return type for :meth:`~pymongo.collection.Collection.delete_one` - and :meth:`~pymongo.collection.Collection.delete_many`""" + and :meth:`~pymongo.collection.Collection.delete_many` + """ __slots__ = ("__raw_result",) def __init__(self, raw_result: Dict[str, Any], acknowledged: bool) -> None: self.__raw_result = raw_result - super(DeleteResult, self).__init__(acknowledged) + super().__init__(acknowledged) @property def raw_result(self) -> Dict[str, Any]: @@ -169,7 +170,7 @@ def __init__(self, bulk_api_result: Dict[str, Any], acknowledged: bool) -> None: :exc:`~pymongo.errors.InvalidOperation`. """ self.__bulk_api_result = bulk_api_result - super(BulkWriteResult, self).__init__(acknowledged) + super().__init__(acknowledged) @property def bulk_api_result(self) -> Dict[str, Any]: @@ -211,7 +212,5 @@ def upserted_ids(self) -> Optional[Dict[int, Any]]: """A map of operation index to the _id of the upserted document.""" self._raise_if_unacknowledged("upserted_ids") if self.__bulk_api_result: - return dict( - (upsert["index"], upsert["_id"]) for upsert in self.bulk_api_result["upserted"] - ) + return {upsert["index"]: upsert["_id"] for upsert in self.bulk_api_result["upserted"]} return None diff --git a/pymongo/saslprep.py b/pymongo/saslprep.py index b96d6fcb56..34c0182a53 100644 --- a/pymongo/saslprep.py +++ b/pymongo/saslprep.py @@ -71,7 +71,7 @@ def saslprep(data: Any, prohibit_unassigned_code_points: Optional[bool] = True) return data if prohibit_unassigned_code_points: - prohibited = _PROHIBITED + (stringprep.in_table_a1,) + prohibited = (*_PROHIBITED, stringprep.in_table_a1) else: prohibited = _PROHIBITED @@ -98,12 +98,12 @@ def saslprep(data: Any, prohibit_unassigned_code_points: Optional[bool] = True) raise ValueError("SASLprep: failed bidirectional check") # RFC3454, Section 6, #2. If a string contains any RandALCat # character, it MUST NOT contain any LCat character. - prohibited = prohibited + (stringprep.in_table_d2,) + prohibited = (*prohibited, stringprep.in_table_d2) else: # RFC3454, Section 6, #3. Following the logic of #3, if # the first character is not a RandALCat, no other character # can be either. - prohibited = prohibited + (in_table_d1,) + prohibited = (*prohibited, in_table_d1) # RFC3454 section 2, step 3 and 4 - Prohibit and check bidi for char in data: diff --git a/pymongo/server.py b/pymongo/server.py index 16c905abb7..2eb91c5b5d 100644 --- a/pymongo/server.py +++ b/pymongo/server.py @@ -25,7 +25,7 @@ _CURSOR_DOC_FIELDS = {"cursor": {"firstBatch": 1, "nextBatch": 1}} -class Server(object): +class Server: def __init__( self, server_description, pool, monitor, topology_id=None, listeners=None, events=None ): @@ -245,4 +245,4 @@ def _split_message(self, message): return request_id, data, 0 def __repr__(self): - return "<%s %r>" % (self.__class__.__name__, self._description) + return f"<{self.__class__.__name__} {self._description!r}>" diff --git a/pymongo/server_api.py b/pymongo/server_api.py index e92d6e6179..2393615032 100644 --- a/pymongo/server_api.py +++ b/pymongo/server_api.py @@ -95,7 +95,7 @@ class ServerApiVersion: """Server API version "1".""" -class ServerApi(object): +class ServerApi: """MongoDB Stable API.""" def __init__(self, version, strict=None, deprecation_errors=None): @@ -113,16 +113,16 @@ def __init__(self, version, strict=None, deprecation_errors=None): .. versionadded:: 3.12 """ if version != ServerApiVersion.V1: - raise ValueError("Unknown ServerApi version: %s" % (version,)) + raise ValueError(f"Unknown ServerApi version: {version}") if strict is not None and not isinstance(strict, bool): raise TypeError( "Wrong type for ServerApi strict, value must be an instance " - "of bool, not %s" % (type(strict),) + "of bool, not {}".format(type(strict)) ) if deprecation_errors is not None and not isinstance(deprecation_errors, bool): raise TypeError( "Wrong type for ServerApi deprecation_errors, value must be " - "an instance of bool, not %s" % (type(deprecation_errors),) + "an instance of bool, not {}".format(type(deprecation_errors)) ) self._version = version self._strict = strict diff --git a/pymongo/server_description.py b/pymongo/server_description.py index 46517ee95e..4bca3390ae 100644 --- a/pymongo/server_description.py +++ b/pymongo/server_description.py @@ -25,7 +25,7 @@ from pymongo.typings import _Address -class ServerDescription(object): +class ServerDescription: """Immutable representation of one server. :Parameters: @@ -287,8 +287,8 @@ def __ne__(self, other: Any) -> bool: def __repr__(self): errmsg = "" if self.error: - errmsg = ", error=%r" % (self.error,) - return "<%s %s server_type: %s, rtt: %s%s>" % ( + errmsg = f", error={self.error!r}" + return "<{} {} server_type: {}, rtt: {}{}>".format( self.__class__.__name__, self.address, self.server_type_name, diff --git a/pymongo/server_selectors.py b/pymongo/server_selectors.py index 313566cb83..aa9d26b5fb 100644 --- a/pymongo/server_selectors.py +++ b/pymongo/server_selectors.py @@ -17,7 +17,7 @@ from pymongo.server_type import SERVER_TYPE -class Selection(object): +class Selection: """Input or output of a server selector function.""" @classmethod @@ -51,6 +51,7 @@ def secondary_with_max_last_write_date(self): secondaries = secondary_server_selector(self) if secondaries.server_descriptions: return max(secondaries.server_descriptions, key=lambda sd: sd.last_write_date) + return None @property def primary_selection(self): diff --git a/pymongo/settings.py b/pymongo/settings.py index 2bd2527cdf..5d6ddefd36 100644 --- a/pymongo/settings.py +++ b/pymongo/settings.py @@ -26,7 +26,7 @@ from pymongo.topology_description import TOPOLOGY_TYPE -class TopologySettings(object): +class TopologySettings: def __init__( self, seeds=None, @@ -156,4 +156,4 @@ def get_topology_type(self): def get_server_descriptions(self): """Initial dict of (address, ServerDescription) for all seeds.""" - return dict([(address, ServerDescription(address)) for address in self.seeds]) + return {address: ServerDescription(address) for address in self.seeds} diff --git a/pymongo/socket_checker.py b/pymongo/socket_checker.py index 420953db2e..a278898952 100644 --- a/pymongo/socket_checker.py +++ b/pymongo/socket_checker.py @@ -33,7 +33,7 @@ def _errno_from_exception(exc): return None -class SocketChecker(object): +class SocketChecker: def __init__(self) -> None: self._poller: Optional[select.poll] if _HAVE_POLL: @@ -78,7 +78,7 @@ def select( # ready: subsets of the first three arguments. Return # True if any of the lists are not empty. return any(res) - except (_SelectError, IOError) as exc: # type: ignore + except (_SelectError, OSError) as exc: # type: ignore if _errno_from_exception(exc) in (errno.EINTR, errno.EAGAIN): continue raise diff --git a/pymongo/srv_resolver.py b/pymongo/srv_resolver.py index fe2dd49aa0..583de818b0 100644 --- a/pymongo/srv_resolver.py +++ b/pymongo/srv_resolver.py @@ -51,7 +51,7 @@ def _resolve(*args, **kwargs): ) -class _SrvResolver(object): +class _SrvResolver: def __init__(self, fqdn, connect_timeout, srv_service_name, srv_max_hosts=0): self.__fqdn = fqdn self.__srv = srv_service_name @@ -110,9 +110,9 @@ def _get_srv_response_and_hosts(self, encapsulate_errors): try: nlist = node[0].split(".")[1:][-self.__slen :] except Exception: - raise ConfigurationError("Invalid SRV host: %s" % (node[0],)) + raise ConfigurationError(f"Invalid SRV host: {node[0]}") if self.__plist != nlist: - raise ConfigurationError("Invalid SRV host: %s" % (node[0],)) + raise ConfigurationError(f"Invalid SRV host: {node[0]}") if self.__srv_max_hosts: nodes = random.sample(nodes, min(self.__srv_max_hosts, len(nodes))) return results, nodes diff --git a/pymongo/ssl_support.py b/pymongo/ssl_support.py index 13c5315eee..3af535ee4b 100644 --- a/pymongo/ssl_support.py +++ b/pymongo/ssl_support.py @@ -71,7 +71,7 @@ def get_ssl_context( try: ctx.load_cert_chain(certfile, None, passphrase) except _ssl.SSLError as exc: - raise ConfigurationError("Private key doesn't match certificate: %s" % (exc,)) + raise ConfigurationError(f"Private key doesn't match certificate: {exc}") if crlfile is not None: if _ssl.IS_PYOPENSSL: raise ConfigurationError("tlsCRLFile cannot be used with PyOpenSSL") diff --git a/pymongo/topology.py b/pymongo/topology.py index 904f6b1836..9759b39f9f 100644 --- a/pymongo/topology.py +++ b/pymongo/topology.py @@ -75,7 +75,7 @@ def process_events_queue(queue_ref): return True # Continue PeriodicExecutor. -class Topology(object): +class Topology: """Monitor a topology of one or more servers.""" def __init__(self, topology_settings): @@ -236,8 +236,7 @@ def _select_servers_loop(self, selector, timeout, address): # No suitable servers. if timeout == 0 or now > end_time: raise ServerSelectionTimeoutError( - "%s, Timeout: %ss, Topology Description: %r" - % (self._error_message(selector), timeout, self.description) + f"{self._error_message(selector)}, Timeout: {timeout}s, Topology Description: {self.description!r}" ) self._ensure_opened() @@ -431,7 +430,7 @@ def _get_replica_set_members(self, selector): ): return set() - return set([sd.address for sd in selector(self._new_selection())]) + return {sd.address for sd in selector(self._new_selection())} def get_secondaries(self): """Return set of secondary addresses.""" @@ -499,7 +498,8 @@ def update_pool(self): def close(self): """Clear pools and terminate monitors. Topology does not reopen on demand. Any further operations will raise - :exc:`~.errors.InvalidOperation`.""" + :exc:`~.errors.InvalidOperation`. + """ with self._lock: for server in self._servers.values(): server.close() @@ -807,14 +807,14 @@ def _error_message(self, selector): else: return "No %s available for writes" % server_plural else: - return 'No %s match selector "%s"' % (server_plural, selector) + return f'No {server_plural} match selector "{selector}"' else: addresses = list(self._description.server_descriptions()) servers = list(self._description.server_descriptions().values()) if not servers: if is_replica_set: # We removed all servers because of the wrong setName? - return 'No %s available for replica set name "%s"' % ( + return 'No {} available for replica set name "{}"'.format( server_plural, self._settings.replica_set_name, ) @@ -844,7 +844,7 @@ def __repr__(self): msg = "" if not self._opened: msg = "CLOSED " - return "<%s %s%r>" % (self.__class__.__name__, msg, self._description) + return f"<{self.__class__.__name__} {msg}{self._description!r}>" def eq_props(self): """The properties to use for MongoClient/Topology equality checks.""" @@ -860,7 +860,7 @@ def __hash__(self): return hash(self.eq_props()) -class _ErrorContext(object): +class _ErrorContext: """An error with context for SDAM error handling.""" def __init__(self, error, max_wire_version, sock_generation, completed_handshake, service_id): diff --git a/pymongo/topology_description.py b/pymongo/topology_description.py index 7503a72704..7079b324b2 100644 --- a/pymongo/topology_description.py +++ b/pymongo/topology_description.py @@ -47,7 +47,7 @@ class _TopologyType(NamedTuple): _ServerSelector = Callable[[List[ServerDescription]], List[ServerDescription]] -class TopologyDescription(object): +class TopologyDescription: def __init__( self, topology_type: int, @@ -171,7 +171,7 @@ def reset(self) -> "TopologyDescription": topology_type = self._topology_type # The default ServerDescription's type is Unknown. - sds = dict((address, ServerDescription(address)) for address in self._server_descriptions) + sds = {address: ServerDescription(address) for address in self._server_descriptions} return TopologyDescription( topology_type, @@ -184,7 +184,8 @@ def reset(self) -> "TopologyDescription": def server_descriptions(self) -> Dict[_Address, ServerDescription]: """Dict of (address, - :class:`~pymongo.server_description.ServerDescription`).""" + :class:`~pymongo.server_description.ServerDescription`). + """ return self._server_descriptions.copy() @property @@ -346,7 +347,7 @@ def has_writable_server(self) -> bool: def __repr__(self): # Sort the servers by address. servers = sorted(self._server_descriptions.values(), key=lambda sd: sd.address) - return "<%s id: %s, topology_type: %s, servers: %r>" % ( + return "<{} id: {}, topology_type: {}, servers: {!r}>".format( self.__class__.__name__, self._topology_settings._topology_id, self.topology_type_name, @@ -400,8 +401,9 @@ def updated_topology_description( if set_name is not None and set_name != server_description.replica_set_name: error = ConfigurationError( "client is configured to connect to a replica set named " - "'%s' but this node belongs to a set named '%s'" - % (set_name, server_description.replica_set_name) + "'{}' but this node belongs to a set named '{}'".format( + set_name, server_description.replica_set_name + ) ) sds[address] = server_description.to_unknown(error=error) # Single type never changes. diff --git a/pymongo/typings.py b/pymongo/typings.py index 32cd980c97..ef82114f15 100644 --- a/pymongo/typings.py +++ b/pymongo/typings.py @@ -29,7 +29,8 @@ def strip_optional(elem): """This function is to allow us to cast all of the elements of an iterator from Optional[_T] to _T - while inside a list comprehension.""" + while inside a list comprehension. + """ assert elem is not None return elem diff --git a/pymongo/uri_parser.py b/pymongo/uri_parser.py index e3aeee399e..0772b39c80 100644 --- a/pymongo/uri_parser.py +++ b/pymongo/uri_parser.py @@ -134,7 +134,7 @@ def parse_host(entity: str, default_port: Optional[int] = DEFAULT_PORT) -> _Addr host, port = host.split(":", 1) if isinstance(port, str): if not port.isdigit() or int(port) > 65535 or int(port) <= 0: - raise ValueError("Port must be an integer between 0 and 65535: %r" % (port,)) + raise ValueError(f"Port must be an integer between 0 and 65535: {port!r}") port = int(port) # Normalize hostname to lowercase, since DNS is case-insensitive: @@ -155,7 +155,8 @@ def parse_host(entity: str, default_port: Optional[int] = DEFAULT_PORT) -> _Addr def _parse_options(opts, delim): """Helper method for split_options which creates the options dict. Also handles the creation of a list for the URI tag_sets/ - readpreferencetags portion, and the use of a unicode options string.""" + readpreferencetags portion, and the use of a unicode options string. + """ options = _CaseInsensitiveDictionary() for uriopt in opts.split(delim): key, value = uriopt.split("=") @@ -163,7 +164,7 @@ def _parse_options(opts, delim): options.setdefault(key, []).append(value) else: if key in options: - warnings.warn("Duplicate URI option '%s'." % (key,)) + warnings.warn(f"Duplicate URI option '{key}'.") if key.lower() == "authmechanismproperties": val = value else: @@ -475,9 +476,7 @@ def parse_uri( is_srv = True scheme_free = uri[SRV_SCHEME_LEN:] else: - raise InvalidURI( - "Invalid URI scheme: URI must begin with '%s' or '%s'" % (SCHEME, SRV_SCHEME) - ) + raise InvalidURI(f"Invalid URI scheme: URI must begin with '{SCHEME}' or '{SRV_SCHEME}'") if not scheme_free: raise InvalidURI("Must provide at least one hostname or IP.") @@ -525,15 +524,13 @@ def parse_uri( srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts") if is_srv: if options.get("directConnection"): - raise ConfigurationError( - "Cannot specify directConnection=true with %s URIs" % (SRV_SCHEME,) - ) + raise ConfigurationError(f"Cannot specify directConnection=true with {SRV_SCHEME} URIs") nodes = split_hosts(hosts, default_port=None) if len(nodes) != 1: - raise InvalidURI("%s URIs must include one, and only one, hostname" % (SRV_SCHEME,)) + raise InvalidURI(f"{SRV_SCHEME} URIs must include one, and only one, hostname") fqdn, port = nodes[0] if port is not None: - raise InvalidURI("%s URIs must not include a port number" % (SRV_SCHEME,)) + raise InvalidURI(f"{SRV_SCHEME} URIs must not include a port number") # Use the connection timeout. connectTimeoutMS passed as a keyword # argument overrides the same option passed in the connection string. diff --git a/pymongo/write_concern.py b/pymongo/write_concern.py index ced71d0488..25f87954b5 100644 --- a/pymongo/write_concern.py +++ b/pymongo/write_concern.py @@ -19,7 +19,7 @@ from pymongo.errors import ConfigurationError -class WriteConcern(object): +class WriteConcern: """WriteConcern :Parameters: @@ -113,7 +113,9 @@ def acknowledged(self) -> bool: return self.__acknowledged def __repr__(self): - return "WriteConcern(%s)" % (", ".join("%s=%s" % kvt for kvt in self.__document.items()),) + return "WriteConcern({})".format( + ", ".join("{}={}".format(*kvt) for kvt in self.__document.items()) + ) def __eq__(self, other: Any) -> bool: if isinstance(other, WriteConcern): diff --git a/test/__init__.py b/test/__init__.py index dc324c6911..c80b4e95c8 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Test suite for pymongo, bson, and gridfs. -""" +"""Test suite for pymongo, bson, and gridfs.""" import base64 import gc @@ -92,7 +91,7 @@ CLIENT_PEM = os.environ.get("CLIENT_PEM", os.path.join(CERT_PATH, "client.pem")) CA_PEM = os.environ.get("CA_PEM", os.path.join(CERT_PATH, "ca.pem")) -TLS_OPTIONS: Dict = dict(tls=True) +TLS_OPTIONS: Dict = {"tls": True} if CLIENT_PEM: TLS_OPTIONS["tlsCertificateKeyFile"] = CLIENT_PEM if CA_PEM: @@ -149,7 +148,7 @@ def is_server_resolvable(): try: socket.gethostbyname("server") return True - except socket.error: + except OSError: return False finally: socket.setdefaulttimeout(socket_timeout) @@ -165,7 +164,7 @@ def _create_user(authdb, user, pwd=None, roles=None, **kwargs): return authdb.command(cmd) -class client_knobs(object): +class client_knobs: def __init__( self, heartbeat_frequency=None, @@ -234,10 +233,9 @@ def wrap(*args, **kwargs): def __del__(self): if self._enabled: msg = ( - "ERROR: client_knobs still enabled! HEARTBEAT_FREQUENCY=%s, " - "MIN_HEARTBEAT_INTERVAL=%s, KILL_CURSOR_FREQUENCY=%s, " - "EVENTS_QUEUE_FREQUENCY=%s, stack:\n%s" - % ( + "ERROR: client_knobs still enabled! HEARTBEAT_FREQUENCY={}, " + "MIN_HEARTBEAT_INTERVAL={}, KILL_CURSOR_FREQUENCY={}, " + "EVENTS_QUEUE_FREQUENCY={}, stack:\n{}".format( common.HEARTBEAT_FREQUENCY, common.MIN_HEARTBEAT_INTERVAL, common.KILL_CURSOR_FREQUENCY, @@ -250,10 +248,10 @@ def __del__(self): def _all_users(db): - return set(u["user"] for u in db.command("usersInfo").get("users", [])) + return {u["user"] for u in db.command("usersInfo").get("users", [])} -class ClientContext(object): +class ClientContext: client: MongoClient MULTI_MONGOS_LB_URI = MULTI_MONGOS_LB_URI @@ -339,14 +337,14 @@ def _connect(self, host, port, **kwargs): except pymongo.errors.OperationFailure as exc: # SERVER-32063 self.connection_attempts.append( - "connected client %r, but legacy hello failed: %s" % (client, exc) + f"connected client {client!r}, but legacy hello failed: {exc}" ) else: - self.connection_attempts.append("successfully connected client %r" % (client,)) + self.connection_attempts.append(f"successfully connected client {client!r}") # If connected, then return client with default timeout return pymongo.MongoClient(host, port, **kwargs) except pymongo.errors.ConnectionFailure as exc: - self.connection_attempts.append("failed to connect client %r: %s" % (client, exc)) + self.connection_attempts.append(f"failed to connect client {client!r}: {exc}") return None finally: client.close() @@ -447,7 +445,7 @@ def _init_client(self): nodes.extend([partition_node(node.lower()) for node in hello.get("arbiters", [])]) self.nodes = set(nodes) else: - self.nodes = set([(host, port)]) + self.nodes = {(host, port)} self.w = len(hello.get("hosts", [])) or 1 self.version = Version.from_client(self.client) @@ -587,7 +585,7 @@ def _server_started_with_ipv6(self): for info in socket.getaddrinfo(self.host, self.port): if info[0] == socket.AF_INET6: return True - except socket.error: + except OSError: pass return False @@ -599,7 +597,7 @@ def wrap(*args, **kwargs): self.init() # Always raise SkipTest if we can't connect to MongoDB if not self.connected: - raise SkipTest("Cannot connect to MongoDB on %s" % (self.pair,)) + raise SkipTest(f"Cannot connect to MongoDB on {self.pair}") if condition(): return f(*args, **kwargs) raise SkipTest(msg) @@ -625,7 +623,7 @@ def require_connection(self, func): """Run a test only if we can connect to MongoDB.""" return self._require( lambda: True, # _require checks if we're connected - "Cannot connect to MongoDB on %s" % (self.pair,), + f"Cannot connect to MongoDB on {self.pair}", func=func, ) @@ -633,14 +631,15 @@ def require_data_lake(self, func): """Run a test only if we are connected to Atlas Data Lake.""" return self._require( lambda: self.is_data_lake, - "Not connected to Atlas Data Lake on %s" % (self.pair,), + f"Not connected to Atlas Data Lake on {self.pair}", func=func, ) def require_no_mmap(self, func): """Run a test only if the server is not using the MMAPv1 storage engine. Only works for standalone and replica sets; tests are - run regardless of storage engine on sharded clusters.""" + run regardless of storage engine on sharded clusters. + """ def is_not_mmap(): if self.is_mongos: @@ -734,7 +733,8 @@ def require_mongos(self, func): def require_multiple_mongoses(self, func): """Run a test only if the client is connected to a sharded cluster - that has 2 mongos nodes.""" + that has 2 mongos nodes. + """ return self._require( lambda: len(self.mongoses) > 1, "Must have multiple mongoses available", func=func ) @@ -786,7 +786,7 @@ def is_topology_type(self, topologies): "load-balanced", } if unknown: - raise AssertionError("Unknown topologies: %r" % (unknown,)) + raise AssertionError(f"Unknown topologies: {unknown!r}") if self.load_balancer: if "load-balanced" in topologies: return True @@ -812,7 +812,8 @@ def is_topology_type(self, topologies): def require_cluster_type(self, topologies=[]): # noqa """Run a test only if the client is connected to a cluster that conforms to one of the specified topologies. Acceptable topologies - are 'single', 'replicaset', and 'sharded'.""" + are 'single', 'replicaset', and 'sharded'. + """ def _is_valid_topology(): return self.is_topology_type(topologies) @@ -827,7 +828,8 @@ def require_test_commands(self, func): def require_failCommand_fail_point(self, func): """Run a test only if the server supports the failCommand fail - point.""" + point. + """ return self._require( lambda: self.supports_failCommand_fail_point, "failCommand fail point must be supported", @@ -930,7 +932,7 @@ def require_no_api_version(self, func): ) def mongos_seeds(self): - return ",".join("%s:%s" % address for address in self.mongoses) + return ",".join("{}:{}".format(*address) for address in self.mongoses) @property def supports_failCommand_fail_point(self): @@ -1139,7 +1141,7 @@ def setUpClass(cls): pass def setUp(self): - super(MockClientTest, self).setUp() + super().setUp() self.client_knobs = client_knobs(heartbeat_frequency=0.001, min_heartbeat_interval=0.001) @@ -1147,7 +1149,7 @@ def setUp(self): def tearDown(self): self.client_knobs.disable() - super(MockClientTest, self).tearDown() + super().tearDown() # Global knobs to speed up the test suite. @@ -1181,9 +1183,9 @@ def print_running_topology(topology): if running: print( "WARNING: found Topology with running threads:\n" - " Threads: %s\n" - " Topology: %s\n" - " Creation traceback:\n%s" % (running, topology, topology._settings._stack) + " Threads: {}\n" + " Topology: {}\n" + " Creation traceback:\n{}".format(running, topology, topology._settings._stack) ) @@ -1215,11 +1217,11 @@ def teardown(): global_knobs.disable() garbage = [] for g in gc.garbage: - garbage.append("GARBAGE: %r" % (g,)) - garbage.append(" gc.get_referents: %r" % (gc.get_referents(g),)) - garbage.append(" gc.get_referrers: %r" % (gc.get_referrers(g),)) + garbage.append(f"GARBAGE: {g!r}") + garbage.append(f" gc.get_referents: {gc.get_referents(g)!r}") + garbage.append(f" gc.get_referrers: {gc.get_referrers(g)!r}") if garbage: - assert False, "\n".join(garbage) + raise AssertionError("\n".join(garbage)) c = client_context.client if c: if not client_context.is_data_lake: @@ -1237,7 +1239,7 @@ def teardown(): class PymongoTestRunner(unittest.TextTestRunner): def run(self, test): setup() - result = super(PymongoTestRunner, self).run(test) + result = super().run(test) teardown() return result @@ -1247,7 +1249,7 @@ def run(self, test): class PymongoXMLTestRunner(XMLTestRunner): # type: ignore[misc] def run(self, test): setup() - result = super(PymongoXMLTestRunner, self).run(test) + result = super().run(test) teardown() return result @@ -1260,8 +1262,7 @@ def test_cases(suite): yield suite_or_case else: # unittest.TestSuite - for case in test_cases(suite_or_case): - yield case + yield from test_cases(suite_or_case) # Helper method to workaround https://bugs.python.org/issue21724 @@ -1272,7 +1273,7 @@ def clear_warning_registry(): setattr(module, "__warningregistry__", {}) # noqa -class SystemCertsPatcher(object): +class SystemCertsPatcher: def __init__(self, ca_certs): if ( ssl.OPENSSL_VERSION.lower().startswith("libressl") diff --git a/test/atlas/test_connection.py b/test/atlas/test_connection.py index 39d817140e..036e4772ff 100644 --- a/test/atlas/test_connection.py +++ b/test/atlas/test_connection.py @@ -102,7 +102,7 @@ def test_uniqueness(self): duplicates = [names for names in uri_to_names.values() if len(names) > 1] self.assertFalse( duplicates, - "Error: the following env variables have duplicate values: %s" % (duplicates,), + f"Error: the following env variables have duplicate values: {duplicates}", ) diff --git a/test/auth_aws/test_auth_aws.py b/test/auth_aws/test_auth_aws.py index e0329a783e..e180d8b064 100644 --- a/test/auth_aws/test_auth_aws.py +++ b/test/auth_aws/test_auth_aws.py @@ -39,7 +39,7 @@ def test_should_fail_without_credentials(self): if "@" not in self.uri: self.skipTest("MONGODB_URI already has no credentials") - hosts = ["%s:%s" % addr for addr in parse_uri(self.uri)["nodelist"]] + hosts = ["{}:{}".format(*addr) for addr in parse_uri(self.uri)["nodelist"]] self.assertTrue(hosts) with MongoClient(hosts) as client: with self.assertRaises(OperationFailure): @@ -115,7 +115,7 @@ def test_poisoned_cache(self): def test_environment_variables_ignored(self): creds = self.setup_cache() self.assertIsNotNone(creds) - prev = os.environ.copy() + os.environ.copy() client = MongoClient(self.uri) self.addCleanup(client.close) @@ -124,9 +124,11 @@ def test_environment_variables_ignored(self): self.assertIsNotNone(auth.get_cached_credentials()) - mock_env = dict( - AWS_ACCESS_KEY_ID="foo", AWS_SECRET_ACCESS_KEY="bar", AWS_SESSION_TOKEN="baz" - ) + mock_env = { + "AWS_ACCESS_KEY_ID": "foo", + "AWS_SECRET_ACCESS_KEY": "bar", + "AWS_SESSION_TOKEN": "baz", + } with patch.dict("os.environ", mock_env): self.assertEqual(os.environ["AWS_ACCESS_KEY_ID"], "foo") @@ -147,7 +149,7 @@ def test_no_cache_environment_variables(self): self.assertIsNotNone(creds) auth.set_cached_credentials(None) - mock_env = dict(AWS_ACCESS_KEY_ID=creds.username, AWS_SECRET_ACCESS_KEY=creds.password) + mock_env = {"AWS_ACCESS_KEY_ID": creds.username, "AWS_SECRET_ACCESS_KEY": creds.password} if creds.token: mock_env["AWS_SESSION_TOKEN"] = creds.token diff --git a/test/auth_aws/test_auth_oidc.py b/test/auth_aws/test_auth_oidc.py index 470e4581c2..26e71573d4 100644 --- a/test/auth_aws/test_auth_oidc.py +++ b/test/auth_aws/test_auth_oidc.py @@ -65,7 +65,7 @@ def request_token(server_info, context): self.assertEqual(timeout_seconds, 60 * 5) with open(token_file) as fid: token = fid.read() - resp = dict(access_token=token) + resp = {"access_token": token} time.sleep(sleep) @@ -94,7 +94,7 @@ def refresh_token(server_info, context): # Validate the timeout. self.assertEqual(context["timeout_seconds"], 60 * 5) - resp = dict(access_token=token) + resp = {"access_token": token} if expires_in_seconds is not None: resp["expires_in_seconds"] = expires_in_seconds self.refresh_called += 1 @@ -115,21 +115,21 @@ def fail_point(self, command_args): def test_connect_callbacks_single_implicit_username(self): request_token = self.create_request_cb() - props: Dict = dict(request_token_callback=request_token) + props: Dict = {"request_token_callback": request_token} client = MongoClient(self.uri_single, authmechanismproperties=props) client.test.test.find_one() client.close() def test_connect_callbacks_single_explicit_username(self): request_token = self.create_request_cb() - props: Dict = dict(request_token_callback=request_token) + props: Dict = {"request_token_callback": request_token} client = MongoClient(self.uri_single, username="test_user1", authmechanismproperties=props) client.test.test.find_one() client.close() def test_connect_callbacks_multiple_principal_user1(self): request_token = self.create_request_cb() - props: Dict = dict(request_token_callback=request_token) + props: Dict = {"request_token_callback": request_token} client = MongoClient( self.uri_multiple, username="test_user1", authmechanismproperties=props ) @@ -138,7 +138,7 @@ def test_connect_callbacks_multiple_principal_user1(self): def test_connect_callbacks_multiple_principal_user2(self): request_token = self.create_request_cb("test_user2") - props: Dict = dict(request_token_callback=request_token) + props: Dict = {"request_token_callback": request_token} client = MongoClient( self.uri_multiple, username="test_user2", authmechanismproperties=props ) @@ -147,7 +147,7 @@ def test_connect_callbacks_multiple_principal_user2(self): def test_connect_callbacks_multiple_no_username(self): request_token = self.create_request_cb() - props: Dict = dict(request_token_callback=request_token) + props: Dict = {"request_token_callback": request_token} client = MongoClient(self.uri_multiple, authmechanismproperties=props) with self.assertRaises(OperationFailure): client.test.test.find_one() @@ -155,13 +155,13 @@ def test_connect_callbacks_multiple_no_username(self): def test_allowed_hosts_blocked(self): request_token = self.create_request_cb() - props: Dict = dict(request_token_callback=request_token, allowed_hosts=[]) + props: Dict = {"request_token_callback": request_token, "allowed_hosts": []} client = MongoClient(self.uri_single, authmechanismproperties=props) with self.assertRaises(ConfigurationError): client.test.test.find_one() client.close() - props: Dict = dict(request_token_callback=request_token, allowed_hosts=["example.com"]) + props: Dict = {"request_token_callback": request_token, "allowed_hosts": ["example.com"]} client = MongoClient( self.uri_single + "&ignored=example.com", authmechanismproperties=props, connect=False ) @@ -170,26 +170,26 @@ def test_allowed_hosts_blocked(self): client.close() def test_connect_aws_single_principal(self): - props = dict(PROVIDER_NAME="aws") + props = {"PROVIDER_NAME": "aws"} client = MongoClient(self.uri_single, authmechanismproperties=props) client.test.test.find_one() client.close() def test_connect_aws_multiple_principal_user1(self): - props = dict(PROVIDER_NAME="aws") + props = {"PROVIDER_NAME": "aws"} client = MongoClient(self.uri_multiple, authmechanismproperties=props) client.test.test.find_one() client.close() def test_connect_aws_multiple_principal_user2(self): os.environ["AWS_WEB_IDENTITY_TOKEN_FILE"] = os.path.join(self.token_dir, "test_user2") - props = dict(PROVIDER_NAME="aws") + props = {"PROVIDER_NAME": "aws"} client = MongoClient(self.uri_multiple, authmechanismproperties=props) client.test.test.find_one() client.close() def test_connect_aws_allowed_hosts_ignored(self): - props = dict(PROVIDER_NAME="aws", allowed_hosts=[]) + props = {"PROVIDER_NAME": "aws", "allowed_hosts": []} client = MongoClient(self.uri_multiple, authmechanismproperties=props) client.test.test.find_one() client.close() @@ -198,10 +198,10 @@ def test_valid_callbacks(self): request_cb = self.create_request_cb(expires_in_seconds=60) refresh_cb = self.create_refresh_cb() - props: Dict = dict( - request_token_callback=request_cb, - refresh_token_callback=refresh_cb, - ) + props: Dict = { + "request_token_callback": request_cb, + "refresh_token_callback": refresh_cb, + } client = MongoClient(self.uri_single, authmechanismproperties=props) client.test.test.find_one() client.close() @@ -214,7 +214,7 @@ def test_lock_avoids_extra_callbacks(self): request_cb = self.create_request_cb(sleep=0.5) refresh_cb = self.create_refresh_cb() - props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) + props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb} def run_test(): client = MongoClient(self.uri_single, authMechanismProperties=props) @@ -239,7 +239,7 @@ def test_request_callback_returns_null(self): def request_token_null(a, b): return None - props: Dict = dict(request_token_callback=request_token_null) + props: Dict = {"request_token_callback": request_token_null} client = MongoClient(self.uri_single, authMechanismProperties=props) with self.assertRaises(ValueError): client.test.test.find_one() @@ -251,9 +251,10 @@ def test_refresh_callback_returns_null(self): def refresh_token_null(a, b): return None - props: Dict = dict( - request_token_callback=request_cb, refresh_token_callback=refresh_token_null - ) + props: Dict = { + "request_token_callback": request_cb, + "refresh_token_callback": refresh_token_null, + } client = MongoClient(self.uri_single, authMechanismProperties=props) client.test.test.find_one() client.close() @@ -265,9 +266,9 @@ def refresh_token_null(a, b): def test_request_callback_invalid_result(self): def request_token_invalid(a, b): - return dict() + return {} - props: Dict = dict(request_token_callback=request_token_invalid) + props: Dict = {"request_token_callback": request_token_invalid} client = MongoClient(self.uri_single, authMechanismProperties=props) with self.assertRaises(ValueError): client.test.test.find_one() @@ -278,7 +279,7 @@ def request_cb_extra_value(server_info, context): result["foo"] = "bar" return result - props: Dict = dict(request_token_callback=request_cb_extra_value) + props: Dict = {"request_token_callback": request_cb_extra_value} client = MongoClient(self.uri_single, authMechanismProperties=props) with self.assertRaises(ValueError): client.test.test.find_one() @@ -288,11 +289,12 @@ def test_refresh_callback_missing_data(self): request_cb = self.create_request_cb(expires_in_seconds=60) def refresh_cb_no_token(a, b): - return dict() + return {} - props: Dict = dict( - request_token_callback=request_cb, refresh_token_callback=refresh_cb_no_token - ) + props: Dict = { + "request_token_callback": request_cb, + "refresh_token_callback": refresh_cb_no_token, + } client = MongoClient(self.uri_single, authMechanismProperties=props) client.test.test.find_one() client.close() @@ -310,9 +312,10 @@ def refresh_cb_extra_value(server_info, context): result["foo"] = "bar" return result - props: Dict = dict( - request_token_callback=request_cb, refresh_token_callback=refresh_cb_extra_value - ) + props: Dict = { + "request_token_callback": request_cb, + "refresh_token_callback": refresh_cb_extra_value, + } client = MongoClient(self.uri_single, authMechanismProperties=props) client.test.test.find_one() client.close() @@ -329,7 +332,7 @@ def test_cache_with_refresh(self): request_cb = self.create_request_cb(expires_in_seconds=60) refresh_cb = self.create_refresh_cb() - props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) + props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb} # Ensure that a ``find`` operation adds credentials to the cache. client = MongoClient(self.uri_single, authMechanismProperties=props) @@ -352,7 +355,7 @@ def test_cache_with_no_refresh(self): # Give a callback response with a valid accessToken and an expiresInSeconds that is within one minute. request_cb = self.create_request_cb() - props = dict(request_token_callback=request_cb) + props = {"request_token_callback": request_cb} client = MongoClient(self.uri_single, authMechanismProperties=props) # Ensure that a ``find`` operation adds credentials to the cache. @@ -373,7 +376,7 @@ def test_cache_with_no_refresh(self): def test_cache_key_includes_callback(self): request_cb = self.create_request_cb() - props: Dict = dict(request_token_callback=request_cb) + props: Dict = {"request_token_callback": request_cb} # Ensure that a ``find`` operation adds a new entry to the cache. client = MongoClient(self.uri_single, authMechanismProperties=props) @@ -397,10 +400,10 @@ def test_cache_clears_on_error(self): # Create a new client with a valid request callback that gives credentials that expire within 5 minutes and a refresh callback that gives invalid credentials. def refresh_cb(a, b): - return dict(access_token="bad") + return {"access_token": "bad"} # Add a token to the cache that will expire soon. - props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) + props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb} client = MongoClient(self.uri_single, authMechanismProperties=props) client.test.test.find_one() client.close() @@ -421,7 +424,7 @@ def refresh_cb(a, b): def test_cache_is_not_used_in_aws_automatic_workflow(self): # Create a new client using the AWS device workflow. # Ensure that a ``find`` operation does not add credentials to the cache. - props = dict(PROVIDER_NAME="aws") + props = {"PROVIDER_NAME": "aws"} client = MongoClient(self.uri_single, authmechanismproperties=props) client.test.test.find_one() client.close() @@ -438,11 +441,11 @@ def test_speculative_auth_success(self): def request_token(a, b): with open(token_file) as fid: token = fid.read() - return dict(access_token=token, expires_in_seconds=1000) + return {"access_token": token, "expires_in_seconds": 1000} # Create a client with a request callback that returns a valid token # that will not expire soon. - props: Dict = dict(request_token_callback=request_token) + props: Dict = {"request_token_callback": request_token} client = MongoClient(self.uri_single, authmechanismproperties=props) # Set a fail point for saslStart commands. @@ -483,7 +486,7 @@ def test_reauthenticate_succeeds(self): refresh_cb = self.create_refresh_cb() # Create a client with the callbacks. - props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) + props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb} client = MongoClient( self.uri_single, event_listeners=[listener], authmechanismproperties=props ) @@ -536,7 +539,7 @@ def test_reauthenticate_succeeds_bulk_write(self): refresh_cb = self.create_refresh_cb() # Create a client with the callbacks. - props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) + props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb} client = MongoClient(self.uri_single, authmechanismproperties=props) # Perform a find operation. @@ -563,7 +566,7 @@ def test_reauthenticate_succeeds_bulk_read(self): refresh_cb = self.create_refresh_cb() # Create a client with the callbacks. - props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) + props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb} client = MongoClient(self.uri_single, authmechanismproperties=props) # Perform a find operation. @@ -594,7 +597,7 @@ def test_reauthenticate_succeeds_cursor(self): refresh_cb = self.create_refresh_cb() # Create a client with the callbacks. - props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) + props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb} client = MongoClient(self.uri_single, authmechanismproperties=props) # Perform an insert operation. @@ -622,7 +625,7 @@ def test_reauthenticate_succeeds_get_more(self): refresh_cb = self.create_refresh_cb() # Create a client with the callbacks. - props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) + props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb} client = MongoClient(self.uri_single, authmechanismproperties=props) # Perform an insert operation. @@ -647,7 +650,7 @@ def test_reauthenticate_succeeds_get_more(self): def test_reauthenticate_succeeds_get_more_exhaust(self): # Ensure no mongos - props = dict(PROVIDER_NAME="aws") + props = {"PROVIDER_NAME": "aws"} client = MongoClient(self.uri_single, authmechanismproperties=props) hello = client.admin.command(HelloCompat.LEGACY_CMD) if hello.get("msg") != "isdbgrid": @@ -657,7 +660,7 @@ def test_reauthenticate_succeeds_get_more_exhaust(self): refresh_cb = self.create_refresh_cb() # Create a client with the callbacks. - props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) + props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb} client = MongoClient(self.uri_single, authmechanismproperties=props) # Perform an insert operation. @@ -685,7 +688,7 @@ def test_reauthenticate_succeeds_command(self): refresh_cb = self.create_refresh_cb() # Create a client with the callbacks. - props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) + props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb} print("start of test") client = MongoClient(self.uri_single, authmechanismproperties=props) @@ -703,7 +706,7 @@ def test_reauthenticate_succeeds_command(self): } ): # Perform a count operation. - cursor = client.test.command(dict(count="test")) + cursor = client.test.command({"count": "test"}) self.assertGreaterEqual(len(list(cursor)), 1) @@ -720,7 +723,7 @@ def test_reauthenticate_retries_and_succeeds_with_cache(self): refresh_cb = self.create_refresh_cb() # Create a client with the callbacks. - props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) + props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb} client = MongoClient( self.uri_single, event_listeners=[listener], authmechanismproperties=props ) @@ -750,7 +753,7 @@ def test_reauthenticate_fails_with_no_cache(self): refresh_cb = self.create_refresh_cb() # Create a client with the callbacks. - props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) + props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb} client = MongoClient( self.uri_single, event_listeners=[listener], authmechanismproperties=props ) @@ -778,7 +781,7 @@ def test_late_reauth_avoids_callback(self): request_cb = self.create_request_cb(expires_in_seconds=1e6) refresh_cb = self.create_refresh_cb(expires_in_seconds=1e6) - props: Dict = dict(request_token_callback=request_cb, refresh_token_callback=refresh_cb) + props: Dict = {"request_token_callback": request_cb, "refresh_token_callback": refresh_cb} client1 = MongoClient(self.uri_single, authMechanismProperties=props) client1.test.test.find_one() client2 = MongoClient(self.uri_single, authMechanismProperties=props) diff --git a/test/crud_v2_format.py b/test/crud_v2_format.py index 4118dfef9f..f711a125c2 100644 --- a/test/crud_v2_format.py +++ b/test/crud_v2_format.py @@ -27,7 +27,7 @@ class TestCrudV2(SpecRunner): def allowable_errors(self, op): """Override expected error classes.""" - errors = super(TestCrudV2, self).allowable_errors(op) + errors = super().allowable_errors(op) errors += (ValueError,) return errors @@ -51,4 +51,4 @@ def setup_scenario(self, scenario_def): """Allow specs to override a test's setup.""" # PYTHON-1935 Only create the collection if there is data to insert. if scenario_def["data"]: - super(TestCrudV2, self).setup_scenario(scenario_def) + super().setup_scenario(scenario_def) diff --git a/test/mockupdb/operations.py b/test/mockupdb/operations.py index 90d7f27c39..692f9aef04 100644 --- a/test/mockupdb/operations.py +++ b/test/mockupdb/operations.py @@ -112,7 +112,7 @@ ] -_ops_by_name = dict([(op.name, op) for op in operations]) +_ops_by_name = {op.name: op for op in operations} Upgrade = namedtuple("Upgrade", ["name", "function", "old", "new", "wire_version"]) diff --git a/test/mockupdb/test_handshake.py b/test/mockupdb/test_handshake.py index 39188e8ad0..d3f8922c4c 100644 --- a/test/mockupdb/test_handshake.py +++ b/test/mockupdb/test_handshake.py @@ -247,6 +247,7 @@ def responder(request): } ) ) + return None else: return request.reply(**primary_response) diff --git a/test/mockupdb/test_mixed_version_sharded.py b/test/mockupdb/test_mixed_version_sharded.py index dc2cd57380..7813069c99 100644 --- a/test/mockupdb/test_mixed_version_sharded.py +++ b/test/mockupdb/test_mixed_version_sharded.py @@ -46,7 +46,7 @@ def setup_server(self, upgrade): "ismaster", ismaster=True, msg="isdbgrid", maxWireVersion=upgrade.wire_version ) - self.mongoses_uri = "mongodb://%s,%s" % ( + self.mongoses_uri = "mongodb://{},{}".format( self.mongos_old.address_string, self.mongos_new.address_string, ) diff --git a/test/mockupdb/test_mongos_command_read_mode.py b/test/mockupdb/test_mongos_command_read_mode.py index 997f5af118..62bd76cf0f 100644 --- a/test/mockupdb/test_mongos_command_read_mode.py +++ b/test/mockupdb/test_mongos_command_read_mode.py @@ -110,7 +110,7 @@ def generate_mongos_read_mode_tests(): # Skip something like command('foo', read_preference=SECONDARY). continue test = create_mongos_read_mode_test(mode, operation) - test_name = "test_%s_with_mode_%s" % (operation.name.replace(" ", "_"), mode) + test_name = "test_{}_with_mode_{}".format(operation.name.replace(" ", "_"), mode) test.__name__ = test_name setattr(TestMongosCommandReadMode, test_name, test) diff --git a/test/mockupdb/test_network_disconnect_primary.py b/test/mockupdb/test_network_disconnect_primary.py index ea13a3b042..dd14abf84f 100755 --- a/test/mockupdb/test_network_disconnect_primary.py +++ b/test/mockupdb/test_network_disconnect_primary.py @@ -26,7 +26,7 @@ def test_network_disconnect_primary(self): # Application operation fails against primary. Test that topology # type changes from ReplicaSetWithPrimary to ReplicaSetNoPrimary. # http://bit.ly/1B5ttuL - primary, secondary = servers = [MockupDB() for _ in range(2)] + primary, secondary = servers = (MockupDB() for _ in range(2)) for server in servers: server.run() self.addCleanup(server.stop) diff --git a/test/mockupdb/test_op_msg.py b/test/mockupdb/test_op_msg.py index 22fe38fd02..e8542e2fe5 100755 --- a/test/mockupdb/test_op_msg.py +++ b/test/mockupdb/test_op_msg.py @@ -304,7 +304,7 @@ def test(self): def create_tests(ops): for op in ops: - test_name = "test_op_msg_%s" % (op.name,) + test_name = f"test_op_msg_{op.name}" setattr(TestOpMsg, test_name, operation_test(op)) diff --git a/test/mockupdb/test_op_msg_read_preference.py b/test/mockupdb/test_op_msg_read_preference.py index b377f4cf69..a3aef1541e 100644 --- a/test/mockupdb/test_op_msg_read_preference.py +++ b/test/mockupdb/test_op_msg_read_preference.py @@ -35,7 +35,7 @@ class OpMsgReadPrefBase(unittest.TestCase): @classmethod def setUpClass(cls): - super(OpMsgReadPrefBase, cls).setUpClass() + super().setUpClass() @classmethod def add_test(cls, mode, test_name, test): @@ -50,7 +50,7 @@ def setup_client(self, read_preference): class TestOpMsgMongos(OpMsgReadPrefBase): @classmethod def setUpClass(cls): - super(TestOpMsgMongos, cls).setUpClass() + super().setUpClass() auto_ismaster = { "ismaster": True, "msg": "isdbgrid", # Mongos. @@ -64,13 +64,13 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): cls.primary.stop() - super(TestOpMsgMongos, cls).tearDownClass() + super().tearDownClass() class TestOpMsgReplicaSet(OpMsgReadPrefBase): @classmethod def setUpClass(cls): - super(TestOpMsgReplicaSet, cls).setUpClass() + super().setUpClass() cls.primary, cls.secondary = MockupDB(), MockupDB() for server in cls.primary, cls.secondary: server.run() @@ -94,7 +94,7 @@ def setUpClass(cls): def tearDownClass(cls): for server in cls.primary, cls.secondary: server.stop() - super(TestOpMsgReplicaSet, cls).tearDownClass() + super().tearDownClass() @classmethod def add_test(cls, mode, test_name, test): @@ -118,7 +118,7 @@ class TestOpMsgSingle(OpMsgReadPrefBase): @classmethod def setUpClass(cls): - super(TestOpMsgSingle, cls).setUpClass() + super().setUpClass() auto_ismaster = { "ismaster": True, "minWireVersion": 2, @@ -131,7 +131,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): cls.primary.stop() - super(TestOpMsgSingle, cls).tearDownClass() + super().tearDownClass() def create_op_msg_read_mode_test(mode, operation): @@ -181,7 +181,7 @@ def generate_op_msg_read_mode_tests(): for entry in matrix: mode, operation = entry test = create_op_msg_read_mode_test(mode, operation) - test_name = "test_%s_with_mode_%s" % (operation.name.replace(" ", "_"), mode) + test_name = "test_{}_with_mode_{}".format(operation.name.replace(" ", "_"), mode) test.__name__ = test_name for cls in TestOpMsgMongos, TestOpMsgReplicaSet, TestOpMsgSingle: cls.add_test(mode, test_name, test) diff --git a/test/mockupdb/test_reset_and_request_check.py b/test/mockupdb/test_reset_and_request_check.py index 841cd41846..c554499379 100755 --- a/test/mockupdb/test_reset_and_request_check.py +++ b/test/mockupdb/test_reset_and_request_check.py @@ -26,7 +26,7 @@ class TestResetAndRequestCheck(unittest.TestCase): def __init__(self, *args, **kwargs): - super(TestResetAndRequestCheck, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self.ismaster_time = 0.0 self.client = None self.server = None @@ -143,7 +143,7 @@ def generate_reset_tests(): for entry in matrix: operation, (test_method, name) = entry test = create_reset_test(operation, test_method) - test_name = "%s_%s" % (name, operation.name.replace(" ", "_")) + test_name = "{}_{}".format(name, operation.name.replace(" ", "_")) test.__name__ = test_name setattr(TestResetAndRequestCheck, test_name, test) diff --git a/test/mockupdb/test_slave_okay_sharded.py b/test/mockupdb/test_slave_okay_sharded.py index 18f2016126..5a590bcf15 100644 --- a/test/mockupdb/test_slave_okay_sharded.py +++ b/test/mockupdb/test_slave_okay_sharded.py @@ -43,7 +43,7 @@ def setup_server(self): "ismaster", minWireVersion=2, maxWireVersion=6, ismaster=True, msg="isdbgrid" ) - self.mongoses_uri = "mongodb://%s,%s" % ( + self.mongoses_uri = "mongodb://{},{}".format( self.mongos1.address_string, self.mongos2.address_string, ) @@ -59,7 +59,7 @@ def test(self): elif operation.op_type == "must-use-primary": slave_ok = False else: - assert False, "unrecognized op_type %r" % operation.op_type + raise AssertionError("unrecognized op_type %r" % operation.op_type) pref = make_read_preference(read_pref_mode_from_name(mode), tag_sets=None) @@ -84,7 +84,7 @@ def generate_slave_ok_sharded_tests(): for entry in matrix: mode, operation = entry test = create_slave_ok_sharded_test(mode, operation) - test_name = "test_%s_with_mode_%s" % (operation.name.replace(" ", "_"), mode) + test_name = "test_{}_with_mode_{}".format(operation.name.replace(" ", "_"), mode) test.__name__ = test_name setattr(TestSlaveOkaySharded, test_name, test) diff --git a/test/mockupdb/test_slave_okay_single.py b/test/mockupdb/test_slave_okay_single.py index 4b2846490f..90b99df496 100644 --- a/test/mockupdb/test_slave_okay_single.py +++ b/test/mockupdb/test_slave_okay_single.py @@ -78,7 +78,7 @@ def generate_slave_ok_single_tests(): mode, (server_type, ismaster), operation = entry test = create_slave_ok_single_test(mode, server_type, ismaster, operation) - test_name = "test_%s_%s_with_mode_%s" % ( + test_name = "test_{}_{}_with_mode_{}".format( operation.name.replace(" ", "_"), server_type, mode, diff --git a/test/mod_wsgi_test/test_client.py b/test/mod_wsgi_test/test_client.py index bfdae9e824..6d3b299700 100644 --- a/test/mod_wsgi_test/test_client.py +++ b/test/mod_wsgi_test/test_client.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Test client for mod_wsgi application, see bug PYTHON-353. -""" +"""Test client for mod_wsgi application, see bug PYTHON-353.""" import _thread as thread import sys @@ -91,14 +90,14 @@ class URLGetterThread(threading.Thread): counter = 0 def __init__(self, options, url, nrequests_per_thread): - super(URLGetterThread, self).__init__() + super().__init__() self.options = options self.url = url self.nrequests_per_thread = nrequests_per_thread self.errors = 0 def run(self): - for i in range(self.nrequests_per_thread): + for _i in range(self.nrequests_per_thread): try: get(url) except Exception as e: @@ -128,9 +127,8 @@ def main(options, mode, url): if options.verbose: print( - "Getting %s %s times total in %s threads, " - "%s times per thread" - % ( + "Getting {} {} times total in {} threads, " + "{} times per thread".format( url, nrequests_per_thread * options.nthreads, options.nthreads, @@ -154,7 +152,7 @@ def main(options, mode, url): else: assert mode == "serial" if options.verbose: - print("Getting %s %s times in one thread" % (url, options.nrequests)) + print(f"Getting {url} {options.nrequests} times in one thread") for i in range(1, options.nrequests + 1): try: diff --git a/test/ocsp/test_ocsp.py b/test/ocsp/test_ocsp.py index a0770afefa..dc2650499f 100644 --- a/test/ocsp/test_ocsp.py +++ b/test/ocsp/test_ocsp.py @@ -40,7 +40,7 @@ def _connect(options): - uri = ("mongodb://localhost:27017/?serverSelectionTimeoutMS=%s&tlsCAFile=%s&%s") % ( + uri = ("mongodb://localhost:27017/?serverSelectionTimeoutMS={}&tlsCAFile={}&{}").format( TIMEOUT_MS, CA_FILE, options, diff --git a/test/performance/perf_test.py b/test/performance/perf_test.py index 3cb4b5d5d1..062058e09d 100644 --- a/test/performance/perf_test.py +++ b/test/performance/perf_test.py @@ -58,7 +58,7 @@ def tearDownModule(): print(output) -class Timer(object): +class Timer: def __enter__(self): self.start = time.monotonic() return self @@ -68,7 +68,7 @@ def __exit__(self, *args): self.interval = self.end - self.start -class PerformanceTest(object): +class PerformanceTest: dataset: Any data_size: Any do_task: Any @@ -85,7 +85,7 @@ def tearDown(self): name = self.__class__.__name__ median = self.percentile(50) bytes_per_sec = self.data_size / median - print("Running %s. MEDIAN=%s" % (self.__class__.__name__, self.percentile(50))) + print(f"Running {self.__class__.__name__}. MEDIAN={self.percentile(50)}") result_data.append( { "info": { @@ -113,6 +113,7 @@ def percentile(self, percentile): return sorted_results[percentile_index] else: self.fail("Test execution failed") + return None def runTest(self): results = [] @@ -202,7 +203,7 @@ class TestDocument(PerformanceTest): def setUp(self): # Location of test data. with open( - os.path.join(TEST_PATH, os.path.join("single_and_multi_document", self.dataset)), "r" + os.path.join(TEST_PATH, os.path.join("single_and_multi_document", self.dataset)) ) as data: self.document = json.loads(data.read()) @@ -210,7 +211,7 @@ def setUp(self): self.client.drop_database("perftest") def tearDown(self): - super(TestDocument, self).tearDown() + super().tearDown() self.client.drop_database("perftest") def before(self): @@ -225,7 +226,7 @@ class TestFindOneByID(TestDocument, unittest.TestCase): def setUp(self): self.dataset = "tweet.json" - super(TestFindOneByID, self).setUp() + super().setUp() documents = [self.document.copy() for _ in range(NUM_DOCS)] self.corpus = self.client.perftest.corpus @@ -249,7 +250,7 @@ class TestSmallDocInsertOne(TestDocument, unittest.TestCase): def setUp(self): self.dataset = "small_doc.json" - super(TestSmallDocInsertOne, self).setUp() + super().setUp() self.documents = [self.document.copy() for _ in range(NUM_DOCS)] @@ -264,7 +265,7 @@ class TestLargeDocInsertOne(TestDocument, unittest.TestCase): def setUp(self): self.dataset = "large_doc.json" - super(TestLargeDocInsertOne, self).setUp() + super().setUp() self.documents = [self.document.copy() for _ in range(10)] @@ -280,7 +281,7 @@ class TestFindManyAndEmptyCursor(TestDocument, unittest.TestCase): def setUp(self): self.dataset = "tweet.json" - super(TestFindManyAndEmptyCursor, self).setUp() + super().setUp() for _ in range(10): self.client.perftest.command("insert", "corpus", documents=[self.document] * 1000) @@ -301,7 +302,7 @@ class TestSmallDocBulkInsert(TestDocument, unittest.TestCase): def setUp(self): self.dataset = "small_doc.json" - super(TestSmallDocBulkInsert, self).setUp() + super().setUp() self.documents = [self.document.copy() for _ in range(NUM_DOCS)] def before(self): @@ -316,7 +317,7 @@ class TestLargeDocBulkInsert(TestDocument, unittest.TestCase): def setUp(self): self.dataset = "large_doc.json" - super(TestLargeDocBulkInsert, self).setUp() + super().setUp() self.documents = [self.document.copy() for _ in range(10)] def before(self): @@ -342,7 +343,7 @@ def setUp(self): self.bucket = GridFSBucket(self.client.perftest) def tearDown(self): - super(TestGridFsUpload, self).tearDown() + super().tearDown() self.client.drop_database("perftest") def before(self): @@ -368,7 +369,7 @@ def setUp(self): self.uploaded_id = self.bucket.upload_from_stream("gridfstest", gfile) def tearDown(self): - super(TestGridFsDownload, self).tearDown() + super().tearDown() self.client.drop_database("perftest") def do_task(self): @@ -392,14 +393,14 @@ def mp_map(map_func, files): def insert_json_file(filename): assert proc_client is not None - with open(filename, "r") as data: + with open(filename) as data: coll = proc_client.perftest.corpus coll.insert_many([json.loads(line) for line in data]) def insert_json_file_with_file_id(filename): documents = [] - with open(filename, "r") as data: + with open(filename) as data: for line in data: doc = json.loads(line) doc["file"] = filename @@ -461,7 +462,7 @@ def after(self): self.client.perftest.drop_collection("corpus") def tearDown(self): - super(TestJsonMultiImport, self).tearDown() + super().tearDown() self.client.drop_database("perftest") @@ -482,7 +483,7 @@ def do_task(self): mp_map(read_json_file, self.files) def tearDown(self): - super(TestJsonMultiExport, self).tearDown() + super().tearDown() self.client.drop_database("perftest") @@ -505,7 +506,7 @@ def do_task(self): mp_map(insert_gridfs_file, self.files) def tearDown(self): - super(TestGridFsMultiFileUpload, self).tearDown() + super().tearDown() self.client.drop_database("perftest") @@ -529,7 +530,7 @@ def do_task(self): mp_map(read_gridfs_file, self.files) def tearDown(self): - super(TestGridFsMultiFileDownload, self).tearDown() + super().tearDown() self.client.drop_database("perftest") diff --git a/test/pymongo_mocks.py b/test/pymongo_mocks.py index 580c5da993..2e7fda21e0 100644 --- a/test/pymongo_mocks.py +++ b/test/pymongo_mocks.py @@ -40,7 +40,7 @@ def __init__(self, client, pair, *args, **kwargs): @contextlib.contextmanager def get_socket(self, handler=None): client = self.client - host_and_port = "%s:%s" % (self.mock_host, self.mock_port) + host_and_port = f"{self.mock_host}:{self.mock_port}" if host_and_port in client.mock_down_hosts: raise AutoReconnect("mock error") @@ -54,7 +54,7 @@ def get_socket(self, handler=None): yield sock_info -class DummyMonitor(object): +class DummyMonitor: def __init__(self, server_description, topology, pool, topology_settings): self._server_description = server_description self.opened = False @@ -99,7 +99,7 @@ def __init__( arbiters=None, down_hosts=None, *args, - **kwargs + **kwargs, ): """A MongoClient connected to the default server, with a mock topology. @@ -144,7 +144,7 @@ def __init__( client_options = client_context.default_client_options.copy() client_options.update(kwargs) - super(MockClient, self).__init__(*args, **client_options) + super().__init__(*args, **client_options) def kill_host(self, host): """Host is like 'a:1'.""" diff --git a/test/qcheck.py b/test/qcheck.py index 4cce7b5bc8..52e4c46b8b 100644 --- a/test/qcheck.py +++ b/test/qcheck.py @@ -116,7 +116,8 @@ def gen_regexp(gen_length): # TODO our patterns only consist of one letter. # this is because of a bug in CPython's regex equality testing, # which I haven't quite tracked down, so I'm just ignoring it... - pattern = lambda: "".join(gen_list(choose_lifted("a"), gen_length)()) + def pattern(): + return "".join(gen_list(choose_lifted("a"), gen_length)()) def gen_flags(): flags = 0 @@ -230,9 +231,9 @@ def check(predicate, generator): try: if not predicate(case): reduction = reduce(case, predicate) - counter_examples.append("after %s reductions: %r" % reduction) + counter_examples.append("after {} reductions: {!r}".format(*reduction)) except: - counter_examples.append("%r : %s" % (case, traceback.format_exc())) + counter_examples.append(f"{case!r} : {traceback.format_exc()}") return counter_examples diff --git a/test/sigstop_sigcont.py b/test/sigstop_sigcont.py index 87b4f62038..6f84b6a6a2 100644 --- a/test/sigstop_sigcont.py +++ b/test/sigstop_sigcont.py @@ -84,7 +84,7 @@ def main(uri: str) -> None: if len(sys.argv) != 2: print("unknown or missing options") print(f"usage: python3 {sys.argv[0]} 'mongodb://localhost'") - exit(1) + sys.exit(1) # Enable logs in this format: # 2022-03-30 12:40:55,582 INFO diff --git a/test/test_auth.py b/test/test_auth.py index 7db2247746..f9a9af4d5a 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -67,7 +67,7 @@ class AutoAuthenticateThread(threading.Thread): """ def __init__(self, collection): - super(AutoAuthenticateThread, self).__init__() + super().__init__() self.collection = collection self.success = False @@ -89,10 +89,10 @@ def setUpClass(cls): cls.service_realm_required = ( GSSAPI_SERVICE_REALM is not None and GSSAPI_SERVICE_REALM not in GSSAPI_PRINCIPAL ) - mech_properties = "SERVICE_NAME:%s" % (GSSAPI_SERVICE_NAME,) - mech_properties += ",CANONICALIZE_HOST_NAME:%s" % (GSSAPI_CANONICALIZE,) + mech_properties = f"SERVICE_NAME:{GSSAPI_SERVICE_NAME}" + mech_properties += f",CANONICALIZE_HOST_NAME:{GSSAPI_CANONICALIZE}" if GSSAPI_SERVICE_REALM is not None: - mech_properties += ",SERVICE_REALM:%s" % (GSSAPI_SERVICE_REALM,) + mech_properties += f",SERVICE_REALM:{GSSAPI_SERVICE_REALM}" cls.mech_properties = mech_properties def test_credentials_hashing(self): @@ -111,8 +111,8 @@ def test_credentials_hashing(self): "GSSAPI", None, "user", "pass", {"authmechanismproperties": {"SERVICE_NAME": "B"}}, None ) - self.assertEqual(1, len(set([creds1, creds2]))) - self.assertEqual(3, len(set([creds0, creds1, creds2, creds3]))) + self.assertEqual(1, len({creds1, creds2})) + self.assertEqual(3, len({creds0, creds1, creds2, creds3})) @ignore_deprecations def test_gssapi_simple(self): @@ -160,7 +160,7 @@ def test_gssapi_simple(self): client[GSSAPI_DB].collection.find_one() # Log in using URI, with authMechanismProperties. - mech_uri = uri + "&authMechanismProperties=%s" % (self.mech_properties,) + mech_uri = uri + f"&authMechanismProperties={self.mech_properties}" client = MongoClient(mech_uri) client[GSSAPI_DB].collection.find_one() @@ -179,7 +179,7 @@ def test_gssapi_simple(self): client[GSSAPI_DB].list_collection_names() - uri = uri + "&replicaSet=%s" % (str(set_name),) + uri = uri + f"&replicaSet={str(set_name)}" client = MongoClient(uri) client[GSSAPI_DB].list_collection_names() @@ -196,7 +196,7 @@ def test_gssapi_simple(self): client[GSSAPI_DB].list_collection_names() - mech_uri = mech_uri + "&replicaSet=%s" % (str(set_name),) + mech_uri = mech_uri + f"&replicaSet={str(set_name)}" client = MongoClient(mech_uri) client[GSSAPI_DB].list_collection_names() @@ -336,12 +336,12 @@ def auth_string(user, password): class TestSCRAMSHA1(IntegrationTest): @client_context.require_auth def setUp(self): - super(TestSCRAMSHA1, self).setUp() + super().setUp() client_context.create_user("pymongo_test", "user", "pass", roles=["userAdmin", "readWrite"]) def tearDown(self): client_context.drop_user("pymongo_test", "user") - super(TestSCRAMSHA1, self).tearDown() + super().tearDown() def test_scram_sha1(self): host, port = client_context.host, client_context.port @@ -368,16 +368,16 @@ class TestSCRAM(IntegrationTest): @client_context.require_auth @client_context.require_version_min(3, 7, 2) def setUp(self): - super(TestSCRAM, self).setUp() + super().setUp() self._SENSITIVE_COMMANDS = monitoring._SENSITIVE_COMMANDS - monitoring._SENSITIVE_COMMANDS = set([]) + monitoring._SENSITIVE_COMMANDS = set() self.listener = AllowListEventListener("saslStart") def tearDown(self): monitoring._SENSITIVE_COMMANDS = self._SENSITIVE_COMMANDS client_context.client.testscram.command("dropAllUsersFromDatabase") client_context.client.drop_database("testscram") - super(TestSCRAM, self).tearDown() + super().tearDown() def test_scram_skip_empty_exchange(self): listener = AllowListEventListener("saslStart", "saslContinue") @@ -597,14 +597,14 @@ def test_scram_threaded(self): class TestAuthURIOptions(IntegrationTest): @client_context.require_auth def setUp(self): - super(TestAuthURIOptions, self).setUp() + super().setUp() client_context.create_user("admin", "admin", "pass") client_context.create_user("pymongo_test", "user", "pass", ["userAdmin", "readWrite"]) def tearDown(self): client_context.drop_user("pymongo_test", "user") client_context.drop_user("admin", "admin") - super(TestAuthURIOptions, self).tearDown() + super().tearDown() def test_uri_options(self): # Test default to admin diff --git a/test/test_auth_spec.py b/test/test_auth_spec.py index 78f4d21929..ebcc4eeb7d 100644 --- a/test/test_auth_spec.py +++ b/test/test_auth_spec.py @@ -67,7 +67,7 @@ def run_test(self): expected = credential["mechanism_properties"] if expected is not None: actual = credentials.mechanism_properties - for key, val in expected.items(): + for key, _val in expected.items(): if "SERVICE_NAME" in expected: self.assertEqual(actual.service_name, expected["SERVICE_NAME"]) elif "CANONICALIZE_HOST_NAME" in expected: @@ -91,7 +91,7 @@ def run_test(self): actual.refresh_token_callback, expected["refresh_token_callback"] ) else: - self.fail("Unhandled property: %s" % (key,)) + self.fail(f"Unhandled property: {key}") else: if credential["mechanism"] == "MONGODB-AWS": self.assertIsNone(credentials.mechanism_properties.aws_session_token) @@ -111,7 +111,7 @@ def create_tests(): continue test_method = create_test(test_case) name = str(test_case["description"].lower().replace(" ", "_")) - setattr(TestAuthSpec, "test_%s_%s" % (test_suffix, name), test_method) + setattr(TestAuthSpec, f"test_{test_suffix}_{name}", test_method) create_tests() diff --git a/test/test_binary.py b/test/test_binary.py index 65abdca796..158a990290 100644 --- a/test/test_binary.py +++ b/test/test_binary.py @@ -122,15 +122,15 @@ def test_equality(self): def test_repr(self): one = Binary(b"hello world") - self.assertEqual(repr(one), "Binary(%s, 0)" % (repr(b"hello world"),)) + self.assertEqual(repr(one), "Binary({}, 0)".format(repr(b"hello world"))) two = Binary(b"hello world", 2) - self.assertEqual(repr(two), "Binary(%s, 2)" % (repr(b"hello world"),)) + self.assertEqual(repr(two), "Binary({}, 2)".format(repr(b"hello world"))) three = Binary(b"\x08\xFF") - self.assertEqual(repr(three), "Binary(%s, 0)" % (repr(b"\x08\xFF"),)) + self.assertEqual(repr(three), "Binary({}, 0)".format(repr(b"\x08\xFF"))) four = Binary(b"\x08\xFF", 2) - self.assertEqual(repr(four), "Binary(%s, 2)" % (repr(b"\x08\xFF"),)) + self.assertEqual(repr(four), "Binary({}, 2)".format(repr(b"\x08\xFF"))) five = Binary(b"test", 100) - self.assertEqual(repr(five), "Binary(%s, 100)" % (repr(b"test"),)) + self.assertEqual(repr(five), "Binary({}, 100)".format(repr(b"test"))) def test_hash(self): one = Binary(b"hello world") @@ -351,7 +351,7 @@ class TestUuidSpecExplicitCoding(unittest.TestCase): @classmethod def setUpClass(cls): - super(TestUuidSpecExplicitCoding, cls).setUpClass() + super().setUpClass() cls.uuid = uuid.UUID("00112233445566778899AABBCCDDEEFF") @staticmethod @@ -452,7 +452,7 @@ class TestUuidSpecImplicitCoding(IntegrationTest): @classmethod def setUpClass(cls): - super(TestUuidSpecImplicitCoding, cls).setUpClass() + super().setUpClass() cls.uuid = uuid.UUID("00112233445566778899AABBCCDDEEFF") @staticmethod diff --git a/test/test_bson.py b/test/test_bson.py index a8fd1fef45..a6e6352333 100644 --- a/test/test_bson.py +++ b/test/test_bson.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Copyright 2009-present MongoDB, Inc. # @@ -370,7 +369,7 @@ def test_invalid_decodes(self): ), ] for i, data in enumerate(bad_bsons): - msg = "bad_bson[{}]".format(i) + msg = f"bad_bson[{i}]" with self.assertRaises(InvalidBSON, msg=msg): decode_all(data) with self.assertRaises(InvalidBSON, msg=msg): @@ -491,7 +490,7 @@ def test_basic_encode(self): def test_unknown_type(self): # Repr value differs with major python version - part = "type %r for fieldname 'foo'" % (b"\x14",) + part = "type {!r} for fieldname 'foo'".format(b"\x14") docs = [ b"\x0e\x00\x00\x00\x14foo\x00\x01\x00\x00\x00\x00", (b"\x16\x00\x00\x00\x04foo\x00\x0c\x00\x00\x00\x140\x00\x01\x00\x00\x00\x00\x00"), @@ -648,7 +647,7 @@ def test_small_long_encode_decode(self): encoded1 = encode({"x": 256}) decoded1 = decode(encoded1)["x"] self.assertEqual(256, decoded1) - self.assertEqual(type(256), type(decoded1)) + self.assertEqual(int, type(decoded1)) encoded2 = encode({"x": Int64(256)}) decoded2 = decode(encoded2)["x"] @@ -925,7 +924,7 @@ def test_bad_id_keys(self): def test_bson_encode_thread_safe(self): def target(i): for j in range(1000): - my_int = type("MyInt_%s_%s" % (i, j), (int,), {}) + my_int = type(f"MyInt_{i}_{j}", (int,), {}) bson.encode({"my_int": my_int()}) threads = [ExceptionCatchingThread(target=target, args=(i,)) for i in range(3)] @@ -939,7 +938,7 @@ def target(i): self.assertIsNone(t.exc) def test_raise_invalid_document(self): - class Wrapper(object): + class Wrapper: def __init__(self, val): self.val = val diff --git a/test/test_bulk.py b/test/test_bulk.py index ac7073c0ef..6a2af3143c 100644 --- a/test/test_bulk.py +++ b/test/test_bulk.py @@ -50,12 +50,12 @@ class BulkTestBase(IntegrationTest): @classmethod def setUpClass(cls): - super(BulkTestBase, cls).setUpClass() + super().setUpClass() cls.coll = cls.db.test cls.coll_w0 = cls.coll.with_options(write_concern=WriteConcern(w=0)) def setUp(self): - super(BulkTestBase, self).setUp() + super().setUp() self.coll.drop() def assertEqualResponse(self, expected, actual): @@ -93,7 +93,7 @@ def assertEqualResponse(self, expected, actual): self.assertEqual( actual.get(key), value, - "%r value of %r does not match expected %r" % (key, actual.get(key), value), + f"{key!r} value of {actual.get(key)!r} does not match expected {value!r}", ) def assertEqualUpsert(self, expected, actual): @@ -793,10 +793,10 @@ class BulkAuthorizationTestBase(BulkTestBase): @client_context.require_auth @client_context.require_no_api_version def setUpClass(cls): - super(BulkAuthorizationTestBase, cls).setUpClass() + super().setUpClass() def setUp(self): - super(BulkAuthorizationTestBase, self).setUp() + super().setUp() client_context.create_user(self.db.name, "readonly", "pw", ["read"]) self.db.command( "createRole", @@ -902,7 +902,7 @@ def test_no_remove(self): InsertOne({"x": 3}), # Never attempted. ] self.assertRaises(OperationFailure, coll.bulk_write, requests) - self.assertEqual(set([1, 2]), set(self.coll.distinct("x"))) + self.assertEqual({1, 2}, set(self.coll.distinct("x"))) class TestBulkWriteConcern(BulkTestBase): @@ -911,7 +911,7 @@ class TestBulkWriteConcern(BulkTestBase): @classmethod def setUpClass(cls): - super(TestBulkWriteConcern, cls).setUpClass() + super().setUpClass() cls.w = client_context.w cls.secondary = None if cls.w is not None and cls.w > 1: diff --git a/test/test_change_stream.py b/test/test_change_stream.py index 2388a6e1f4..c9ddfcd137 100644 --- a/test/test_change_stream.py +++ b/test/test_change_stream.py @@ -104,7 +104,8 @@ def get_resume_token(self, invalidate=False): def get_start_at_operation_time(self): """Get an operationTime. Advances the operation clock beyond the most - recently returned timestamp.""" + recently returned timestamp. + """ optime = self.client.admin.command("ping")["operationTime"] return Timestamp(optime.time, optime.inc + 1) @@ -120,7 +121,7 @@ def kill_change_stream_cursor(self, change_stream): client._close_cursor_now(cursor.cursor_id, address) -class APITestsMixin(object): +class APITestsMixin: @no_type_check def test_watch(self): with self.change_stream( @@ -208,7 +209,7 @@ def test_try_next_runs_one_getmore(self): # Stream still works after a resume. coll.insert_one({"_id": 3}) wait_until(lambda: stream.try_next() is not None, "get change from try_next") - self.assertEqual(set(listener.started_command_names()), set(["getMore"])) + self.assertEqual(set(listener.started_command_names()), {"getMore"}) self.assertIsNone(stream.try_next()) @no_type_check @@ -249,7 +250,7 @@ def test_start_at_operation_time(self): coll.insert_many([{"data": i} for i in range(ndocs)]) with self.change_stream(start_at_operation_time=optime) as cs: - for i in range(ndocs): + for _i in range(ndocs): cs.next() @no_type_check @@ -443,7 +444,7 @@ def test_start_after_resume_process_without_changes(self): self.assertEqual(change["fullDocument"], {"_id": 2}) -class ProseSpecTestsMixin(object): +class ProseSpecTestsMixin: @no_type_check def _client_with_listener(self, *commands): listener = AllowListEventListener(*commands) @@ -461,7 +462,8 @@ def _populate_and_exhaust_change_stream(self, change_stream, batch_size=3): def _get_expected_resume_token_legacy(self, stream, listener, previous_change=None): """Predicts what the resume token should currently be for server versions that don't support postBatchResumeToken. Assumes the stream - has never returned any changes if previous_change is None.""" + has never returned any changes if previous_change is None. + """ if previous_change is None: agg_cmd = listener.started_events[0] stage = agg_cmd.command["pipeline"][0]["$changeStream"] @@ -474,7 +476,8 @@ def _get_expected_resume_token(self, stream, listener, previous_change=None): versions that support postBatchResumeToken. Assumes the stream has never returned any changes if previous_change is None. Assumes listener is a AllowListEventListener that listens for aggregate and - getMore commands.""" + getMore commands. + """ if previous_change is None or stream._cursor._has_next(): token = self._get_expected_resume_token_legacy(stream, listener, previous_change) if token is not None: @@ -767,14 +770,14 @@ class TestClusterChangeStream(TestChangeStreamBase, APITestsMixin): @client_context.require_version_min(4, 0, 0, -1) @client_context.require_change_streams def setUpClass(cls): - super(TestClusterChangeStream, cls).setUpClass() + super().setUpClass() cls.dbs = [cls.db, cls.client.pymongo_test_2] @classmethod def tearDownClass(cls): for db in cls.dbs: cls.client.drop_database(db) - super(TestClusterChangeStream, cls).tearDownClass() + super().tearDownClass() def change_stream_with_client(self, client, *args, **kwargs): return client.watch(*args, **kwargs) @@ -828,7 +831,7 @@ class TestDatabaseChangeStream(TestChangeStreamBase, APITestsMixin): @client_context.require_version_min(4, 0, 0, -1) @client_context.require_change_streams def setUpClass(cls): - super(TestDatabaseChangeStream, cls).setUpClass() + super().setUpClass() def change_stream_with_client(self, client, *args, **kwargs): return client[self.db.name].watch(*args, **kwargs) @@ -913,7 +916,7 @@ class TestCollectionChangeStream(TestChangeStreamBase, APITestsMixin, ProseSpecT @classmethod @client_context.require_change_streams def setUpClass(cls): - super(TestCollectionChangeStream, cls).setUpClass() + super().setUpClass() def setUp(self): # Use a new collection for each test. @@ -1044,17 +1047,17 @@ class TestAllLegacyScenarios(IntegrationTest): @classmethod @client_context.require_connection def setUpClass(cls): - super(TestAllLegacyScenarios, cls).setUpClass() + super().setUpClass() cls.listener = AllowListEventListener("aggregate", "getMore") cls.client = rs_or_single_client(event_listeners=[cls.listener]) @classmethod def tearDownClass(cls): cls.client.close() - super(TestAllLegacyScenarios, cls).tearDownClass() + super().tearDownClass() def setUp(self): - super(TestAllLegacyScenarios, self).setUp() + super().setUp() self.listener.reset() def setUpCluster(self, scenario_dict): @@ -1088,7 +1091,8 @@ def setFailPoint(self, scenario_dict): def assert_list_contents_are_subset(self, superlist, sublist): """Check that each element in sublist is a subset of the corresponding - element in superlist.""" + element in superlist. + """ self.assertEqual(len(superlist), len(sublist)) for sup, sub in zip(superlist, sublist): if isinstance(sub, dict): @@ -1104,7 +1108,7 @@ def assert_dict_is_subset(self, superdict, subdict): exempt_fields = ["documentKey", "_id", "getMore"] for key, value in subdict.items(): if key not in superdict: - self.fail("Key %s not found in %s" % (key, superdict)) + self.fail(f"Key {key} not found in {superdict}") if isinstance(value, dict): self.assert_dict_is_subset(superdict[key], value) continue diff --git a/test/test_client.py b/test/test_client.py index 624c460c08..ec2b4bac97 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -325,7 +325,7 @@ def test_metadata(self): self.assertRaises(TypeError, MongoClient, driver=("Foo", "1", "a")) # Test appending to driver info. metadata["driver"]["name"] = "PyMongo|FooDriver" - metadata["driver"]["version"] = "%s|1.2.3" % (_METADATA["driver"]["version"],) + metadata["driver"]["version"] = "{}|1.2.3".format(_METADATA["driver"]["version"]) client = MongoClient( "foo", 27017, @@ -335,7 +335,7 @@ def test_metadata(self): ) options = client._MongoClient__options self.assertEqual(options.pool_options.metadata, metadata) - metadata["platform"] = "%s|FooPlatform" % (_METADATA["platform"],) + metadata["platform"] = "{}|FooPlatform".format(_METADATA["platform"]) client = MongoClient( "foo", 27017, @@ -347,7 +347,7 @@ def test_metadata(self): self.assertEqual(options.pool_options.metadata, metadata) def test_kwargs_codec_options(self): - class MyFloatType(object): + class MyFloatType: def __init__(self, x): self.__x = x @@ -704,7 +704,7 @@ def test_init_disconnected_with_auth(self): self.assertRaises(ConnectionFailure, c.pymongo_test.test.find_one) def test_equality(self): - seed = "%s:%s" % list(self.client._topology_settings.seeds)[0] + seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) c = rs_or_single_client(seed, connect=False) self.addCleanup(c.close) self.assertEqual(client_context.client, c) @@ -723,7 +723,7 @@ def test_equality(self): ) def test_hashable(self): - seed = "%s:%s" % list(self.client._topology_settings.seeds)[0] + seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) c = rs_or_single_client(seed, connect=False) self.addCleanup(c.close) self.assertIn(c, {client_context.client}) @@ -735,7 +735,7 @@ def test_host_w_port(self): with self.assertRaises(ValueError): connected( MongoClient( - "%s:1234567" % (client_context.host,), + f"{client_context.host}:1234567", connectTimeoutMS=1, serverSelectionTimeoutMS=10, ) @@ -1002,7 +1002,7 @@ def test_username_and_password(self): @client_context.require_auth def test_lazy_auth_raises_operation_failure(self): lazy_client = rs_or_single_client_noauth( - "mongodb://user:wrong@%s/pymongo_test" % (client_context.host,), connect=False + f"mongodb://user:wrong@{client_context.host}/pymongo_test", connect=False ) assertRaisesExactly(OperationFailure, lazy_client.test.collection.find_one) @@ -1160,7 +1160,7 @@ def test_ipv6(self): raise SkipTest("Need the ipaddress module to test with SSL") if client_context.auth_enabled: - auth_str = "%s:%s@" % (db_user, db_pwd) + auth_str = f"{db_user}:{db_pwd}@" else: auth_str = "" @@ -1533,7 +1533,7 @@ def test_reset_during_update_pool(self): # Continuously reset the pool. class ResetPoolThread(threading.Thread): def __init__(self, pool): - super(ResetPoolThread, self).__init__() + super().__init__() self.running = True self.pool = pool @@ -1657,7 +1657,7 @@ def test_network_error_message(self): {"mode": {"times": 1}, "data": {"closeConnection": True, "failCommands": ["find"]}} ): assert client.address is not None - expected = "%s:%s: " % client.address + expected = "{}:{}: ".format(*client.address) with self.assertRaisesRegex(AutoReconnect, expected): client.pymongo_test.test.find_one({}) @@ -1836,7 +1836,7 @@ class TestExhaustCursor(IntegrationTest): """Test that clients properly handle errors from exhaust cursors.""" def setUp(self): - super(TestExhaustCursor, self).setUp() + super().setUp() if client_context.is_mongos: raise SkipTest("mongos doesn't support exhaust, SERVER-2627") @@ -2188,23 +2188,33 @@ def _test_network_error(self, operation_callback): self.assertEqual(7, sd_b.max_wire_version) def test_network_error_on_query(self): - callback = lambda client: client.db.collection.find_one() + def callback(client): + return client.db.collection.find_one() + self._test_network_error(callback) def test_network_error_on_insert(self): - callback = lambda client: client.db.collection.insert_one({}) + def callback(client): + return client.db.collection.insert_one({}) + self._test_network_error(callback) def test_network_error_on_update(self): - callback = lambda client: client.db.collection.update_one({}, {"$unset": "x"}) + def callback(client): + return client.db.collection.update_one({}, {"$unset": "x"}) + self._test_network_error(callback) def test_network_error_on_replace(self): - callback = lambda client: client.db.collection.replace_one({}, {}) + def callback(client): + return client.db.collection.replace_one({}, {}) + self._test_network_error(callback) def test_network_error_on_delete(self): - callback = lambda client: client.db.collection.delete_many({}) + def callback(client): + return client.db.collection.delete_many({}) + self._test_network_error(callback) @@ -2227,7 +2237,7 @@ def test_rs_client_does_not_maintain_pool_to_arbiters(self): wait_until(lambda: len(c.nodes) == 3, "connect") self.assertEqual(c.address, ("a", 1)) - self.assertEqual(c.arbiters, set([("c", 3)])) + self.assertEqual(c.arbiters, {("c", 3)}) # Assert that we create 2 and only 2 pooled connections. listener.wait_for_event(monitoring.ConnectionReadyEvent, 2) self.assertEqual(listener.event_count(monitoring.ConnectionCreatedEvent), 2) diff --git a/test/test_client_context.py b/test/test_client_context.py index 9ee5b96d61..72da8dbc34 100644 --- a/test/test_client_context.py +++ b/test/test_client_context.py @@ -28,8 +28,9 @@ def test_must_connect(self): self.assertTrue( client_context.connected, "client context must be connected when " - "PYMONGO_MUST_CONNECT is set. Failed attempts:\n%s" - % (client_context.connection_attempt_info(),), + "PYMONGO_MUST_CONNECT is set. Failed attempts:\n{}".format( + client_context.connection_attempt_info() + ), ) def test_serverless(self): @@ -39,8 +40,9 @@ def test_serverless(self): self.assertTrue( client_context.connected and client_context.serverless, "client context must be connected to serverless when " - "TEST_SERVERLESS is set. Failed attempts:\n%s" - % (client_context.connection_attempt_info(),), + "TEST_SERVERLESS is set. Failed attempts:\n{}".format( + client_context.connection_attempt_info() + ), ) def test_enableTestCommands_is_disabled(self): diff --git a/test/test_cmap.py b/test/test_cmap.py index 360edef0e8..3b84524f44 100644 --- a/test/test_cmap.py +++ b/test/test_cmap.py @@ -116,7 +116,7 @@ def wait_for_event(self, op): timeout = op.get("timeout", 10000) / 1000.0 wait_until( lambda: self.listener.event_count(event) >= count, - "find %s %s event(s)" % (count, event), + f"find {count} {event} event(s)", timeout=timeout, ) @@ -191,11 +191,11 @@ def check_events(self, events, ignore): """Check the events of a test.""" actual_events = self.actual_events(ignore) for actual, expected in zip(actual_events, events): - self.logs.append("Checking event actual: %r vs expected: %r" % (actual, expected)) + self.logs.append(f"Checking event actual: {actual!r} vs expected: {expected!r}") self.check_event(actual, expected) if len(events) > len(actual_events): - self.fail("missing events: %r" % (events[len(actual_events) :],)) + self.fail(f"missing events: {events[len(actual_events) :]!r}") def check_error(self, actual, expected): message = expected.pop("message") @@ -260,9 +260,9 @@ def run_scenario(self, scenario_def, test): self.pool = list(client._topology._servers.values())[0].pool # Map of target names to Thread objects. - self.targets: dict = dict() + self.targets: dict = {} # Map of label names to Connection objects - self.labels: dict = dict() + self.labels: dict = {} def cleanup(): for t in self.targets.values(): @@ -285,7 +285,7 @@ def cleanup(): self.check_events(test["events"], test["ignore"]) except Exception: # Print the events after a test failure. - print("\nFailed test: %r" % (test["description"],)) + print("\nFailed test: {!r}".format(test["description"])) print("Operations:") for op in self._ops: print(op) @@ -332,8 +332,8 @@ def test_2_all_client_pools_have_same_options(self): self.assertEqual(pool.opts, pool_opts) def test_3_uri_connection_pool_options(self): - opts = "&".join(["%s=%s" % (k, v) for k, v in self.POOL_OPTIONS.items()]) - uri = "mongodb://%s/?%s" % (client_context.pair, opts) + opts = "&".join([f"{k}={v}" for k, v in self.POOL_OPTIONS.items()]) + uri = f"mongodb://{client_context.pair}/?{opts}" client = rs_or_single_client(uri) self.addCleanup(client.close) pool_opts = get_pool(client).opts diff --git a/test/test_code.py b/test/test_code.py index 9ff305e39a..9e44ca4962 100644 --- a/test/test_code.py +++ b/test/test_code.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Copyright 2009-present MongoDB, Inc. # @@ -67,7 +66,7 @@ def test_repr(self): c = Code("hello world", {"blah": 3}) self.assertEqual(repr(c), "Code('hello world', {'blah': 3})") c = Code("\x08\xFF") - self.assertEqual(repr(c), "Code(%s, None)" % (repr("\x08\xFF"),)) + self.assertEqual(repr(c), "Code({}, None)".format(repr("\x08\xFF"))) def test_equality(self): b = Code("hello") diff --git a/test/test_collation.py b/test/test_collation.py index 18f8bc78ac..7f4bbf4750 100644 --- a/test/test_collation.py +++ b/test/test_collation.py @@ -96,7 +96,7 @@ class TestCollation(IntegrationTest): @classmethod @client_context.require_connection def setUpClass(cls): - super(TestCollation, cls).setUpClass() + super().setUpClass() cls.listener = EventListener() cls.client = rs_or_single_client(event_listeners=[cls.listener]) cls.db = cls.client.pymongo_test @@ -110,11 +110,11 @@ def tearDownClass(cls): cls.warn_context.__exit__() cls.warn_context = None cls.client.close() - super(TestCollation, cls).tearDownClass() + super().tearDownClass() def tearDown(self): self.listener.reset() - super(TestCollation, self).tearDown() + super().tearDown() def last_command_started(self): return self.listener.started_events[-1].command diff --git a/test/test_collection.py b/test/test_collection.py index e36d6663f0..ca657f0099 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - # Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -151,7 +149,7 @@ class TestCollection(IntegrationTest): @classmethod def setUpClass(cls): - super(TestCollection, cls).setUpClass() + super().setUpClass() cls.w = client_context.w # type: ignore @classmethod @@ -373,7 +371,7 @@ def test_list_indexes(self): db.test.insert_one({}) # create collection def map_indexes(indexes): - return dict([(index["name"], index) for index in indexes]) + return {index["name"]: index for index in indexes} indexes = list(db.test.list_indexes()) self.assertEqual(len(indexes), 1) @@ -485,7 +483,7 @@ def test_index_2dsphere(self): db.test.drop_indexes() self.assertEqual("geo_2dsphere", db.test.create_index([("geo", GEOSPHERE)])) - for dummy, info in db.test.index_information().items(): + for _dummy, info in db.test.index_information().items(): field, idx_type = info["key"][0] if field == "geo" and idx_type == "2dsphere": break @@ -504,7 +502,7 @@ def test_index_hashed(self): db.test.drop_indexes() self.assertEqual("a_hashed", db.test.create_index([("a", HASHED)])) - for dummy, info in db.test.index_information().items(): + for _dummy, info in db.test.index_information().items(): field, idx_type = info["key"][0] if field == "a" and idx_type == "hashed": break @@ -1638,8 +1636,8 @@ def test_find_one(self): self.assertTrue("hello" in db.test.find_one(projection=("hello",))) self.assertTrue("hello" not in db.test.find_one(projection=("foo",))) - self.assertTrue("hello" in db.test.find_one(projection=set(["hello"]))) - self.assertTrue("hello" not in db.test.find_one(projection=set(["foo"]))) + self.assertTrue("hello" in db.test.find_one(projection={"hello"})) + self.assertTrue("hello" not in db.test.find_one(projection={"foo"})) self.assertTrue("hello" in db.test.find_one(projection=frozenset(["hello"]))) self.assertTrue("hello" not in db.test.find_one(projection=frozenset(["foo"]))) diff --git a/test/test_comment.py b/test/test_comment.py index 85e5470d74..ea44c74257 100644 --- a/test/test_comment.py +++ b/test/test_comment.py @@ -28,7 +28,7 @@ from pymongo.operations import IndexModel -class Empty(object): +class Empty: def __getattr__(self, item): try: self.__dict__[item] diff --git a/test/test_common.py b/test/test_common.py index ff50878ea1..76367ffa0c 100644 --- a/test/test_common.py +++ b/test/test_common.py @@ -148,14 +148,12 @@ def test_mongo_client(self): self.assertTrue(new_coll.insert_one(doc)) self.assertRaises(OperationFailure, coll.insert_one, doc) - m = rs_or_single_client( - "mongodb://%s/" % (pair,), replicaSet=client_context.replica_set_name - ) + m = rs_or_single_client(f"mongodb://{pair}/", replicaSet=client_context.replica_set_name) coll = m.pymongo_test.write_concern_test self.assertRaises(OperationFailure, coll.insert_one, doc) m = rs_or_single_client( - "mongodb://%s/?w=0" % (pair,), replicaSet=client_context.replica_set_name + f"mongodb://{pair}/?w=0", replicaSet=client_context.replica_set_name ) coll = m.pymongo_test.write_concern_test @@ -163,7 +161,7 @@ def test_mongo_client(self): # Equality tests direct = connected(single_client(w=0)) - direct2 = connected(single_client("mongodb://%s/?w=0" % (pair,), **self.credentials)) + direct2 = connected(single_client(f"mongodb://{pair}/?w=0", **self.credentials)) self.assertEqual(direct, direct2) self.assertFalse(direct != direct2) diff --git a/test/test_connections_survive_primary_stepdown_spec.py b/test/test_connections_survive_primary_stepdown_spec.py index fd9f126551..e09ba72a5c 100644 --- a/test/test_connections_survive_primary_stepdown_spec.py +++ b/test/test_connections_survive_primary_stepdown_spec.py @@ -40,7 +40,7 @@ class TestConnectionsSurvivePrimaryStepDown(IntegrationTest): @classmethod @client_context.require_replica_set def setUpClass(cls): - super(TestConnectionsSurvivePrimaryStepDown, cls).setUpClass() + super().setUpClass() cls.listener = CMAPListener() cls.client = rs_or_single_client( event_listeners=[cls.listener], retryWrites=False, heartbeatFrequencyMS=500 diff --git a/test/test_crud_v1.py b/test/test_crud_v1.py index ca4b84c26d..589da0a7d7 100644 --- a/test/test_crud_v1.py +++ b/test/test_crud_v1.py @@ -55,7 +55,7 @@ def check_result(self, expected_result, result): if isinstance(result, _WriteResult): for res in expected_result: prop = camel_to_snake(res) - msg = "%s : %r != %r" % (prop, expected_result, result) + msg = f"{prop} : {expected_result!r} != {result!r}" # SPEC-869: Only BulkWriteResult has upserted_count. if prop == "upserted_count" and not isinstance(result, BulkWriteResult): if result.upserted_id is not None: # type: ignore diff --git a/test/test_cursor.py b/test/test_cursor.py index e96efb92b0..f8820f8aa2 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -945,7 +945,7 @@ def test_getitem_slice_index(self): for a, b in zip(count(99), self.db.test.find()[99:]): self.assertEqual(a, b["i"]) - for i in self.db.test.find()[1000:]: + for _i in self.db.test.find()[1000:]: self.fail() self.assertEqual(5, len(list(self.db.test.find()[20:25]))) @@ -1079,7 +1079,7 @@ def test_concurrent_close(self): def iterate_cursor(): while cursor.alive: - for doc in cursor: + for _doc in cursor: pass t = threading.Thread(target=iterate_cursor) @@ -1430,7 +1430,7 @@ def test_monitoring(self): class TestRawBatchCommandCursor(IntegrationTest): @classmethod def setUpClass(cls): - super(TestRawBatchCommandCursor, cls).setUpClass() + super().setUpClass() def test_aggregate_raw(self): c = self.db.test diff --git a/test/test_custom_types.py b/test/test_custom_types.py index 676b3b6af0..14d7b4b05d 100644 --- a/test/test_custom_types.py +++ b/test/test_custom_types.py @@ -81,7 +81,7 @@ class DecimalCodec(DecimalDecoder, DecimalEncoder): DECIMAL_CODECOPTS = CodecOptions(type_registry=TypeRegistry([DecimalCodec()])) -class UndecipherableInt64Type(object): +class UndecipherableInt64Type: def __init__(self, value): self.value = value @@ -146,7 +146,7 @@ def transform_bson(self, value): return ResumeTokenToNanDecoder -class CustomBSONTypeTests(object): +class CustomBSONTypeTests: @no_type_check def roundtrip(self, doc): bsonbytes = encode(doc, codec_options=self.codecopts) @@ -164,9 +164,9 @@ def test_encode_decode_roundtrip(self): def test_decode_all(self): documents = [] for dec in range(3): - documents.append({"average": Decimal("56.4%s" % (dec,))}) + documents.append({"average": Decimal(f"56.4{dec}")}) - bsonstream = bytes() + bsonstream = b"" for doc in documents: bsonstream += encode(doc, codec_options=self.codecopts) @@ -287,7 +287,7 @@ def run_test(base, attrs, fail): else: codec() - class MyType(object): + class MyType: pass run_test( @@ -350,11 +350,11 @@ class TestBSONCustomTypeEncoderAndFallbackEncoderTandem(unittest.TestCase): @classmethod def setUpClass(cls): - class TypeA(object): + class TypeA: def __init__(self, x): self.value = x - class TypeB(object): + class TypeB: def __init__(self, x): self.value = x @@ -442,12 +442,12 @@ class TestTypeRegistry(unittest.TestCase): @classmethod def setUpClass(cls): - class MyIntType(object): + class MyIntType: def __init__(self, x): assert isinstance(x, int) self.x = x - class MyStrType(object): + class MyStrType: def __init__(self, x): assert isinstance(x, str) self.x = x @@ -553,18 +553,18 @@ def test_initialize_fail(self): with self.assertRaisesRegex(TypeError, err_msg): TypeRegistry([type("AnyType", (object,), {})()]) - err_msg = "fallback_encoder %r is not a callable" % (True,) + err_msg = f"fallback_encoder {True!r} is not a callable" with self.assertRaisesRegex(TypeError, err_msg): TypeRegistry([], True) # type: ignore[arg-type] - err_msg = "fallback_encoder %r is not a callable" % ("hello",) + err_msg = "fallback_encoder {!r} is not a callable".format("hello") with self.assertRaisesRegex(TypeError, err_msg): TypeRegistry(fallback_encoder="hello") # type: ignore[arg-type] def test_type_registry_repr(self): codec_instances = [codec() for codec in self.codecs] type_registry = TypeRegistry(codec_instances) - r = "TypeRegistry(type_codecs=%r, fallback_encoder=%r)" % (codec_instances, None) + r = f"TypeRegistry(type_codecs={codec_instances!r}, fallback_encoder={None!r})" self.assertEqual(r, repr(type_registry)) def test_type_registry_eq(self): @@ -777,7 +777,7 @@ def test_grid_out_custom_opts(self): self.assertRaises(AttributeError, setattr, two, attr, 5) -class ChangeStreamsWCustomTypesTestMixin(object): +class ChangeStreamsWCustomTypesTestMixin: @no_type_check def change_stream(self, *args, **kwargs): return self.watched_target.watch(*args, **kwargs) @@ -899,7 +899,7 @@ class TestCollectionChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCus @classmethod @client_context.require_change_streams def setUpClass(cls): - super(TestCollectionChangeStreamsWCustomTypes, cls).setUpClass() + super().setUpClass() cls.db.test.delete_many({}) def tearDown(self): @@ -918,7 +918,7 @@ class TestDatabaseChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCusto @client_context.require_version_min(4, 0, 0) @client_context.require_change_streams def setUpClass(cls): - super(TestDatabaseChangeStreamsWCustomTypes, cls).setUpClass() + super().setUpClass() cls.db.test.delete_many({}) def tearDown(self): @@ -937,7 +937,7 @@ class TestClusterChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustom @client_context.require_version_min(4, 0, 0) @client_context.require_change_streams def setUpClass(cls): - super(TestClusterChangeStreamsWCustomTypes, cls).setUpClass() + super().setUpClass() cls.db.test.delete_many({}) def tearDown(self): diff --git a/test/test_data_lake.py b/test/test_data_lake.py index 4fa38435a3..ce210010bd 100644 --- a/test/test_data_lake.py +++ b/test/test_data_lake.py @@ -52,7 +52,7 @@ class TestDataLakeProse(IntegrationTest): @classmethod @client_context.require_data_lake def setUpClass(cls): - super(TestDataLakeProse, cls).setUpClass() + super().setUpClass() # Test killCursors def test_1(self): @@ -100,7 +100,7 @@ class DataLakeTestSpec(TestCrudV2): @classmethod @client_context.require_data_lake def setUpClass(cls): - super(DataLakeTestSpec, cls).setUpClass() + super().setUpClass() def setup_scenario(self, scenario_def): # Spec tests MUST NOT insert data/drop collection for diff --git a/test/test_database.py b/test/test_database.py index b6be380aab..140d169db3 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -137,7 +137,7 @@ def test_get_coll(self): def test_repr(self): self.assertEqual( repr(Database(self.client, "pymongo_test")), - "Database(%r, %s)" % (self.client, repr("pymongo_test")), + "Database({!r}, {})".format(self.client, repr("pymongo_test")), ) def test_create_collection(self): @@ -262,8 +262,8 @@ def test_list_collections(self): # Checking if is there any collection which don't exists. if ( - len(set(colls) - set(["test", "test.mike"])) == 0 - or len(set(colls) - set(["test", "test.mike", "system.indexes"])) == 0 + len(set(colls) - {"test", "test.mike"}) == 0 + or len(set(colls) - {"test", "test.mike", "system.indexes"}) == 0 ): self.assertTrue(True) else: @@ -301,10 +301,7 @@ def test_list_collections(self): coll_cnt = {} # Checking if is there any collection which don't exists. - if ( - len(set(colls) - set(["test"])) == 0 - or len(set(colls) - set(["test", "system.indexes"])) == 0 - ): + if len(set(colls) - {"test"}) == 0 or len(set(colls) - {"test", "system.indexes"}) == 0: self.assertTrue(True) else: self.assertTrue(False) @@ -439,7 +436,7 @@ def test_id_ordering(self): ) cursor = db.test.find() for x in cursor: - for (k, v) in x.items(): + for (k, _v) in x.items(): self.assertEqual(k, "_id") break diff --git a/test/test_dbref.py b/test/test_dbref.py index 281aef473f..107d95d230 100644 --- a/test/test_dbref.py +++ b/test/test_dbref.py @@ -64,7 +64,7 @@ def test_repr(self): ) self.assertEqual( repr(DBRef("coll", ObjectId("1234567890abcdef12345678"))), - "DBRef(%s, ObjectId('1234567890abcdef12345678'))" % (repr("coll"),), + "DBRef({}, ObjectId('1234567890abcdef12345678'))".format(repr("coll")), ) self.assertEqual(repr(DBRef("coll", 5, foo="bar")), "DBRef('coll', 5, foo='bar')") self.assertEqual( diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index 9af8185ab5..8a14ecfb2a 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -104,15 +104,15 @@ def got_app_error(topology, app_error): elif error_type == "timeout": raise NetworkTimeout("mock network timeout error") else: - raise AssertionError("unknown error type: %s" % (error_type,)) - assert False + raise AssertionError(f"unknown error type: {error_type}") + raise AssertionError except (AutoReconnect, NotPrimaryError, OperationFailure) as e: if when == "beforeHandshakeCompletes": completed_handshake = False elif when == "afterHandshakeCompletes": completed_handshake = True else: - assert False, "Unknown when field %s" % (when,) + raise AssertionError(f"Unknown when field {when}") topology.handle_error( server_address, @@ -201,7 +201,7 @@ def run_scenario(self): for i, phase in enumerate(scenario_def["phases"]): # Including the phase description makes failures easier to debug. description = phase.get("description", str(i)) - with assertion_context("phase: %s" % (description,)): + with assertion_context(f"phase: {description}"): for response in phase.get("responses", []): got_hello(c, common.partition_node(response[0]), response[1]) @@ -228,7 +228,7 @@ def create_tests(): # Construct test from scenario. new_test = create_test(scenario_def) - test_name = "test_%s_%s" % (dirname, os.path.splitext(filename)[0]) + test_name = f"test_{dirname}_{os.path.splitext(filename)[0]}" new_test.__name__ = test_name setattr(TestAllScenarios, new_test.__name__, new_test) diff --git a/test/test_encryption.py b/test/test_encryption.py index af8f54cd07..95f18eb307 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -207,7 +207,7 @@ class EncryptionIntegrationTest(IntegrationTest): @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") @client_context.require_version_min(4, 2, -1) def setUpClass(cls): - super(EncryptionIntegrationTest, cls).setUpClass() + super().setUpClass() def assertEncrypted(self, val): self.assertIsInstance(val, Binary) @@ -295,7 +295,7 @@ def _test_auto_encrypt(self, opts): # Collection.distinct auto decrypts. decrypted_ssns = encrypted_coll.distinct("ssn") - self.assertEqual(set(decrypted_ssns), set(d["ssn"] for d in docs)) + self.assertEqual(set(decrypted_ssns), {d["ssn"] for d in docs}) # Make sure the field is actually encrypted. for encrypted_doc in self.db.test.find(): @@ -391,7 +391,7 @@ class TestClientMaxWireVersion(IntegrationTest): @classmethod @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") def setUpClass(cls): - super(TestClientMaxWireVersion, cls).setUpClass() + super().setUpClass() @client_context.require_version_max(4, 0, 99) def test_raise_max_wire_version_error(self): @@ -585,7 +585,7 @@ class TestSpec(SpecRunner): @classmethod @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") def setUpClass(cls): - super(TestSpec, cls).setUpClass() + super().setUpClass() def parse_auto_encrypt_opts(self, opts): """Parse clientOptions.autoEncryptOpts.""" @@ -630,14 +630,14 @@ def parse_client_options(self, opts): if encrypt_opts: opts["auto_encryption_opts"] = self.parse_auto_encrypt_opts(encrypt_opts) - return super(TestSpec, self).parse_client_options(opts) + return super().parse_client_options(opts) def get_object_name(self, op): """Default object is collection.""" return op.get("object", "collection") def maybe_skip_scenario(self, test): - super(TestSpec, self).maybe_skip_scenario(test) + super().maybe_skip_scenario(test) desc = test["description"].lower() if "type=symbol" in desc: self.skipTest("PyMongo does not support the symbol type") @@ -674,7 +674,7 @@ def setup_scenario(self, scenario_def): def allowable_errors(self, op): """Override expected error classes.""" - errors = super(TestSpec, self).allowable_errors(op) + errors = super().allowable_errors(op) # An updateOne test expects encryption to error when no $ operator # appears but pymongo raises a client side ValueError in this case. if op["name"] == "updateOne": @@ -773,7 +773,7 @@ class TestDataKeyDoubleEncryption(EncryptionIntegrationTest): "No environment credentials are set", ) def setUpClass(cls): - super(TestDataKeyDoubleEncryption, cls).setUpClass() + super().setUpClass() cls.listener = OvertCommandListener() cls.client = rs_or_single_client(event_listeners=[cls.listener]) cls.client.db.coll.drop() @@ -818,7 +818,7 @@ def run_test(self, provider_name): # Create data key. master_key: Any = self.MASTER_KEYS[provider_name] datakey_id = self.client_encryption.create_data_key( - provider_name, master_key=master_key, key_alt_names=["%s_altname" % (provider_name,)] + provider_name, master_key=master_key, key_alt_names=[f"{provider_name}_altname"] ) self.assertBinaryUUID(datakey_id) cmd = self.listener.started_events[-1] @@ -830,20 +830,20 @@ def run_test(self, provider_name): # Encrypt by key_id. encrypted = self.client_encryption.encrypt( - "hello %s" % (provider_name,), + f"hello {provider_name}", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id=datakey_id, ) self.assertEncrypted(encrypted) self.client_encrypted.db.coll.insert_one({"_id": provider_name, "value": encrypted}) doc_decrypted = self.client_encrypted.db.coll.find_one({"_id": provider_name}) - self.assertEqual(doc_decrypted["value"], "hello %s" % (provider_name,)) # type: ignore + self.assertEqual(doc_decrypted["value"], f"hello {provider_name}") # type: ignore # Encrypt by key_alt_name. encrypted_altname = self.client_encryption.encrypt( - "hello %s" % (provider_name,), + f"hello {provider_name}", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, - key_alt_name="%s_altname" % (provider_name,), + key_alt_name=f"{provider_name}_altname", ) self.assertEqual(encrypted_altname, encrypted) @@ -965,7 +965,7 @@ class TestCorpus(EncryptionIntegrationTest): @classmethod @unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set") def setUpClass(cls): - super(TestCorpus, cls).setUpClass() + super().setUpClass() @staticmethod def kms_providers(): @@ -1046,17 +1046,17 @@ def _test_corpus(self, opts): self.assertIn(kms, ("local", "aws", "azure", "gcp", "kmip")) if identifier == "id": if kms == "local": - kwargs = dict(key_id=LOCAL_KEY_ID) + kwargs = {"key_id": LOCAL_KEY_ID} elif kms == "aws": - kwargs = dict(key_id=AWS_KEY_ID) + kwargs = {"key_id": AWS_KEY_ID} elif kms == "azure": - kwargs = dict(key_id=AZURE_KEY_ID) + kwargs = {"key_id": AZURE_KEY_ID} elif kms == "gcp": - kwargs = dict(key_id=GCP_KEY_ID) + kwargs = {"key_id": GCP_KEY_ID} else: - kwargs = dict(key_id=KMIP_KEY_ID) + kwargs = {"key_id": KMIP_KEY_ID} else: - kwargs = dict(key_alt_name=kms) + kwargs = {"key_alt_name": kms} self.assertIn(value["algo"], ("det", "rand")) if value["algo"] == "det": @@ -1069,12 +1069,12 @@ def _test_corpus(self, opts): value["value"], algo, **kwargs # type: ignore[arg-type] ) if not value["allowed"]: - self.fail("encrypt should have failed: %r: %r" % (key, value)) + self.fail(f"encrypt should have failed: {key!r}: {value!r}") corpus_copied[key]["value"] = encrypted_val except Exception: if value["allowed"]: tb = traceback.format_exc() - self.fail("encrypt failed: %r: %r, traceback: %s" % (key, value, tb)) + self.fail(f"encrypt failed: {key!r}: {value!r}, traceback: {tb}") client_encrypted.db.coll.insert_one(corpus_copied) corpus_decrypted = client_encrypted.db.coll.find_one() @@ -1141,7 +1141,7 @@ class TestBsonSizeBatches(EncryptionIntegrationTest): @classmethod def setUpClass(cls): - super(TestBsonSizeBatches, cls).setUpClass() + super().setUpClass() db = client_context.client.db cls.coll = db.coll cls.coll.drop() @@ -1172,7 +1172,7 @@ def setUpClass(cls): def tearDownClass(cls): cls.coll_encrypted.drop() cls.client_encrypted.close() - super(TestBsonSizeBatches, cls).tearDownClass() + super().tearDownClass() def test_01_insert_succeeds_under_2MiB(self): doc = {"_id": "over_2mib_under_16mib", "unencrypted": "a" * _2_MiB} @@ -1242,7 +1242,7 @@ class TestCustomEndpoint(EncryptionIntegrationTest): "No environment credentials are set", ) def setUpClass(cls): - super(TestCustomEndpoint, cls).setUpClass() + super().setUpClass() def setUp(self): kms_providers = { @@ -1442,7 +1442,7 @@ def test_12_kmip_master_key_invalid_endpoint(self): self.client_encryption.create_data_key("kmip", key) -class AzureGCPEncryptionTestMixin(object): +class AzureGCPEncryptionTestMixin: DEK = None KMS_PROVIDER_MAP = None KEYVAULT_DB = "keyvault" @@ -1514,7 +1514,7 @@ def setUpClass(cls): cls.KMS_PROVIDER_MAP = {"azure": AZURE_CREDS} cls.DEK = json_data(BASE, "custom", "azure-dek.json") cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") - super(TestAzureEncryption, cls).setUpClass() + super().setUpClass() def test_explicit(self): return self._test_explicit( @@ -1540,7 +1540,7 @@ def setUpClass(cls): cls.KMS_PROVIDER_MAP = {"gcp": GCP_CREDS} cls.DEK = json_data(BASE, "custom", "gcp-dek.json") cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") - super(TestGCPEncryption, cls).setUpClass() + super().setUpClass() def test_explicit(self): return self._test_explicit( @@ -1985,7 +1985,7 @@ def listener(): class TestKmsTLSProse(EncryptionIntegrationTest): @unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set") def setUp(self): - super(TestKmsTLSProse, self).setUp() + super().setUp() self.patch_system_certs(CA_PEM) self.client_encrypted = ClientEncryption( {"aws": AWS_CREDS}, "keyvault.datakeys", self.client, OPTS @@ -2023,7 +2023,7 @@ def test_invalid_hostname_in_kms_certificate(self): class TestKmsTLSOptions(EncryptionIntegrationTest): @unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set") def setUp(self): - super(TestKmsTLSOptions, self).setUp() + super().setUp() # 1, create client with only tlsCAFile. providers: dict = copy.deepcopy(ALL_KMS_PROVIDERS) providers["azure"]["identityPlatformEndpoint"] = "127.0.0.1:8002" @@ -2391,7 +2391,7 @@ def run_test(self, src_provider, dst_provider): # https://github.com/mongodb/specifications/blob/5cf3ed/source/client-side-encryption/tests/README.rst#on-demand-aws-credentials class TestOnDemandAWSCredentials(EncryptionIntegrationTest): def setUp(self): - super(TestOnDemandAWSCredentials, self).setUp() + super().setUp() self.master_key = { "region": "us-east-1", "key": ("arn:aws:kms:us-east-1:579766882180:key/89fcc2c4-08b0-4bd9-9f25-e30687b580d0"), diff --git a/test/test_examples.py b/test/test_examples.py index c08cb17e20..b9508d4f1e 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -34,7 +34,7 @@ class TestSampleShellCommands(IntegrationTest): @classmethod def setUpClass(cls): - super(TestSampleShellCommands, cls).setUpClass() + super().setUpClass() # Run once before any tests run. cls.db.inventory.drop() @@ -757,18 +757,18 @@ def insert_docs(): # 1. The database for reactive, real-time applications # Start Changestream Example 1 cursor = db.inventory.watch() - document = next(cursor) + next(cursor) # End Changestream Example 1 # Start Changestream Example 2 cursor = db.inventory.watch(full_document="updateLookup") - document = next(cursor) + next(cursor) # End Changestream Example 2 # Start Changestream Example 3 resume_token = cursor.resume_token cursor = db.inventory.watch(resume_after=resume_token) - document = next(cursor) + next(cursor) # End Changestream Example 3 # Start Changestream Example 4 @@ -777,7 +777,7 @@ def insert_docs(): {"$addFields": {"newField": "this is an added field!"}}, ] cursor = db.inventory.watch(pipeline=pipeline) - document = next(cursor) + next(cursor) # End Changestream Example 4 finally: done = True @@ -898,7 +898,7 @@ def test_misc(self): with client.start_session() as session: collection.insert_one({"_id": 1}, session=session) collection.update_one({"_id": 1}, {"$set": {"a": 1}}, session=session) - for doc in collection.find({}, session=session): + for _doc in collection.find({}, session=session): pass # 3. Exploiting the power of arrays @@ -1078,7 +1078,7 @@ def update_employee_info(session): with client.start_session() as session: try: run_transaction_with_retry(update_employee_info, session) - except Exception as exc: + except Exception: # Do something with error. raise @@ -1089,7 +1089,9 @@ def update_employee_info(session): self.assertIsNotNone(employee) self.assertEqual(employee["status"], "Inactive") - MongoClient = lambda _: rs_client() + def MongoClient(_): + return rs_client() + uriString = None # Start Transactions withTxn API Example 1 @@ -1179,25 +1181,27 @@ class TestVersionedApiExamples(IntegrationTest): @client_context.require_version_min(4, 7) def test_versioned_api(self): # Versioned API examples - MongoClient = lambda _, server_api: rs_client(server_api=server_api, connect=False) + def MongoClient(_, server_api): + return rs_client(server_api=server_api, connect=False) + uri = None # Start Versioned API Example 1 from pymongo.server_api import ServerApi - client = MongoClient(uri, server_api=ServerApi("1")) + MongoClient(uri, server_api=ServerApi("1")) # End Versioned API Example 1 # Start Versioned API Example 2 - client = MongoClient(uri, server_api=ServerApi("1", strict=True)) + MongoClient(uri, server_api=ServerApi("1", strict=True)) # End Versioned API Example 2 # Start Versioned API Example 3 - client = MongoClient(uri, server_api=ServerApi("1", strict=False)) + MongoClient(uri, server_api=ServerApi("1", strict=False)) # End Versioned API Example 3 # Start Versioned API Example 4 - client = MongoClient(uri, server_api=ServerApi("1", deprecation_errors=True)) + MongoClient(uri, server_api=ServerApi("1", deprecation_errors=True)) # End Versioned API Example 4 @unittest.skip("PYTHON-3167 count has been added to API version 1") @@ -1339,7 +1343,7 @@ def test_snapshot_query(self): # Start Snapshot Query Example 2 db = client.retail with client.start_session(snapshot=True) as s: - total = db.sales.aggregate( + db.sales.aggregate( [ { "$match": { diff --git a/test/test_grid_file.py b/test/test_grid_file.py index 8b46133a60..04003289e6 100644 --- a/test/test_grid_file.py +++ b/test/test_grid_file.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Copyright 2009-present MongoDB, Inc. # @@ -14,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for the grid_file module. -""" +"""Tests for the grid_file module.""" import datetime import io @@ -462,12 +460,10 @@ def test_multiple_reads(self): def test_readline(self): f = GridIn(self.db.fs, chunkSize=5) f.write( - ( - b"""Hello world, + b"""Hello world, How are you? Hope all is well. Bye""" - ) ) f.close() @@ -498,12 +494,10 @@ def test_readline(self): def test_readlines(self): f = GridIn(self.db.fs, chunkSize=5) f.write( - ( - b"""Hello world, + b"""Hello world, How are you? Hope all is well. Bye""" - ) ) f.close() diff --git a/test/test_gridfs.py b/test/test_gridfs.py index cfa6e43e85..4ba8467d22 100644 --- a/test/test_gridfs.py +++ b/test/test_gridfs.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Copyright 2009-present MongoDB, Inc. # @@ -14,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for the gridfs package. -""" +"""Tests for the gridfs package.""" import datetime import sys @@ -90,7 +88,7 @@ class TestGridfs(IntegrationTest): @classmethod def setUpClass(cls): - super(TestGridfs, cls).setUpClass() + super().setUpClass() cls.fs = gridfs.GridFS(cls.db) cls.alt = gridfs.GridFS(cls.db, "alt") @@ -141,7 +139,7 @@ def test_list(self): self.fs.put(b"foo", filename="test") self.fs.put(b"", filename="hello world") - self.assertEqual(set(["mike", "test", "hello world"]), set(self.fs.list())) + self.assertEqual({"mike", "test", "hello world"}, set(self.fs.list())) def test_empty_file(self): oid = self.fs.put(b"") @@ -210,7 +208,7 @@ def test_alt_collection(self): self.alt.put(b"foo", filename="test") self.alt.put(b"", filename="hello world") - self.assertEqual(set(["mike", "test", "hello world"]), set(self.alt.list())) + self.assertEqual({"mike", "test", "hello world"}, set(self.alt.list())) def test_threaded_reads(self): self.fs.put(b"hello", _id="test") @@ -394,7 +392,7 @@ def test_missing_length_iter(self): f = self.fs.get_last_version(filename="empty") def iterate_file(grid_file): - for chunk in grid_file: + for _chunk in grid_file: pass return True @@ -496,7 +494,7 @@ class TestGridfsReplicaSet(IntegrationTest): @classmethod @client_context.require_secondaries_count(1) def setUpClass(cls): - super(TestGridfsReplicaSet, cls).setUpClass() + super().setUpClass() @classmethod def tearDownClass(cls): diff --git a/test/test_gridfs_bucket.py b/test/test_gridfs_bucket.py index b6a33b4ecc..e5695f2c38 100644 --- a/test/test_gridfs_bucket.py +++ b/test/test_gridfs_bucket.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Copyright 2015-present MongoDB, Inc. # @@ -14,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for the gridfs package. -""" +"""Tests for the gridfs package.""" import datetime import itertools import threading @@ -75,7 +73,7 @@ class TestGridfs(IntegrationTest): @classmethod def setUpClass(cls): - super(TestGridfs, cls).setUpClass() + super().setUpClass() cls.fs = gridfs.GridFSBucket(cls.db) cls.alt = gridfs.GridFSBucket(cls.db, bucket_name="alt") @@ -196,8 +194,8 @@ def test_alt_collection(self): self.alt.upload_from_stream("hello world", b"") self.assertEqual( - set(["mike", "test", "hello world", "foo"]), - set(k["filename"] for k in list(self.db.alt.files.find())), + {"mike", "test", "hello world", "foo"}, + {k["filename"] for k in list(self.db.alt.files.find())}, ) def test_threaded_reads(self): @@ -442,7 +440,7 @@ class TestGridfsBucketReplicaSet(IntegrationTest): @classmethod @client_context.require_secondaries_count(1) def setUpClass(cls): - super(TestGridfsBucketReplicaSet, cls).setUpClass() + super().setUpClass() @classmethod def tearDownClass(cls): diff --git a/test/test_load_balancer.py b/test/test_load_balancer.py index d4de8debf5..df68b3e626 100644 --- a/test/test_load_balancer.py +++ b/test/test_load_balancer.py @@ -150,7 +150,7 @@ def test_session_gc(self): class PoolLocker(ExceptionCatchingThread): def __init__(self, pool): - super(PoolLocker, self).__init__(target=self.lock_pool) + super().__init__(target=self.lock_pool) self.pool = pool self.daemon = True self.locked = threading.Event() diff --git a/test/test_mongos_load_balancing.py b/test/test_mongos_load_balancing.py index e39940f56b..9e83e879a5 100644 --- a/test/test_mongos_load_balancing.py +++ b/test/test_mongos_load_balancing.py @@ -36,7 +36,7 @@ def setUpModule(): class SimpleOp(threading.Thread): def __init__(self, client): - super(SimpleOp, self).__init__() + super().__init__() self.client = client self.passed = False @@ -58,9 +58,9 @@ def do_simple_op(client, nthreads): def writable_addresses(topology): - return set( + return { server.description.address for server in topology.select_servers(writable_server_selector) - ) + } class TestMongosLoadBalancing(MockClientTest): @@ -133,7 +133,7 @@ def test_local_threshold(self): topology = client._topology # All are within a 30-ms latency window, see self.mock_client(). - self.assertEqual(set([("a", 1), ("b", 2), ("c", 3)]), writable_addresses(topology)) + self.assertEqual({("a", 1), ("b", 2), ("c", 3)}, writable_addresses(topology)) # No error client.admin.command("ping") @@ -143,7 +143,7 @@ def test_local_threshold(self): # No error client.db.command("ping") # Our chosen mongos goes down. - client.kill_host("%s:%s" % next(iter(client.nodes))) + client.kill_host("{}:{}".format(*next(iter(client.nodes)))) try: client.db.command("ping") except: @@ -174,13 +174,13 @@ def test_load_balancing(self): self.assertEqual(TOPOLOGY_TYPE.Sharded, topology.description.topology_type) # a and b are within the 15-ms latency window, see self.mock_client(). - self.assertEqual(set([("a", 1), ("b", 2)]), writable_addresses(topology)) + self.assertEqual({("a", 1), ("b", 2)}, writable_addresses(topology)) client.mock_rtts["a:1"] = 0.045 # Discover only b is within latency window. wait_until( - lambda: set([("b", 2)]) == writable_addresses(topology), + lambda: {("b", 2)} == writable_addresses(topology), 'discover server "a" is too far', ) diff --git a/test/test_monitor.py b/test/test_monitor.py index 85cfb0bc40..9ee3c52ff5 100644 --- a/test/test_monitor.py +++ b/test/test_monitor.py @@ -66,7 +66,7 @@ def test_cleanup_executors_on_client_del(self): del client for ref, name in executor_refs: - wait_until(partial(unregistered, ref), "unregister executor: %s" % (name,), timeout=5) + wait_until(partial(unregistered, ref), f"unregister executor: {name}", timeout=5) def test_cleanup_executors_on_client_close(self): client = create_client() @@ -76,9 +76,7 @@ def test_cleanup_executors_on_client_close(self): client.close() for executor in executors: - wait_until( - lambda: executor._stopped, "closed executor: %s" % (executor._name,), timeout=5 - ) + wait_until(lambda: executor._stopped, f"closed executor: {executor._name}", timeout=5) if __name__ == "__main__": diff --git a/test/test_monitoring.py b/test/test_monitoring.py index 39b3d2f896..c7c793b382 100644 --- a/test/test_monitoring.py +++ b/test/test_monitoring.py @@ -39,18 +39,18 @@ class TestCommandMonitoring(IntegrationTest): @classmethod @client_context.require_connection def setUpClass(cls): - super(TestCommandMonitoring, cls).setUpClass() + super().setUpClass() cls.listener = EventListener() cls.client = rs_or_single_client(event_listeners=[cls.listener], retryWrites=False) @classmethod def tearDownClass(cls): cls.client.close() - super(TestCommandMonitoring, cls).tearDownClass() + super().tearDownClass() def tearDown(self): self.listener.reset() - super(TestCommandMonitoring, self).tearDown() + super().tearDown() def test_started_simple(self): self.client.pymongo_test.command("ping") @@ -232,40 +232,40 @@ def _test_find_options(self, query, expected_cmd): tuple(cursor) def test_find_options(self): - query = dict( - filter={}, - hint=[("x", 1)], - max_time_ms=10000, - max={"x": 10}, - min={"x": -10}, - return_key=True, - show_record_id=True, - projection={"x": False}, - skip=1, - no_cursor_timeout=True, - sort=[("_id", 1)], - allow_partial_results=True, - comment="this is a test", - batch_size=2, - ) + query = { + "filter": {}, + "hint": [("x", 1)], + "max_time_ms": 10000, + "max": {"x": 10}, + "min": {"x": -10}, + "return_key": True, + "show_record_id": True, + "projection": {"x": False}, + "skip": 1, + "no_cursor_timeout": True, + "sort": [("_id", 1)], + "allow_partial_results": True, + "comment": "this is a test", + "batch_size": 2, + } - cmd = dict( - find="test", - filter={}, - hint=SON([("x", 1)]), - comment="this is a test", - maxTimeMS=10000, - max={"x": 10}, - min={"x": -10}, - returnKey=True, - showRecordId=True, - sort=SON([("_id", 1)]), - projection={"x": False}, - skip=1, - batchSize=2, - noCursorTimeout=True, - allowPartialResults=True, - ) + cmd = { + "find": "test", + "filter": {}, + "hint": SON([("x", 1)]), + "comment": "this is a test", + "maxTimeMS": 10000, + "max": {"x": 10}, + "min": {"x": -10}, + "returnKey": True, + "showRecordId": True, + "sort": SON([("_id", 1)]), + "projection": {"x": False}, + "skip": 1, + "batchSize": 2, + "noCursorTimeout": True, + "allowPartialResults": True, + } if client_context.version < (4, 1, 0, -1): query["max_scan"] = 10 @@ -276,9 +276,9 @@ def test_find_options(self): @client_context.require_version_max(3, 7, 2) def test_find_snapshot(self): # Test "snapshot" parameter separately, can't combine with "sort". - query = dict(filter={}, snapshot=True) + query = {"filter": {}, "snapshot": True} - cmd = dict(find="test", filter={}, snapshot=True) + cmd = {"find": "test", "filter": {}, "snapshot": True} self._test_find_options(query, cmd) @@ -1049,7 +1049,7 @@ def test_write_errors(self): errors.extend(succeed.reply["writeErrors"]) self.assertEqual(2, len(errors)) - fields = set(["index", "code", "errmsg"]) + fields = {"index", "code", "errmsg"} for error in errors: self.assertTrue(fields.issubset(set(error))) @@ -1113,7 +1113,7 @@ class TestGlobalListener(IntegrationTest): @classmethod @client_context.require_connection def setUpClass(cls): - super(TestGlobalListener, cls).setUpClass() + super().setUpClass() cls.listener = EventListener() # We plan to call register(), which internally modifies _LISTENERS. cls.saved_listeners = copy.deepcopy(monitoring._LISTENERS) @@ -1126,10 +1126,10 @@ def setUpClass(cls): def tearDownClass(cls): monitoring._LISTENERS = cls.saved_listeners cls.client.close() - super(TestGlobalListener, cls).tearDownClass() + super().tearDownClass() def setUp(self): - super(TestGlobalListener, self).setUp() + super().setUp() self.listener.reset() def test_simple(self): diff --git a/test/test_on_demand_csfle.py b/test/test_on_demand_csfle.py index d5668199a3..499dc64b3b 100644 --- a/test/test_on_demand_csfle.py +++ b/test/test_on_demand_csfle.py @@ -30,10 +30,10 @@ class TestonDemandGCPCredentials(IntegrationTest): @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") @client_context.require_version_min(4, 2, -1) def setUpClass(cls): - super(TestonDemandGCPCredentials, cls).setUpClass() + super().setUpClass() def setUp(self): - super(TestonDemandGCPCredentials, self).setUp() + super().setUp() self.master_key = { "projectId": "devprod-drivers", "location": "global", @@ -72,10 +72,10 @@ class TestonDemandAzureCredentials(IntegrationTest): @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") @client_context.require_version_min(4, 2, -1) def setUpClass(cls): - super(TestonDemandAzureCredentials, cls).setUpClass() + super().setUpClass() def setUp(self): - super(TestonDemandAzureCredentials, self).setUp() + super().setUp() self.master_key = { "keyVaultEndpoint": "https://keyvault-drivers-2411.vault.azure.net/keys/", "keyName": "KEY-NAME", diff --git a/test/test_pooling.py b/test/test_pooling.py index 923c89d83b..57c9b807a6 100644 --- a/test/test_pooling.py +++ b/test/test_pooling.py @@ -60,7 +60,7 @@ class MongoThread(threading.Thread): """A thread that uses a MongoClient.""" def __init__(self, client): - super(MongoThread, self).__init__() + super().__init__() self.daemon = True # Don't hang whole test if thread hangs. self.client = client self.db = self.client[DB] @@ -107,7 +107,7 @@ class SocketGetter(MongoThread): """ def __init__(self, client, pool): - super(SocketGetter, self).__init__(client) + super().__init__(client) self.state = "init" self.pool = pool self.sock = None @@ -132,7 +132,7 @@ def run_cases(client, cases): n_runs = 5 for case in cases: - for i in range(n_runs): + for _i in range(n_runs): t = case(client) t.start() threads.append(t) @@ -148,7 +148,7 @@ class _TestPoolingBase(IntegrationTest): """Base class for all connection-pool tests.""" def setUp(self): - super(_TestPoolingBase, self).setUp() + super().setUp() self.c = rs_or_single_client() db = self.c[DB] db.unique.drop() @@ -158,7 +158,7 @@ def setUp(self): def tearDown(self): self.c.close() - super(_TestPoolingBase, self).tearDown() + super().tearDown() def create_pool(self, pair=(client_context.host, client_context.port), *args, **kwargs): # Start the pool with the correct ssl options. @@ -329,7 +329,7 @@ def test_wait_queue_timeout(self): duration = time.time() - start self.assertTrue( abs(wait_queue_timeout - duration) < 1, - "Waited %.2f seconds for a socket, expected %f" % (duration, wait_queue_timeout), + f"Waited {duration:.2f} seconds for a socket, expected {wait_queue_timeout:f}", ) def test_no_wait_queue_timeout(self): @@ -440,7 +440,7 @@ def f(): with lock: self.n_passed += 1 - for i in range(nthreads): + for _i in range(nthreads): t = threading.Thread(target=f) threads.append(t) t.start() @@ -472,7 +472,7 @@ def f(): with lock: self.n_passed += 1 - for i in range(nthreads): + for _i in range(nthreads): t = threading.Thread(target=f) threads.append(t) t.start() @@ -500,7 +500,7 @@ def test_max_pool_size_with_connection_failure(self): # First call to get_socket fails; if pool doesn't release its semaphore # then the second call raises "ConnectionFailure: Timed out waiting for # socket from pool" instead of AutoReconnect. - for i in range(2): + for _i in range(2): with self.assertRaises(AutoReconnect) as context: with test_pool.get_socket(): pass diff --git a/test/test_read_concern.py b/test/test_read_concern.py index 2230f2bef2..682fe03e72 100644 --- a/test/test_read_concern.py +++ b/test/test_read_concern.py @@ -33,7 +33,7 @@ class TestReadConcern(IntegrationTest): @classmethod @client_context.require_connection def setUpClass(cls): - super(TestReadConcern, cls).setUpClass() + super().setUpClass() cls.listener = OvertCommandListener() cls.client = rs_or_single_client(event_listeners=[cls.listener]) cls.db = cls.client.pymongo_test @@ -43,11 +43,11 @@ def setUpClass(cls): def tearDownClass(cls): cls.client.close() client_context.client.pymongo_test.drop_collection("coll") - super(TestReadConcern, cls).tearDownClass() + super().tearDownClass() def tearDown(self): self.listener.reset() - super(TestReadConcern, self).tearDown() + super().tearDown() def test_read_concern(self): rc = ReadConcern() @@ -65,7 +65,7 @@ def test_read_concern(self): self.assertRaises(TypeError, ReadConcern, 42) def test_read_concern_uri(self): - uri = "mongodb://%s/?readConcernLevel=majority" % (client_context.pair,) + uri = f"mongodb://{client_context.pair}/?readConcernLevel=majority" client = rs_or_single_client(uri, connect=False) self.assertEqual(ReadConcern("majority"), client.read_concern) diff --git a/test/test_read_preferences.py b/test/test_read_preferences.py index 1362623dff..6156b6b3fc 100644 --- a/test/test_read_preferences.py +++ b/test/test_read_preferences.py @@ -90,10 +90,10 @@ class TestReadPreferencesBase(IntegrationTest): @classmethod @client_context.require_secondaries_count(1) def setUpClass(cls): - super(TestReadPreferencesBase, cls).setUpClass() + super().setUpClass() def setUp(self): - super(TestReadPreferencesBase, self).setUp() + super().setUp() # Insert some data so we can use cursors in read_from_which_host self.client.pymongo_test.test.drop() self.client.get_database( @@ -119,16 +119,17 @@ def read_from_which_kind(self, client): return "secondary" else: self.fail( - "Cursor used address %s, expected either primary " - "%s or secondaries %s" % (address, client.primary, client.secondaries) + "Cursor used address {}, expected either primary " + "{} or secondaries {}".format(address, client.primary, client.secondaries) ) + return None def assertReadsFrom(self, expected, **kwargs): c = rs_client(**kwargs) wait_until(lambda: len(c.nodes - c.arbiters) == client_context.w, "discovered all nodes") used = self.read_from_which_kind(c) - self.assertEqual(expected, used, "Cursor used %s, expected %s" % (used, expected)) + self.assertEqual(expected, used, f"Cursor used {used}, expected {expected}") class TestSingleSecondaryOk(TestReadPreferencesBase): @@ -271,7 +272,7 @@ def test_nearest(self): self.assertFalse( not_used, "Expected to use primary and all secondaries for mode NEAREST," - " but didn't use %s\nlatencies: %s" % (not_used, latencies), + " but didn't use {}\nlatencies: {}".format(not_used, latencies), ) @@ -280,18 +281,18 @@ def __init__(self, *args, **kwargs): self.has_read_from = set() client_options = client_context.client_options client_options.update(kwargs) - super(ReadPrefTester, self).__init__(*args, **client_options) + super().__init__(*args, **client_options) @contextlib.contextmanager def _socket_for_reads(self, read_preference, session): - context = super(ReadPrefTester, self)._socket_for_reads(read_preference, session) + context = super()._socket_for_reads(read_preference, session) with context as (sock_info, read_preference): self.record_a_read(sock_info.address) yield sock_info, read_preference @contextlib.contextmanager def _socket_from_server(self, read_preference, server, session): - context = super(ReadPrefTester, self)._socket_from_server(read_preference, server, session) + context = super()._socket_from_server(read_preference, server, session) with context as (sock_info, read_preference): self.record_a_read(sock_info.address) yield sock_info, read_preference @@ -317,7 +318,7 @@ class TestCommandAndReadPreference(IntegrationTest): @classmethod @client_context.require_secondaries_count(1) def setUpClass(cls): - super(TestCommandAndReadPreference, cls).setUpClass() + super().setUpClass() cls.c = ReadPrefTester( client_context.pair, # Ignore round trip times, to test ReadPreference modes only. @@ -360,7 +361,7 @@ def _test_fn(self, server_type, fn): break assert self.c.primary is not None - unused = self.c.secondaries.union(set([self.c.primary])).difference(used) + unused = self.c.secondaries.union({self.c.primary}).difference(used) if unused: self.fail("Some members not used for NEAREST: %s" % (unused)) else: @@ -373,7 +374,10 @@ def _test_primary_helper(self, func): def _test_coll_helper(self, secondary_ok, coll, meth, *args, **kwargs): for mode, server_type in _PREF_MAP: new_coll = coll.with_options(read_preference=mode()) - func = lambda: getattr(new_coll, meth)(*args, **kwargs) + + def func(): + return getattr(new_coll, meth)(*args, **kwargs) + if secondary_ok: self._test_fn(server_type, func) else: @@ -383,7 +387,10 @@ def test_command(self): # Test that the generic command helper obeys the read preference # passed to it. for mode, server_type in _PREF_MAP: - func = lambda: self.c.pymongo_test.command("dbStats", read_preference=mode()) + + def func(): + return self.c.pymongo_test.command("dbStats", read_preference=mode()) + self._test_fn(server_type, func) def test_create_collection(self): @@ -536,7 +543,7 @@ def test_send_hedge(self): client = rs_client(event_listeners=[listener]) self.addCleanup(client.close) client.admin.command("ping") - for mode, cls in cases.items(): + for _mode, cls in cases.items(): pref = cls(hedge={"enabled": True}) coll = client.test.get_collection("test", read_preference=pref) listener.reset() diff --git a/test/test_read_write_concern_spec.py b/test/test_read_write_concern_spec.py index 26bc111f00..2b39f7d04e 100644 --- a/test/test_read_write_concern_spec.py +++ b/test/test_read_write_concern_spec.py @@ -89,16 +89,16 @@ def insert_command_default_write_concern(): f() self.assertGreaterEqual(len(listener.started_events), 1) - for i, event in enumerate(listener.started_events): + for _i, event in enumerate(listener.started_events): self.assertNotIn( "readConcern", event.command, - "%s sent default readConcern with %s" % (name, event.command_name), + f"{name} sent default readConcern with {event.command_name}", ) self.assertNotIn( "writeConcern", event.command, - "%s sent default writeConcern with %s" % (name, event.command_name), + f"{name} sent default writeConcern with {event.command_name}", ) def assertWriteOpsRaise(self, write_concern, expected_exception): @@ -307,7 +307,7 @@ def create_tests(): fname = os.path.splitext(filename)[0] for test_case in test_cases: new_test = create_test(test_case) - test_name = "test_%s_%s_%s" % ( + test_name = "test_{}_{}_{}".format( dirname.replace("-", "_"), fname.replace("-", "_"), str(test_case["description"].lower().replace(" ", "_")), diff --git a/test/test_replica_set_reconfig.py b/test/test_replica_set_reconfig.py index 898be99d4d..bdeaeb06a3 100644 --- a/test/test_replica_set_reconfig.py +++ b/test/test_replica_set_reconfig.py @@ -83,7 +83,7 @@ def test_replica_set_client(self): c.mock_members.remove("c:3") c.mock_standalones.append("c:3") - wait_until(lambda: set([("b", 2)]) == c.secondaries, "update the list of secondaries") + wait_until(lambda: {("b", 2)} == c.secondaries, "update the list of secondaries") self.assertEqual(("a", 1), c.primary) @@ -106,7 +106,7 @@ def test_replica_set_client(self): # C is removed. c.mock_hello_hosts.remove("c:3") - wait_until(lambda: set([("b", 2)]) == c.secondaries, "update list of secondaries") + wait_until(lambda: {("b", 2)} == c.secondaries, "update list of secondaries") self.assertEqual(("a", 1), c.primary) @@ -148,7 +148,7 @@ def test_client(self): # MongoClient connects to primary by default. self.assertEqual(c.address, ("a", 1)) - self.assertEqual(set([("a", 1), ("b", 2)]), c.nodes) + self.assertEqual({("a", 1), ("b", 2)}, c.nodes) # C is added. c.mock_members.append("c:3") @@ -159,7 +159,7 @@ def test_client(self): self.assertEqual(c.address, ("a", 1)) wait_until( - lambda: set([("a", 1), ("b", 2), ("c", 3)]) == c.nodes, "reconnect to both secondaries" + lambda: {("a", 1), ("b", 2), ("c", 3)} == c.nodes, "reconnect to both secondaries" ) def test_replica_set_client(self): @@ -169,13 +169,13 @@ def test_replica_set_client(self): self.addCleanup(c.close) wait_until(lambda: ("a", 1) == c.primary, "discover the primary") - wait_until(lambda: set([("b", 2)]) == c.secondaries, "discover the secondary") + wait_until(lambda: {("b", 2)} == c.secondaries, "discover the secondary") # C is added. c.mock_members.append("c:3") c.mock_hello_hosts.append("c:3") - wait_until(lambda: set([("b", 2), ("c", 3)]) == c.secondaries, "discover the new secondary") + wait_until(lambda: {("b", 2), ("c", 3)} == c.secondaries, "discover the new secondary") self.assertEqual(("a", 1), c.primary) diff --git a/test/test_retryable_reads.py b/test/test_retryable_reads.py index 517e1122b0..ee12c524c9 100644 --- a/test/test_retryable_reads.py +++ b/test/test_retryable_reads.py @@ -76,14 +76,14 @@ class TestSpec(SpecRunner): # TODO: remove this once PYTHON-1948 is done. @client_context.require_no_mmap def setUpClass(cls): - super(TestSpec, cls).setUpClass() + super().setUpClass() def maybe_skip_scenario(self, test): - super(TestSpec, self).maybe_skip_scenario(test) + super().maybe_skip_scenario(test) skip_names = ["listCollectionObjects", "listIndexNames", "listDatabaseObjects"] for name in skip_names: if name.lower() in test["description"].lower(): - self.skipTest("PyMongo does not support %s" % (name,)) + self.skipTest(f"PyMongo does not support {name}") # Serverless does not support $out and collation. if client_context.serverless: @@ -107,7 +107,7 @@ def get_scenario_coll_name(self, scenario_def): """Override a test's collection name to support GridFS tests.""" if "bucket_name" in scenario_def: return scenario_def["bucket_name"] - return super(TestSpec, self).get_scenario_coll_name(scenario_def) + return super().get_scenario_coll_name(scenario_def) def setup_scenario(self, scenario_def): """Override a test's setup to support GridFS tests.""" @@ -127,7 +127,7 @@ def setup_scenario(self, scenario_def): db.get_collection("fs.chunks").drop() db.get_collection("fs.files", write_concern=wc).drop() else: - super(TestSpec, self).setup_scenario(scenario_def) + super().setup_scenario(scenario_def) def create_test(scenario_def, test, name): diff --git a/test/test_retryable_writes.py b/test/test_retryable_writes.py index 1e978f21be..32841a8227 100644 --- a/test/test_retryable_writes.py +++ b/test/test_retryable_writes.py @@ -68,7 +68,7 @@ class InsertEventListener(EventListener): def succeeded(self, event: CommandSucceededEvent) -> None: - super(InsertEventListener, self).succeeded(event) + super().succeeded(event) if ( event.command_name == "insert" and event.reply.get("writeConcernError", {}).get("code", None) == 91 @@ -108,7 +108,7 @@ def run_test_ops(self, sessions, collection, test): if "result" in outcome: operation["result"] = outcome["result"] test["operations"] = [operation] - super(TestAllScenarios, self).run_test_ops(sessions, collection, test) + super().run_test_ops(sessions, collection, test) def create_test(scenario_def, test, name): @@ -168,13 +168,13 @@ class IgnoreDeprecationsTest(IntegrationTest): @classmethod def setUpClass(cls): - super(IgnoreDeprecationsTest, cls).setUpClass() + super().setUpClass() cls.deprecation_filter = DeprecationFilter() @classmethod def tearDownClass(cls): cls.deprecation_filter.stop() - super(IgnoreDeprecationsTest, cls).tearDownClass() + super().tearDownClass() class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest): @@ -182,7 +182,7 @@ class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest): @classmethod def setUpClass(cls): - super(TestRetryableWritesMMAPv1, cls).setUpClass() + super().setUpClass() # Speed up the tests by decreasing the heartbeat frequency. cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) cls.knobs.enable() @@ -193,7 +193,7 @@ def setUpClass(cls): def tearDownClass(cls): cls.knobs.disable() cls.client.close() - super(TestRetryableWritesMMAPv1, cls).tearDownClass() + super().tearDownClass() @client_context.require_no_standalone def test_actionable_error_message(self): @@ -217,7 +217,7 @@ class TestRetryableWrites(IgnoreDeprecationsTest): @classmethod @client_context.require_no_mmap def setUpClass(cls): - super(TestRetryableWrites, cls).setUpClass() + super().setUpClass() # Speed up the tests by decreasing the heartbeat frequency. cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) cls.knobs.enable() @@ -229,7 +229,7 @@ def setUpClass(cls): def tearDownClass(cls): cls.knobs.disable() cls.client.close() - super(TestRetryableWrites, cls).tearDownClass() + super().tearDownClass() def setUp(self): if client_context.is_rs and client_context.test_commands_enabled: @@ -248,20 +248,20 @@ def test_supported_single_statement_no_retry(self): client = rs_or_single_client(retryWrites=False, event_listeners=[listener]) self.addCleanup(client.close) for method, args, kwargs in retryable_single_statement_ops(client.db.retryable_write_test): - msg = "%s(*%r, **%r)" % (method.__name__, args, kwargs) + msg = f"{method.__name__}(*{args!r}, **{kwargs!r})" listener.reset() method(*args, **kwargs) for event in listener.started_events: self.assertNotIn( "txnNumber", event.command, - "%s sent txnNumber with %s" % (msg, event.command_name), + f"{msg} sent txnNumber with {event.command_name}", ) @client_context.require_no_standalone def test_supported_single_statement_supported_cluster(self): for method, args, kwargs in retryable_single_statement_ops(self.db.retryable_write_test): - msg = "%s(*%r, **%r)" % (method.__name__, args, kwargs) + msg = f"{method.__name__}(*{args!r}, **{kwargs!r})" self.listener.reset() method(*args, **kwargs) commands_started = self.listener.started_events @@ -270,13 +270,13 @@ def test_supported_single_statement_supported_cluster(self): self.assertIn( "lsid", first_attempt.command, - "%s sent no lsid with %s" % (msg, first_attempt.command_name), + f"{msg} sent no lsid with {first_attempt.command_name}", ) initial_session_id = first_attempt.command["lsid"] self.assertIn( "txnNumber", first_attempt.command, - "%s sent no txnNumber with %s" % (msg, first_attempt.command_name), + f"{msg} sent no txnNumber with {first_attempt.command_name}", ) # There should be no retry when the failpoint is not active. @@ -289,13 +289,13 @@ def test_supported_single_statement_supported_cluster(self): self.assertIn( "lsid", retry_attempt.command, - "%s sent no lsid with %s" % (msg, first_attempt.command_name), + f"{msg} sent no lsid with {first_attempt.command_name}", ) self.assertEqual(retry_attempt.command["lsid"], initial_session_id, msg) self.assertIn( "txnNumber", retry_attempt.command, - "%s sent no txnNumber with %s" % (msg, first_attempt.command_name), + f"{msg} sent no txnNumber with {first_attempt.command_name}", ) self.assertEqual(retry_attempt.command["txnNumber"], initial_transaction_id, msg) @@ -304,7 +304,7 @@ def test_supported_single_statement_unsupported_cluster(self): raise SkipTest("This cluster supports retryable writes") for method, args, kwargs in retryable_single_statement_ops(self.db.retryable_write_test): - msg = "%s(*%r, **%r)" % (method.__name__, args, kwargs) + msg = f"{method.__name__}(*{args!r}, **{kwargs!r})" self.listener.reset() method(*args, **kwargs) @@ -312,7 +312,7 @@ def test_supported_single_statement_unsupported_cluster(self): self.assertNotIn( "txnNumber", event.command, - "%s sent txnNumber with %s" % (msg, event.command_name), + f"{msg} sent txnNumber with {event.command_name}", ) def test_unsupported_single_statement(self): @@ -322,7 +322,7 @@ def test_unsupported_single_statement(self): for method, args, kwargs in non_retryable_single_statement_ops( coll ) + retryable_single_statement_ops(coll_w0): - msg = "%s(*%r, **%r)" % (method.__name__, args, kwargs) + msg = f"{method.__name__}(*{args!r}, **{kwargs!r})" self.listener.reset() method(*args, **kwargs) started_events = self.listener.started_events @@ -332,7 +332,7 @@ def test_unsupported_single_statement(self): self.assertNotIn( "txnNumber", event.command, - "%s sent txnNumber with %s" % (msg, event.command_name), + f"{msg} sent txnNumber with {event.command_name}", ) def test_server_selection_timeout_not_retried(self): @@ -345,7 +345,7 @@ def test_server_selection_timeout_not_retried(self): event_listeners=[listener], ) for method, args, kwargs in retryable_single_statement_ops(client.db.retryable_write_test): - msg = "%s(*%r, **%r)" % (method.__name__, args, kwargs) + msg = f"{method.__name__}(*{args!r}, **{kwargs!r})" listener.reset() with self.assertRaises(ServerSelectionTimeoutError, msg=msg): method(*args, **kwargs) @@ -374,7 +374,7 @@ def raise_error(*args, **kwargs): return server for method, args, kwargs in retryable_single_statement_ops(client.db.retryable_write_test): - msg = "%s(*%r, **%r)" % (method.__name__, args, kwargs) + msg = f"{method.__name__}(*{args!r}, **{kwargs!r})" listener.reset() topology.select_server = mock_select_server with self.assertRaises(ConnectionFailure, msg=msg): @@ -479,7 +479,7 @@ class TestWriteConcernError(IntegrationTest): @client_context.require_no_mmap @client_context.require_failCommand_fail_point def setUpClass(cls): - super(TestWriteConcernError, cls).setUpClass() + super().setUpClass() cls.fail_insert = { "configureFailPoint": "failCommand", "mode": {"times": 2}, @@ -668,7 +668,7 @@ def raise_connection_err_select_server(*args, **kwargs): with client.start_session() as session: kwargs = copy.deepcopy(kwargs) kwargs["session"] = session - msg = "%s(*%r, **%r)" % (method.__name__, args, kwargs) + msg = f"{method.__name__}(*{args!r}, **{kwargs!r})" initial_txn_id = session._server_session.transaction_id # Each operation should fail on the first attempt and succeed diff --git a/test/test_sdam_monitoring_spec.py b/test/test_sdam_monitoring_spec.py index d7b3744399..2587ae7965 100644 --- a/test/test_sdam_monitoring_spec.py +++ b/test/test_sdam_monitoring_spec.py @@ -44,12 +44,12 @@ def compare_server_descriptions(expected, actual): - if (not expected["address"] == "%s:%s" % actual.address) or ( + if (not expected["address"] == "{}:{}".format(*actual.address)) or ( not server_name_to_type(expected["type"]) == actual.server_type ): return False expected_hosts = set(expected["arbiters"] + expected["passives"] + expected["hosts"]) - return expected_hosts == set("%s:%s" % s for s in actual.all_hosts) + return expected_hosts == {"{}:{}".format(*s) for s in actual.all_hosts} def compare_topology_descriptions(expected, actual): @@ -60,7 +60,7 @@ def compare_topology_descriptions(expected, actual): if len(expected) != len(actual): return False for exp_server in expected: - for address, actual_server in actual.items(): + for _address, actual_server in actual.items(): if compare_server_descriptions(exp_server, actual_server): break else: @@ -79,22 +79,22 @@ def compare_events(expected_dict, actual): if expected_type == "server_opening_event": if not isinstance(actual, monitoring.ServerOpeningEvent): return False, "Expected ServerOpeningEvent, got %s" % (actual.__class__) - if not expected["address"] == "%s:%s" % actual.server_address: + if not expected["address"] == "{}:{}".format(*actual.server_address): return ( False, "ServerOpeningEvent published with wrong address (expected" - " %s, got %s" % (expected["address"], actual.server_address), + " {}, got {}".format(expected["address"], actual.server_address), ) elif expected_type == "server_description_changed_event": if not isinstance(actual, monitoring.ServerDescriptionChangedEvent): return (False, "Expected ServerDescriptionChangedEvent, got %s" % (actual.__class__)) - if not expected["address"] == "%s:%s" % actual.server_address: + if not expected["address"] == "{}:{}".format(*actual.server_address): return ( False, "ServerDescriptionChangedEvent has wrong address" - " (expected %s, got %s" % (expected["address"], actual.server_address), + " (expected {}, got {}".format(expected["address"], actual.server_address), ) if not compare_server_descriptions(expected["newDescription"], actual.new_description): @@ -110,11 +110,11 @@ def compare_events(expected_dict, actual): elif expected_type == "server_closed_event": if not isinstance(actual, monitoring.ServerClosedEvent): return False, "Expected ServerClosedEvent, got %s" % (actual.__class__) - if not expected["address"] == "%s:%s" % actual.server_address: + if not expected["address"] == "{}:{}".format(*actual.server_address): return ( False, "ServerClosedEvent published with wrong address" - " (expected %s, got %s" % (expected["address"], actual.server_address), + " (expected {}, got {}".format(expected["address"], actual.server_address), ) elif expected_type == "topology_opening_event": @@ -145,7 +145,7 @@ def compare_events(expected_dict, actual): return False, "Expected TopologyClosedEvent, got %s" % (actual.__class__) else: - return False, "Incorrect event: expected %s, actual %s" % (expected_type, actual) + return False, f"Incorrect event: expected {expected_type}, actual {actual}" return True, "" @@ -170,7 +170,7 @@ def compare_multiple_events(i, expected_results, actual_results): class TestAllScenarios(IntegrationTest): def setUp(self): - super(TestAllScenarios, self).setUp() + super().setUp() self.all_listener = ServerAndTopologyEventListener() @@ -235,7 +235,7 @@ def _run(self): # Assert no extra events. extra_events = self.all_listener.results[expected_len:] if extra_events: - self.fail("Extra events %r" % (extra_events,)) + self.fail(f"Extra events {extra_events!r}") self.all_listener.reset() finally: @@ -251,7 +251,7 @@ def create_tests(): scenario_def = json.load(scenario_stream, object_hook=object_hook) # Construct test from scenario. new_test = create_test(scenario_def) - test_name = "test_%s" % (os.path.splitext(filename)[0],) + test_name = f"test_{os.path.splitext(filename)[0]}" new_test.__name__ = test_name setattr(TestAllScenarios, new_test.__name__, new_test) @@ -268,7 +268,7 @@ class TestSdamMonitoring(IntegrationTest): @classmethod @client_context.require_failCommand_fail_point def setUpClass(cls): - super(TestSdamMonitoring, cls).setUpClass() + super().setUpClass() # Speed up the tests by decreasing the event publish frequency. cls.knobs = client_knobs(events_queue_frequency=0.1) cls.knobs.enable() @@ -284,7 +284,7 @@ def setUpClass(cls): def tearDownClass(cls): cls.test_client.close() cls.knobs.disable() - super(TestSdamMonitoring, cls).tearDownClass() + super().tearDownClass() def setUp(self): self.listener.reset() diff --git a/test/test_server_selection.py b/test/test_server_selection.py index 8d4ffe5e9b..30b82769bc 100644 --- a/test/test_server_selection.py +++ b/test/test_server_selection.py @@ -48,7 +48,7 @@ ) -class SelectionStoreSelector(object): +class SelectionStoreSelector: """No-op selector that keeps track of what was passed to it.""" def __init__(self): @@ -103,7 +103,7 @@ def all_hosts_started(): def test_invalid_server_selector(self): # Client initialization must fail if server_selector is not callable. - for selector_candidate in [list(), 10, "string", {}]: + for selector_candidate in [[], 10, "string", {}]: with self.assertRaisesRegex(ValueError, "must be a callable"): MongoClient(connect=False, server_selector=selector_candidate) diff --git a/test/test_server_selection_in_window.py b/test/test_server_selection_in_window.py index d076ae77b3..63769a6457 100644 --- a/test/test_server_selection_in_window.py +++ b/test/test_server_selection_in_window.py @@ -46,7 +46,7 @@ def run_scenario(self, scenario_def): server.pool.operation_count = mock["operation_count"] pref = ReadPreference.NEAREST - counts = dict((address, 0) for address in topology._description.server_descriptions()) + counts = {address: 0 for address in topology._description.server_descriptions()} # Number of times to repeat server selection iterations = scenario_def["iterations"] @@ -91,7 +91,7 @@ def tests(self, scenario_def): class FinderThread(threading.Thread): def __init__(self, collection, iterations): - super(FinderThread, self).__init__() + super().__init__() self.daemon = True self.collection = collection self.iterations = iterations diff --git a/test/test_server_selection_rtt.py b/test/test_server_selection_rtt.py index d2d8768809..5c2a8a6fba 100644 --- a/test/test_server_selection_rtt.py +++ b/test/test_server_selection_rtt.py @@ -57,7 +57,7 @@ def create_tests(): # Construct test from scenario. new_test = create_test(scenario_def) - test_name = "test_%s_%s" % (dirname, os.path.splitext(filename)[0]) + test_name = f"test_{dirname}_{os.path.splitext(filename)[0]}" new_test.__name__ = test_name setattr(TestAllScenarios, new_test.__name__, new_test) diff --git a/test/test_session.py b/test/test_session.py index 25d209ebaf..18d0122dae 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -47,15 +47,15 @@ class SessionTestListener(EventListener): def started(self, event): if not event.command_name.startswith("sasl"): - super(SessionTestListener, self).started(event) + super().started(event) def succeeded(self, event): if not event.command_name.startswith("sasl"): - super(SessionTestListener, self).succeeded(event) + super().succeeded(event) def failed(self, event): if not event.command_name.startswith("sasl"): - super(SessionTestListener, self).failed(event) + super().failed(event) def first_command_started(self): assert len(self.started_events) >= 1, "No command-started events" @@ -74,7 +74,7 @@ class TestSession(IntegrationTest): @classmethod @client_context.require_sessions def setUpClass(cls): - super(TestSession, cls).setUpClass() + super().setUpClass() # Create a second client so we can make sure clients cannot share # sessions. cls.client2 = rs_or_single_client() @@ -87,7 +87,7 @@ def setUpClass(cls): def tearDownClass(cls): monitoring._SENSITIVE_COMMANDS.update(cls.sensitive_commands) cls.client2.close() - super(TestSession, cls).tearDownClass() + super().tearDownClass() def setUp(self): self.listener = SessionTestListener() @@ -97,7 +97,7 @@ def setUp(self): ) self.addCleanup(self.client.close) self.db = self.client.pymongo_test - self.initial_lsids = set(s["id"] for s in session_ids(self.client)) + self.initial_lsids = {s["id"] for s in session_ids(self.client)} def tearDown(self): """All sessions used in the test must be returned to the pool.""" @@ -107,7 +107,7 @@ def tearDown(self): if "lsid" in event.command: used_lsids.add(event.command["lsid"]["id"]) - current_lsids = set(s["id"] for s in session_ids(self.client)) + current_lsids = {s["id"] for s in session_ids(self.client)} self.assertLessEqual(used_lsids, current_lsids) def _test_ops(self, client, *ops): @@ -129,13 +129,13 @@ def _test_ops(self, client, *ops): for event in listener.started_events: self.assertTrue( "lsid" in event.command, - "%s sent no lsid with %s" % (f.__name__, event.command_name), + f"{f.__name__} sent no lsid with {event.command_name}", ) self.assertEqual( s.session_id, event.command["lsid"], - "%s sent wrong lsid with %s" % (f.__name__, event.command_name), + f"{f.__name__} sent wrong lsid with {event.command_name}", ) self.assertFalse(s.has_ended) @@ -164,7 +164,7 @@ def _test_ops(self, client, *ops): for event in listener.started_events: self.assertTrue( "lsid" in event.command, - "%s sent no lsid with %s" % (f.__name__, event.command_name), + f"{f.__name__} sent no lsid with {event.command_name}", ) lsids.append(event.command["lsid"]) @@ -176,7 +176,7 @@ def _test_ops(self, client, *ops): self.assertIn( lsid, session_ids(client), - "%s did not return implicit session to pool" % (f.__name__,), + f"{f.__name__} did not return implicit session to pool", ) def test_implicit_sessions_checkout(self): @@ -405,13 +405,13 @@ def test_cursor(self): for event in listener.started_events: self.assertTrue( "lsid" in event.command, - "%s sent no lsid with %s" % (name, event.command_name), + f"{name} sent no lsid with {event.command_name}", ) self.assertEqual( s.session_id, event.command["lsid"], - "%s sent wrong lsid with %s" % (name, event.command_name), + f"{name} sent wrong lsid with {event.command_name}", ) with self.assertRaisesRegex(InvalidOperation, "ended session"): @@ -423,20 +423,20 @@ def test_cursor(self): f(session=None) event0 = listener.first_command_started() self.assertTrue( - "lsid" in event0.command, "%s sent no lsid with %s" % (name, event0.command_name) + "lsid" in event0.command, f"{name} sent no lsid with {event0.command_name}" ) lsid = event0.command["lsid"] for event in listener.started_events[1:]: self.assertTrue( - "lsid" in event.command, "%s sent no lsid with %s" % (name, event.command_name) + "lsid" in event.command, f"{name} sent no lsid with {event.command_name}" ) self.assertEqual( lsid, event.command["lsid"], - "%s sent wrong lsid with %s" % (name, event.command_name), + f"{name} sent wrong lsid with {event.command_name}", ) def test_gridfs(self): @@ -693,7 +693,7 @@ def _test_unacknowledged_ops(self, client, *ops): kw = copy.copy(kw) kw["session"] = s with self.assertRaises( - ConfigurationError, msg="%s did not raise ConfigurationError" % (f.__name__,) + ConfigurationError, msg=f"{f.__name__} did not raise ConfigurationError" ): f(*args, **kw) if f.__name__ == "create_collection": @@ -703,11 +703,11 @@ def _test_unacknowledged_ops(self, client, *ops): self.assertIn( "lsid", event.command, - "%s sent no lsid with %s" % (f.__name__, event.command_name), + f"{f.__name__} sent no lsid with {event.command_name}", ) # Should not run any command before raising an error. - self.assertFalse(listener.started_events, "%s sent command" % (f.__name__,)) + self.assertFalse(listener.started_events, f"{f.__name__} sent command") self.assertTrue(s.has_ended) @@ -724,12 +724,12 @@ def _test_unacknowledged_ops(self, client, *ops): self.assertIn( "lsid", event.command, - "%s sent no lsid with %s" % (f.__name__, event.command_name), + f"{f.__name__} sent no lsid with {event.command_name}", ) for event in listener.started_events: self.assertNotIn( - "lsid", event.command, "%s sent lsid with %s" % (f.__name__, event.command_name) + "lsid", event.command, f"{f.__name__} sent lsid with {event.command_name}" ) def test_unacknowledged_writes(self): @@ -792,7 +792,7 @@ def tearDownClass(cls): @client_context.require_sessions def setUp(self): - super(TestCausalConsistency, self).setUp() + super().setUp() @client_context.require_no_standalone def test_core(self): @@ -1072,7 +1072,7 @@ def test_cluster_time_no_server_support(self): class TestClusterTime(IntegrationTest): def setUp(self): - super(TestClusterTime, self).setUp() + super().setUp() if "$clusterTime" not in client_context.hello: raise SkipTest("$clusterTime not supported") @@ -1128,7 +1128,7 @@ def insert_and_aggregate(): ("rename_and_drop", rename_and_drop), ] - for name, f in ops: + for _name, f in ops: listener.reset() # Call f() twice, insert to advance clusterTime, call f() again. f() @@ -1140,21 +1140,20 @@ def insert_and_aggregate(): for i, event in enumerate(listener.started_events): self.assertTrue( "$clusterTime" in event.command, - "%s sent no $clusterTime with %s" % (f.__name__, event.command_name), + f"{f.__name__} sent no $clusterTime with {event.command_name}", ) if i > 0: succeeded = listener.succeeded_events[i - 1] self.assertTrue( "$clusterTime" in succeeded.reply, - "%s received no $clusterTime with %s" - % (f.__name__, succeeded.command_name), + f"{f.__name__} received no $clusterTime with {succeeded.command_name}", ) self.assertTrue( event.command["$clusterTime"]["clusterTime"] >= succeeded.reply["$clusterTime"]["clusterTime"], - "%s sent wrong $clusterTime with %s" % (f.__name__, event.command_name), + f"{f.__name__} sent wrong $clusterTime with {event.command_name}", ) diff --git a/test/test_son.py b/test/test_son.py index 5c1f43594d..5e62ffb176 100644 --- a/test/test_son.py +++ b/test/test_son.py @@ -47,7 +47,7 @@ def test_equality(self): self.assertEqual(a1, SON({"hello": "world"})) self.assertEqual(b2, SON((("hello", "world"), ("mike", "awesome"), ("hello_", "mike")))) - self.assertEqual(b2, dict((("hello_", "mike"), ("mike", "awesome"), ("hello", "world")))) + self.assertEqual(b2, {"hello_": "mike", "mike": "awesome", "hello": "world"}) self.assertNotEqual(a1, b2) self.assertNotEqual(b2, SON((("hello_", "mike"), ("mike", "awesome"), ("hello", "world")))) @@ -55,7 +55,7 @@ def test_equality(self): # Explicitly test inequality self.assertFalse(a1 != SON({"hello": "world"})) self.assertFalse(b2 != SON((("hello", "world"), ("mike", "awesome"), ("hello_", "mike")))) - self.assertFalse(b2 != dict((("hello_", "mike"), ("mike", "awesome"), ("hello", "world")))) + self.assertFalse(b2 != {"hello_": "mike", "mike": "awesome", "hello": "world"}) # Embedded SON. d4 = SON([("blah", {"foo": SON()})]) @@ -97,10 +97,10 @@ def test_pickle_backwards_compatability(self): # This string was generated by pickling a SON object in pymongo # version 2.1.1 pickled_with_2_1_1 = ( - "ccopy_reg\n_reconstructor\np0\n(cbson.son\nSON\np1\n" - "c__builtin__\ndict\np2\n(dp3\ntp4\nRp5\n(dp6\n" - "S'_SON__keys'\np7\n(lp8\nsb." - ).encode("utf8") + b"ccopy_reg\n_reconstructor\np0\n(cbson.son\nSON\np1\n" + b"c__builtin__\ndict\np2\n(dp3\ntp4\nRp5\n(dp6\n" + b"S'_SON__keys'\np7\n(lp8\nsb." + ) son_2_1_1 = pickle.loads(pickled_with_2_1_1) self.assertEqual(son_2_1_1, SON([])) @@ -138,18 +138,14 @@ def test_copying(self): self.assertEqual(id(reflexive_son1), id(reflexive_son1["reflexive"])) def test_iteration(self): - """ - Test __iter__ - """ + """Test __iter__""" # test success case test_son = SON([(1, 100), (2, 200), (3, 300)]) for ele in test_son: self.assertEqual(ele * 100, test_son[ele]) def test_contains_has(self): - """ - has_key and __contains__ - """ + """has_key and __contains__""" test_son = SON([(1, 100), (2, 200), (3, 300)]) self.assertIn(1, test_son) self.assertTrue(2 in test_son, "in failed") @@ -158,9 +154,7 @@ def test_contains_has(self): self.assertFalse(test_son.has_key(22), "has_key succeeded when it shouldn't") # noqa def test_clears(self): - """ - Test clear() - """ + """Test clear()""" test_son = SON([(1, 100), (2, 200), (3, 300)]) test_son.clear() self.assertNotIn(1, test_son) @@ -169,9 +163,7 @@ def test_clears(self): self.assertEqual({}, test_son.to_dict()) def test_len(self): - """ - Test len - """ + """Test len""" test_son = SON() self.assertEqual(0, len(test_son)) test_son = SON([(1, 100), (2, 200), (3, 300)]) diff --git a/test/test_srv_polling.py b/test/test_srv_polling.py index 7a6c61ad21..8bf81f4de9 100644 --- a/test/test_srv_polling.py +++ b/test/test_srv_polling.py @@ -32,7 +32,7 @@ WAIT_TIME = 0.1 -class SrvPollingKnobs(object): +class SrvPollingKnobs: def __init__( self, ttl_time=None, diff --git a/test/test_ssl.py b/test/test_ssl.py index bf151578cb..e6df2a1c24 100644 --- a/test/test_ssl.py +++ b/test/test_ssl.py @@ -142,7 +142,7 @@ def assertClientWorks(self, client): @classmethod @unittest.skipUnless(HAVE_SSL, "The ssl module is not available.") def setUpClass(cls): - super(TestSSL, cls).setUpClass() + super().setUpClass() # MongoClient should connect to the primary by default. cls.saved_port = MongoClient.PORT MongoClient.PORT = client_context.port @@ -150,7 +150,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): MongoClient.PORT = cls.saved_port - super(TestSSL, cls).tearDownClass() + super().tearDownClass() @client_context.require_tls def test_simple_ssl(self): diff --git a/test/test_threads.py b/test/test_threads.py index 899392e1a0..b948bf9249 100644 --- a/test/test_threads.py +++ b/test/test_threads.py @@ -111,7 +111,7 @@ def test_threading(self): self.db.test.insert_many([{"x": i} for i in range(1000)]) threads = [] - for i in range(10): + for _i in range(10): t = SaveAndFind(self.db.test) t.start() threads.append(t) diff --git a/test/test_topology.py b/test/test_topology.py index e09d7c3691..adbf19f571 100644 --- a/test/test_topology.py +++ b/test/test_topology.py @@ -89,7 +89,7 @@ class TopologyTest(unittest.TestCase): """Disables periodic monitoring, to make tests deterministic.""" def setUp(self): - super(TopologyTest, self).setUp() + super().setUp() self.client_knobs = client_knobs(heartbeat_frequency=999999) self.client_knobs.enable() self.addCleanup(self.client_knobs.disable) @@ -647,13 +647,13 @@ def test_topology_repr(self): ) self.assertEqual( repr(t.description), - ", " ", " "]>" % (t._topology_id,), + " rtt: None>]>".format(t._topology_id), ) def test_unexpected_load_balancer(self): @@ -734,7 +734,7 @@ def _check_with_socket(self, *args, **kwargs): if hello_count[0] in (1, 3): return Hello({"ok": 1, "maxWireVersion": 6}), 0 else: - raise AutoReconnect("mock monitor error #%s" % (hello_count[0],)) + raise AutoReconnect(f"mock monitor error #{hello_count[0]}") t = create_mock_topology(monitor_class=TestMonitor) self.addCleanup(t.close) diff --git a/test/test_transactions.py b/test/test_transactions.py index dc58beb930..9b51927d67 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -63,19 +63,19 @@ class TransactionsBase(SpecRunner): @classmethod def setUpClass(cls): - super(TransactionsBase, cls).setUpClass() + super().setUpClass() if client_context.supports_transactions(): for address in client_context.mongoses: - cls.mongos_clients.append(single_client("%s:%s" % address)) + cls.mongos_clients.append(single_client("{}:{}".format(*address))) @classmethod def tearDownClass(cls): for client in cls.mongos_clients: client.close() - super(TransactionsBase, cls).tearDownClass() + super().tearDownClass() def maybe_skip_scenario(self, test): - super(TransactionsBase, self).maybe_skip_scenario(test) + super().maybe_skip_scenario(test) if ( "secondary" in self.id() and not client_context.is_mongos @@ -390,7 +390,7 @@ def test_transaction_direct_connection(self): list(res) -class PatchSessionTimeout(object): +class PatchSessionTimeout: """Patches the client_session's with_transaction timeout for testing.""" def __init__(self, mock_timeout): @@ -416,7 +416,7 @@ class _MyException(Exception): pass def raise_error(_): - raise _MyException() + raise _MyException with self.client.start_session() as s: with self.assertRaises(_MyException): diff --git a/test/test_typing.py b/test/test_typing.py index 0aebc707cd..27597bb2c8 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -13,7 +13,8 @@ # limitations under the License. """Test that each file in mypy_fails/ actually fails mypy, and test some -sample client code that uses PyMongo typings.""" +sample client code that uses PyMongo typings. +""" import os import sys import tempfile @@ -39,7 +40,7 @@ class ImplicitMovie(TypedDict): name: str year: int -except ImportError as exc: +except ImportError: Movie = dict # type:ignore[misc,assignment] ImplicitMovie = dict # type: ignore[assignment,misc] MovieWithId = dict # type: ignore[assignment,misc] @@ -164,12 +165,12 @@ def test_bulk_write_heterogeneous(self): def test_command(self) -> None: result: Dict = self.client.admin.command("ping") - items = result.items() + result.items() def test_list_collections(self) -> None: cursor = self.client.test.list_collections() value = cursor.next() - items = value.items() + value.items() def test_list_databases(self) -> None: cursor = self.client.list_databases() @@ -237,7 +238,7 @@ def foo(self): assert rt_document2.foo() == "bar" codec_options2 = CodecOptions(document_class=RawBSONDocument) - bsonbytes3 = encode(doc, codec_options=codec_options2) + encode(doc, codec_options=codec_options2) rt_document3 = decode(bsonbytes2, codec_options=codec_options2) assert rt_document3.raw @@ -463,7 +464,7 @@ def test_son_document_type(self) -> None: retrieved["a"] = 1 def test_son_document_type_runtime(self) -> None: - client = MongoClient(document_class=SON[str, Any], connect=False) + MongoClient(document_class=SON[str, Any], connect=False) @only_type_check def test_create_index(self) -> None: diff --git a/test/test_uri_parser.py b/test/test_uri_parser.py index 2f81e3b512..e2dd17ec26 100644 --- a/test/test_uri_parser.py +++ b/test/test_uri_parser.py @@ -508,7 +508,7 @@ def test_redact_AWS_SESSION_TOKEN(self): def test_special_chars(self): user = "user@ /9+:?~!$&'()*+,;=" pwd = "pwd@ /9+:?~!$&'()*+,;=" - uri = "mongodb://%s:%s@localhost" % (quote_plus(user), quote_plus(pwd)) + uri = f"mongodb://{quote_plus(user)}:{quote_plus(pwd)}@localhost" res = parse_uri(uri) self.assertEqual(user, res["username"]) self.assertEqual(pwd, res["password"]) diff --git a/test/test_uri_spec.py b/test/test_uri_spec.py index d12abf3b91..5b68c80401 100644 --- a/test/test_uri_spec.py +++ b/test/test_uri_spec.py @@ -13,7 +13,8 @@ # limitations under the License. """Test that the pymongo.uri_parser module is compliant with the connection -string and uri options specifications.""" +string and uri options specifications. +""" import json import os @@ -73,7 +74,7 @@ def setUp(self): def get_error_message_template(expected, artefact): - return "%s %s for test '%s'" % ("Expected" if expected else "Unexpected", artefact, "%s") + return "{} {} for test '{}'".format("Expected" if expected else "Unexpected", artefact, "%s") def run_scenario_in_dir(target_workdir): @@ -133,13 +134,15 @@ def run_scenario(self): for exp, actual in zip(test["hosts"], options["nodelist"]): self.assertEqual( - exp["host"], actual[0], "Expected host %s but got %s" % (exp["host"], actual[0]) + exp["host"], + actual[0], + "Expected host {} but got {}".format(exp["host"], actual[0]), ) if exp["port"] is not None: self.assertEqual( exp["port"], actual[1], - "Expected port %s but got %s" % (exp["port"], actual), + "Expected port {} but got {}".format(exp["port"], actual), ) # Compare auth options. @@ -157,7 +160,7 @@ def run_scenario(self): self.assertEqual( auth[elm], options[elm], - "Expected %s but got %s" % (auth[elm], options[elm]), + f"Expected {auth[elm]} but got {options[elm]}", ) # Compare URI options. @@ -183,7 +186,7 @@ def run_scenario(self): ), ) else: - self.fail("Missing expected option %s" % (opt,)) + self.fail(f"Missing expected option {opt}") return run_scenario_in_dir(test_workdir)(run_scenario) @@ -209,7 +212,7 @@ def create_tests(test_path): continue testmethod = create_test(testcase, dirpath) - testname = "test_%s_%s_%s" % ( + testname = "test_{}_{}_{}".format( dirname, os.path.splitext(filename)[0], str(dsc).replace(" ", "_"), diff --git a/test/test_write_concern.py b/test/test_write_concern.py index 02c562a348..822f3a4d1d 100644 --- a/test/test_write_concern.py +++ b/test/test_write_concern.py @@ -40,7 +40,7 @@ def test_equality_to_none(self): self.assertTrue(concern != None) # noqa def test_equality_compatible_type(self): - class _FakeWriteConcern(object): + class _FakeWriteConcern: def __init__(self, **document): self.document = document diff --git a/test/unified_format.py b/test/unified_format.py index 584ee04ddd..90cb442b28 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -131,7 +131,7 @@ # Build up a placeholder map. -PLACEHOLDER_MAP = dict() +PLACEHOLDER_MAP = {} for (provider_name, provider_data) in [ ("local", {"key": LOCAL_MASTER_KEY}), ("aws", AWS_CREDS), @@ -257,7 +257,7 @@ def parse_bulk_write_error_result(error): return parse_bulk_write_result(write_result) -class NonLazyCursor(object): +class NonLazyCursor: """A find cursor proxy that creates the remote cursor when initialized.""" def __init__(self, find_cursor, client): @@ -289,7 +289,7 @@ class EventListenerUtil(CMAPListener, CommandListener, ServerListener): def __init__( self, observe_events, ignore_commands, observe_sensitive_commands, store_events, entity_map ): - self._event_types = set(name.lower() for name in observe_events) + self._event_types = {name.lower() for name in observe_events} if observe_sensitive_commands: self._observe_sensitive_commands = True self._ignore_commands = set(ignore_commands) @@ -306,7 +306,7 @@ def __init__( for i in events: self._event_mapping[i].append(id) self.entity_map[id] = [] - super(EventListenerUtil, self).__init__() + super().__init__() def get_events(self, event_type): assert event_type in ("command", "cmap", "sdam", "all"), event_type @@ -321,7 +321,7 @@ def get_events(self, event_type): def add_event(self, event): event_name = type(event).__name__.lower() if event_name in self._event_types: - super(EventListenerUtil, self).add_event(event) + super().add_event(event) for id in self._event_mapping[event_name]: self.entity_map[id].append( { @@ -332,7 +332,7 @@ def add_event(self, event): ) def _command_event(self, event): - if not event.command_name.lower() in self._ignore_commands: + if event.command_name.lower() not in self._ignore_commands: self.add_event(event) def started(self, event): @@ -364,9 +364,10 @@ def closed(self, event: ServerClosedEvent) -> None: self.add_event(event) -class EntityMapUtil(object): +class EntityMapUtil: """Utility class that implements an entity map as per the unified - test format specification.""" + test format specification. + """ def __init__(self, test_class): self._entities: Dict[str, Any] = {} @@ -384,14 +385,14 @@ def __getitem__(self, item): try: return self._entities[item] except KeyError: - self.test.fail("Could not find entity named %s in map" % (item,)) + self.test.fail(f"Could not find entity named {item} in map") def __setitem__(self, key, value): if not isinstance(key, str): self.test.fail("Expected entity name of type str, got %s" % (type(key))) if key in self._entities: - self.test.fail("Entity named %s already in map" % (key,)) + self.test.fail(f"Entity named {key} already in map") self._entities[key] = value @@ -410,9 +411,7 @@ def _handle_placeholders(self, spec: dict, current: dict, path: str) -> Any: def _create_entity(self, entity_spec, uri=None): if len(entity_spec) != 1: - self.test.fail( - "Entity spec %s did not contain exactly one top-level key" % (entity_spec,) - ) + self.test.fail(f"Entity spec {entity_spec} did not contain exactly one top-level key") entity_type, spec = next(iter(entity_spec.items())) spec = self._handle_placeholders(spec, spec, "") @@ -454,8 +453,9 @@ def _create_entity(self, entity_spec, uri=None): client = self[spec["client"]] if not isinstance(client, MongoClient): self.test.fail( - "Expected entity %s to be of type MongoClient, got %s" - % (spec["client"], type(client)) + "Expected entity {} to be of type MongoClient, got {}".format( + spec["client"], type(client) + ) ) options = parse_collection_or_database_options(spec.get("databaseOptions", {})) self[spec["id"]] = client.get_database(spec["databaseName"], **options) @@ -464,8 +464,9 @@ def _create_entity(self, entity_spec, uri=None): database = self[spec["database"]] if not isinstance(database, Database): self.test.fail( - "Expected entity %s to be of type Database, got %s" - % (spec["database"], type(database)) + "Expected entity {} to be of type Database, got {}".format( + spec["database"], type(database) + ) ) options = parse_collection_or_database_options(spec.get("collectionOptions", {})) self[spec["id"]] = database.get_collection(spec["collectionName"], **options) @@ -474,8 +475,9 @@ def _create_entity(self, entity_spec, uri=None): client = self[spec["client"]] if not isinstance(client, MongoClient): self.test.fail( - "Expected entity %s to be of type MongoClient, got %s" - % (spec["client"], type(client)) + "Expected entity {} to be of type MongoClient, got {}".format( + spec["client"], type(client) + ) ) opts = camel_to_snake_args(spec.get("sessionOptions", {})) if "default_transaction_options" in opts: @@ -522,7 +524,7 @@ def drop(self: GridFSBucket, *args: Any, **kwargs: Any) -> None: self[name] = thread return - self.test.fail("Unable to create entity of unknown type %s" % (entity_type,)) + self.test.fail(f"Unable to create entity of unknown type {entity_type}") def create_entities_from_spec(self, entity_spec, uri=None): for spec in entity_spec: @@ -532,12 +534,12 @@ def get_listener_for_client(self, client_name: str) -> EventListenerUtil: client = self[client_name] if not isinstance(client, MongoClient): self.test.fail( - "Expected entity %s to be of type MongoClient, got %s" % (client_name, type(client)) + f"Expected entity {client_name} to be of type MongoClient, got {type(client)}" ) listener = self._listeners.get(client_name) if not listener: - self.test.fail("No listeners configured for client %s" % (client_name,)) + self.test.fail(f"No listeners configured for client {client_name}") return listener @@ -545,8 +547,7 @@ def get_lsid_for_session(self, session_name): session = self[session_name] if not isinstance(session, ClientSession): self.test.fail( - "Expected entity %s to be of type ClientSession, got %s" - % (session_name, type(session)) + f"Expected entity {session_name} to be of type ClientSession, got {type(session)}" ) try: @@ -587,9 +588,10 @@ def get_lsid_for_session(self, session_name): } -class MatchEvaluatorUtil(object): +class MatchEvaluatorUtil: """Utility class that implements methods for evaluating matches as per - the unified test format specification.""" + the unified test format specification. + """ def __init__(self, test_class): self.test = test_class @@ -606,11 +608,11 @@ def _operation_exists(self, spec, actual, key_to_compare): else: self.test.assertNotIn(key_to_compare, actual) else: - self.test.fail("Expected boolean value for $$exists operator, got %s" % (spec,)) + self.test.fail(f"Expected boolean value for $$exists operator, got {spec}") def __type_alias_to_type(self, alias): if alias not in BSON_TYPE_ALIAS_MAP: - self.test.fail("Unrecognized BSON type alias %s" % (alias,)) + self.test.fail(f"Unrecognized BSON type alias {alias}") return BSON_TYPE_ALIAS_MAP[alias] def _operation_type(self, spec, actual, key_to_compare): @@ -653,11 +655,11 @@ def _operation_lte(self, spec, actual, key_to_compare): self.test.assertLessEqual(actual[key_to_compare], spec) def _evaluate_special_operation(self, opname, spec, actual, key_to_compare): - method_name = "_operation_%s" % (opname.strip("$"),) + method_name = "_operation_{}".format(opname.strip("$")) try: method = getattr(self, method_name) except AttributeError: - self.test.fail("Unsupported special matching operator %s" % (opname,)) + self.test.fail(f"Unsupported special matching operator {opname}") else: method(spec, actual, key_to_compare) @@ -668,7 +670,8 @@ def _evaluate_if_special_operation(self, expectation, actual, key_to_compare=Non If given, ``key_to_compare`` is assumed to be the key in ``expectation`` whose corresponding value needs to be evaluated for a possible special operation. ``key_to_compare`` - is ignored when ``expectation`` has only one key.""" + is ignored when ``expectation`` has only one key. + """ if not isinstance(expectation, abc.Mapping): return False @@ -730,14 +733,16 @@ def match_result(self, expectation, actual, in_recursive_call=False): self._match_document(e, a, is_root=not in_recursive_call) else: self.match_result(e, a, in_recursive_call=True) - return + return None # account for flexible numerics in element-wise comparison if isinstance(expectation, int) or isinstance(expectation, float): self.test.assertEqual(expectation, actual) + return None else: self.test.assertIsInstance(actual, type(expectation)) self.test.assertEqual(expectation, actual) + return None def assertHasServiceId(self, spec, actual): if "hasServiceId" in spec: @@ -828,7 +833,7 @@ def match_event(self, event_type, expectation, actual): if "newDescription" in spec: self.match_server_description(actual.new_description, spec["newDescription"]) else: - raise Exception("Unsupported event type %s" % (name,)) + raise Exception(f"Unsupported event type {name}") def coerce_result(opname, result): @@ -840,7 +845,7 @@ def coerce_result(opname, result): if opname == "insertOne": return {"insertedId": result.inserted_id} if opname == "insertMany": - return {idx: _id for idx, _id in enumerate(result.inserted_ids)} + return dict(enumerate(result.inserted_ids)) if opname in ("deleteOne", "deleteMany"): return {"deletedCount": result.deleted_count} if opname in ("updateOne", "updateMany", "replaceOne"): @@ -904,11 +909,11 @@ def insert_initial_data(self, initial_data): @classmethod def setUpClass(cls): # super call creates internal client cls.client - super(UnifiedSpecTestMixinV1, cls).setUpClass() + super().setUpClass() # process file-level runOnRequirements run_on_spec = cls.TEST_SPEC.get("runOnRequirements", []) if not cls.should_run_on(run_on_spec): - raise unittest.SkipTest("%s runOnRequirements not satisfied" % (cls.__name__,)) + raise unittest.SkipTest(f"{cls.__name__} runOnRequirements not satisfied") # add any special-casing for skipping tests here if client_context.storage_engine == "mmapv1": @@ -916,7 +921,7 @@ def setUpClass(cls): raise unittest.SkipTest("MMAPv1 does not support retryWrites=True") def setUp(self): - super(UnifiedSpecTestMixinV1, self).setUp() + super().setUp() # process schemaVersion # note: we check major schema version during class generation # note: we do this here because we cannot run assertions in setUpClass @@ -924,7 +929,7 @@ def setUp(self): self.assertLessEqual( version, self.SCHEMA_VERSION, - "expected schema version %s or lower, got %s" % (self.SCHEMA_VERSION, version), + f"expected schema version {self.SCHEMA_VERSION} or lower, got {version}", ) # initialize internals @@ -1044,20 +1049,18 @@ def process_error(self, exception, spec): if error_labels_omit: for err_label in error_labels_omit: if exception.has_error_label(err_label): - self.fail("Exception '%s' unexpectedly had label '%s'" % (exception, err_label)) + self.fail(f"Exception '{exception}' unexpectedly had label '{err_label}'") if expect_result: if isinstance(exception, BulkWriteError): result = parse_bulk_write_error_result(exception) self.match_evaluator.match_result(expect_result, result) else: - self.fail( - "expectResult can only be specified with %s exceptions" % (BulkWriteError,) - ) + self.fail(f"expectResult can only be specified with {BulkWriteError} exceptions") def __raise_if_unsupported(self, opname, target, *target_types): if not isinstance(target, target_types): - self.fail("Operation %s not supported for entity of type %s" % (opname, type(target))) + self.fail(f"Operation {opname} not supported for entity of type {type(target)}") def __entityOperation_createChangeStream(self, target, *args, **kwargs): if client_context.storage_engine == "mmapv1": @@ -1153,6 +1156,7 @@ def _cursor_iterateUntilDocumentOrError(self, target, *args, **kwargs): return next(target) except StopIteration: pass + return None def _cursor_close(self, target, *args, **kwargs): self.__raise_if_unsupported("close", target, NonLazyCursor) @@ -1182,8 +1186,8 @@ def _clientEncryptionOperation_rewrapManyDataKey(self, target, *args, **kwargs): kwargs["master_key"] = opts.get("masterKey") data = target.rewrap_many_data_key(*args, **kwargs) if data.bulk_write_result: - return dict(bulkWriteResult=parse_bulk_write_result(data.bulk_write_result)) - return dict() + return {"bulkWriteResult": parse_bulk_write_result(data.bulk_write_result)} + return {} def _bucketOperation_download(self, target: GridFSBucket, *args: Any, **kwargs: Any) -> bytes: with target.open_download_stream(*args, **kwargs) as gout: @@ -1234,30 +1238,30 @@ def run_entity_operation(self, spec): arguments = {} if isinstance(target, MongoClient): - method_name = "_clientOperation_%s" % (opname,) + method_name = f"_clientOperation_{opname}" elif isinstance(target, Database): - method_name = "_databaseOperation_%s" % (opname,) + method_name = f"_databaseOperation_{opname}" elif isinstance(target, Collection): - method_name = "_collectionOperation_%s" % (opname,) + method_name = f"_collectionOperation_{opname}" # contentType is always stored in metadata in pymongo. if target.name.endswith(".files") and opname == "find": for doc in spec.get("expectResult", []): if "contentType" in doc: doc.setdefault("metadata", {})["contentType"] = doc.pop("contentType") elif isinstance(target, ChangeStream): - method_name = "_changeStreamOperation_%s" % (opname,) + method_name = f"_changeStreamOperation_{opname}" elif isinstance(target, NonLazyCursor): - method_name = "_cursor_%s" % (opname,) + method_name = f"_cursor_{opname}" elif isinstance(target, ClientSession): - method_name = "_sessionOperation_%s" % (opname,) + method_name = f"_sessionOperation_{opname}" elif isinstance(target, GridFSBucket): - method_name = "_bucketOperation_%s" % (opname,) + method_name = f"_bucketOperation_{opname}" if "id" in arguments: arguments["file_id"] = arguments.pop("id") # MD5 is always disabled in pymongo. arguments.pop("disable_md5", None) elif isinstance(target, ClientEncryption): - method_name = "_clientEncryptionOperation_%s" % (opname,) + method_name = f"_clientEncryptionOperation_{opname}" else: method_name = "doesNotExist" @@ -1270,7 +1274,7 @@ def run_entity_operation(self, spec): try: cmd = getattr(target, target_opname) except AttributeError: - self.fail("Unsupported operation %s on entity %s" % (opname, target)) + self.fail(f"Unsupported operation {opname} on entity {target}") else: cmd = functools.partial(method, target) @@ -1286,15 +1290,13 @@ def run_entity_operation(self, spec): # Ignore all operation errors but to avoid masking bugs don't # ignore things like TypeError and ValueError. if ignore and isinstance(exc, (PyMongoError,)): - return + return None if expect_error: return self.process_error(exc, expect_error) raise else: if expect_error: - self.fail( - 'Excepted error %s but "%s" succeeded: %s' % (expect_error, opname, result) - ) + self.fail(f'Excepted error {expect_error} but "{opname}" succeeded: {result}') if expect_result: actual = coerce_result(opname, result) @@ -1302,6 +1304,8 @@ def run_entity_operation(self, spec): if save_as_entity: self.entity_map[save_as_entity] = result + return None + return None def __set_fail_point(self, client, command_args): if not client_context.test_commands_enabled: @@ -1324,10 +1328,10 @@ def _testOperation_targetedFailPoint(self, spec): if not session._pinned_address: self.fail( "Cannot use targetedFailPoint operation with unpinned " - "session %s" % (spec["session"],) + "session {}".format(spec["session"]) ) - client = single_client("%s:%s" % session._pinned_address) + client = single_client("{}:{}".format(*session._pinned_address)) self.addCleanup(client.close) self.__set_fail_point(client=client, command_args=spec["failPoint"]) @@ -1422,9 +1426,7 @@ def _testOperation_assertEventCount(self, spec): Assert the given event was published exactly `count` times. """ client, event, count = spec["client"], spec["event"], spec["count"] - self.assertEqual( - self._event_count(client, event), count, "expected %s not %r" % (count, event) - ) + self.assertEqual(self._event_count(client, event), count, f"expected {count} not {event!r}") def _testOperation_waitForEvent(self, spec): """Run the waitForEvent test operation. @@ -1434,7 +1436,7 @@ def _testOperation_waitForEvent(self, spec): client, event, count = spec["client"], spec["event"], spec["count"] wait_until( lambda: self._event_count(client, event) >= count, - "find %s %s event(s)" % (count, event), + f"find {count} {event} event(s)", ) def _testOperation_wait(self, spec): @@ -1485,7 +1487,7 @@ def _testOperation_waitForThread(self, spec): thread.join(10) if thread.exc: raise thread.exc - self.assertFalse(thread.is_alive(), "Thread %s is still running" % (spec["thread"],)) + self.assertFalse(thread.is_alive(), "Thread {} is still running".format(spec["thread"])) def _testOperation_loop(self, spec): failure_key = spec.get("storeFailuresAsEntity") @@ -1527,11 +1529,11 @@ def _testOperation_loop(self, spec): def run_special_operation(self, spec): opname = spec["name"] - method_name = "_testOperation_%s" % (opname,) + method_name = f"_testOperation_{opname}" try: method = getattr(self, method_name) except AttributeError: - self.fail("Unsupported special test operation %s" % (opname,)) + self.fail(f"Unsupported special test operation {opname}") else: method(spec["arguments"]) @@ -1604,8 +1606,10 @@ def run_scenario(self, spec, uri=None): self.setUp() continue raise + return None else: self._run_scenario(spec, uri) + return None def _run_scenario(self, spec, uri=None): # maybe skip test manually @@ -1619,7 +1623,7 @@ def _run_scenario(self, spec, uri=None): # process skipReason skip_reason = spec.get("skipReason", None) if skip_reason is not None: - raise unittest.SkipTest("%s" % (skip_reason,)) + raise unittest.SkipTest(f"{skip_reason}") # process createEntities self._uri = uri @@ -1648,7 +1652,7 @@ class UnifiedSpecTestMeta(type): EXPECTED_FAILURES: Any def __init__(cls, *args, **kwargs): - super(UnifiedSpecTestMeta, cls).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) def create_test(spec): def test_case(self): @@ -1658,7 +1662,9 @@ def test_case(self): for test_spec in cls.TEST_SPEC["tests"]: description = test_spec["description"] - test_name = "test_%s" % (description.strip(". ").replace(" ", "_").replace(".", "_"),) + test_name = "test_{}".format( + description.strip(". ").replace(" ", "_").replace(".", "_") + ) test_method = create_test(copy.deepcopy(test_spec)) test_method.__name__ = str(test_name) @@ -1690,13 +1696,15 @@ def generate_test_classes( **kwargs, ): """Method for generating test classes. Returns a dictionary where keys are - the names of test classes and values are the test class objects.""" + the names of test classes and values are the test class objects. + """ test_klasses = {} def test_base_class_factory(test_spec): """Utility that creates the base class to use for test generation. This is needed to ensure that cls.TEST_SPEC is appropriately set when - the metaclass __init__ is invoked.""" + the metaclass __init__ is invoked. + """ class SpecTestBase(with_metaclass(UnifiedSpecTestMeta)): # type: ignore TEST_SPEC = test_spec @@ -1716,7 +1724,7 @@ class SpecTestBase(with_metaclass(UnifiedSpecTestMeta)): # type: ignore scenario_def = json_util.loads(scenario_stream.read(), json_options=opts) test_type = os.path.splitext(filename)[0] - snake_class_name = "Test%s_%s_%s" % ( + snake_class_name = "Test{}_{}_{}".format( class_name_prefix, dirname.replace("-", "_"), test_type.replace("-", "_").replace(".", "_"), @@ -1728,8 +1736,7 @@ class SpecTestBase(with_metaclass(UnifiedSpecTestMeta)): # type: ignore mixin_class = _SCHEMA_VERSION_MAJOR_TO_MIXIN_CLASS.get(schema_version[0]) if mixin_class is None: raise ValueError( - "test file '%s' has unsupported schemaVersion '%s'" - % (fpath, schema_version) + f"test file '{fpath}' has unsupported schemaVersion '{schema_version}'" ) module_dict = {"__module__": module} module_dict.update(kwargs) diff --git a/test/utils.py b/test/utils.py index b39375925c..810a02b872 100644 --- a/test/utils.py +++ b/test/utils.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Utilities for testing pymongo -""" +"""Utilities for testing pymongo""" import contextlib import copy @@ -65,7 +64,7 @@ IMPOSSIBLE_WRITE_CONCERN = WriteConcern(w=50) -class BaseListener(object): +class BaseListener: def __init__(self): self.events = [] @@ -91,7 +90,7 @@ def matching(self, matcher): def wait_for_event(self, event, count): """Wait for a number of events to be published, or fail.""" - wait_until(lambda: self.event_count(event) >= count, "find %s %s event(s)" % (count, event)) + wait_until(lambda: self.event_count(event) >= count, f"find {count} {event} event(s)") class CMAPListener(BaseListener, monitoring.ConnectionPoolListener): @@ -142,7 +141,7 @@ def pool_closed(self, event): class EventListener(BaseListener, monitoring.CommandListener): def __init__(self): - super(EventListener, self).__init__() + super().__init__() self.results = defaultdict(list) @property @@ -176,7 +175,7 @@ def started_command_names(self) -> List[str]: def reset(self) -> None: """Reset the state of this listener.""" self.results.clear() - super(EventListener, self).reset() + super().reset() class TopologyEventListener(monitoring.TopologyListener): @@ -200,19 +199,19 @@ def reset(self): class AllowListEventListener(EventListener): def __init__(self, *commands): self.commands = set(commands) - super(AllowListEventListener, self).__init__() + super().__init__() def started(self, event): if event.command_name in self.commands: - super(AllowListEventListener, self).started(event) + super().started(event) def succeeded(self, event): if event.command_name in self.commands: - super(AllowListEventListener, self).succeeded(event) + super().succeeded(event) def failed(self, event): if event.command_name in self.commands: - super(AllowListEventListener, self).failed(event) + super().failed(event) class OvertCommandListener(EventListener): @@ -222,18 +221,18 @@ class OvertCommandListener(EventListener): def started(self, event): if event.command_name.lower() not in _SENSITIVE_COMMANDS: - super(OvertCommandListener, self).started(event) + super().started(event) def succeeded(self, event): if event.command_name.lower() not in _SENSITIVE_COMMANDS: - super(OvertCommandListener, self).succeeded(event) + super().succeeded(event) def failed(self, event): if event.command_name.lower() not in _SENSITIVE_COMMANDS: - super(OvertCommandListener, self).failed(event) + super().failed(event) -class _ServerEventListener(object): +class _ServerEventListener: """Listens to all events.""" def __init__(self): @@ -280,7 +279,7 @@ def failed(self, event): self.add_event(event) -class MockSocketInfo(object): +class MockSocketInfo: def __init__(self): self.cancel_context = _CancellationContext() self.more_to_come = False @@ -295,7 +294,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): pass -class MockPool(object): +class MockPool: def __init__(self, address, options, handshake=True): self.gen = _PoolGeneration() self._lock = _create_lock() @@ -357,7 +356,7 @@ def __getitem__(self, item): return ScenarioDict({}) -class CompareType(object): +class CompareType: """Class that compares equal to any object of the given type(s).""" def __init__(self, types): @@ -367,7 +366,7 @@ def __eq__(self, other): return isinstance(other, self.types) -class FunctionCallRecorder(object): +class FunctionCallRecorder: """Utility class to wrap a callable and record its invocations.""" def __init__(self, function): @@ -392,7 +391,7 @@ def call_count(self): return len(self._call_list) -class TestCreator(object): +class TestCreator: """Class to create test cases from specifications.""" def __init__(self, create_test, test_class, test_path): @@ -415,7 +414,8 @@ def __init__(self, create_test, test_class, test_path): def _ensure_min_max_server_version(self, scenario_def, method): """Test modifier that enforces a version range for the server on a - test case.""" + test case. + """ if "minServerVersion" in scenario_def: min_ver = tuple(int(elt) for elt in scenario_def["minServerVersion"].split(".")) if min_ver is not None: @@ -524,7 +524,7 @@ def create_tests(self): # Construct test from scenario. for test_def in self.tests(scenario_def): - test_name = "test_%s_%s_%s" % ( + test_name = "test_{}_{}_{}".format( dirname, test_type.replace("-", "_").replace(".", "_"), str(test_def["description"].replace(" ", "_").replace(".", "_")), @@ -539,9 +539,9 @@ def create_tests(self): def _connection_string(h): - if h.startswith("mongodb://") or h.startswith("mongodb+srv://"): + if h.startswith(("mongodb://", "mongodb+srv://")): return h - return "mongodb://%s" % (str(h),) + return f"mongodb://{str(h)}" def _mongo_client(host, port, authenticate=True, directConnection=None, **kwargs): @@ -620,7 +620,7 @@ def ensure_all_connected(client: MongoClient) -> None: raise ConfigurationError("cluster is not a replica set") target_host_list = set(hello["hosts"] + hello.get("passives", [])) - connected_host_list = set([hello["me"]]) + connected_host_list = {hello["me"]} # Run hello until we have connected to each host at least once. def discover(): @@ -821,7 +821,7 @@ def assertRaisesExactly(cls, fn, *args, **kwargs): try: fn(*args, **kwargs) except Exception as e: - assert e.__class__ == cls, "got %s, expected %s" % (e.__class__.__name__, cls.__name__) + assert e.__class__ == cls, f"got {e.__class__.__name__}, expected {cls.__name__}" else: raise AssertionError("%s not raised" % cls) @@ -848,7 +848,7 @@ def wrapper(*args, **kwargs): return _ignore_deprecations() -class DeprecationFilter(object): +class DeprecationFilter: def __init__(self, action="ignore"): """Start filtering deprecations.""" self.warn_context = warnings.catch_warnings() @@ -922,7 +922,7 @@ def lazy_client_trial(reset, target, test, get_client): collection = client_context.client.pymongo_test.test with frequent_thread_switches(): - for i in range(NTRIALS): + for _i in range(NTRIALS): reset(collection) lazy_client = get_client() lazy_collection = lazy_client.pymongo_test.test @@ -972,11 +972,11 @@ class ExceptionCatchingThread(threading.Thread): def __init__(self, *args, **kwargs): self.exc = None - super(ExceptionCatchingThread, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) def run(self): try: - super(ExceptionCatchingThread, self).run() + super().run() except BaseException as exc: self.exc = exc raise @@ -1147,6 +1147,6 @@ def prepare_spec_arguments(spec, arguments, opname, entity_map, with_txn_callbac elif cursor_type == "tailableAwait": arguments["cursor_type"] = CursorType.TAILABLE else: - assert False, f"Unsupported cursorType: {cursor_type}" + raise AssertionError(f"Unsupported cursorType: {cursor_type}") else: arguments[c2s] = arguments.pop(arg_name) diff --git a/test/utils_selection_tests.py b/test/utils_selection_tests.py index e693fc25f0..ccb3897966 100644 --- a/test/utils_selection_tests.py +++ b/test/utils_selection_tests.py @@ -109,9 +109,11 @@ def get_topology_type_name(scenario_def): def get_topology_settings_dict(**kwargs): - settings = dict( - monitor_class=DummyMonitor, heartbeat_frequency=HEARTBEAT_FREQUENCY, pool_class=MockPool - ) + settings = { + "monitor_class": DummyMonitor, + "heartbeat_frequency": HEARTBEAT_FREQUENCY, + "pool_class": MockPool, + } settings.update(kwargs) return settings @@ -255,7 +257,7 @@ class TestAllScenarios(unittest.TestCase): # Construct test from scenario. new_test = create_test(scenario_def) - test_name = "test_%s_%s" % (dirname, os.path.splitext(filename)[0]) + test_name = f"test_{dirname}_{os.path.splitext(filename)[0]}" new_test.__name__ = test_name setattr(TestAllScenarios, new_test.__name__, new_test) diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 6530f39da6..4ca6f1cc58 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -49,7 +49,7 @@ class SpecRunnerThread(threading.Thread): def __init__(self, name): - super(SpecRunnerThread, self).__init__() + super().__init__() self.name = name self.exc = None self.daemon = True @@ -88,7 +88,7 @@ class SpecRunner(IntegrationTest): @classmethod def setUpClass(cls): - super(SpecRunner, cls).setUpClass() + super().setUpClass() cls.mongos_clients = [] # Speed up the tests by decreasing the heartbeat frequency. @@ -98,10 +98,10 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): cls.knobs.disable() - super(SpecRunner, cls).tearDownClass() + super().tearDownClass() def setUp(self): - super(SpecRunner, self).setUp() + super().setUp() self.targets = {} self.listener = None # type: ignore self.pool_listener = None @@ -170,7 +170,7 @@ def assertErrorLabelsContain(self, exc, expected_labels): def assertErrorLabelsOmit(self, exc, omit_labels): for label in omit_labels: self.assertFalse( - exc.has_error_label(label), msg="error labels should not contain %s" % (label,) + exc.has_error_label(label), msg=f"error labels should not contain {label}" ) def kill_all_sessions(self): @@ -242,6 +242,7 @@ def _helper(expected_result, result): self.assertEqual(expected_result, result) _helper(expected_result, result) + return None def get_object_name(self, op): """Allow subclasses to override handling of 'object' @@ -335,7 +336,7 @@ def _run_op(self, sessions, collection, op, in_with_transaction): expected_result = op.get("result") if expect_error(op): with self.assertRaises(self.allowable_errors(op), msg=op["name"]) as context: - out = self.run_operation(sessions, collection, op.copy()) + self.run_operation(sessions, collection, op.copy()) exc = context.exception if expect_error_message(expected_result): if isinstance(exc, BulkWriteError): @@ -425,9 +426,9 @@ def check_events(self, test, listener, session_ids): for key, val in expected.items(): if val is None: if key in actual: - self.fail("Unexpected key [%s] in %r" % (key, actual)) + self.fail(f"Unexpected key [{key}] in {actual!r}") elif key not in actual: - self.fail("Expected key [%s] in %r" % (key, actual)) + self.fail(f"Expected key [{key}] in {actual!r}") else: # Workaround an incorrect command started event in fle2v2-CreateCollection.yml # added in DRIVERS-2524. @@ -436,7 +437,7 @@ def check_events(self, test, listener, session_ids): if val.get(n) is None: val.pop(n, None) self.assertEqual( - val, decode_raw(actual[key]), "Key [%s] in %s" % (key, actual) + val, decode_raw(actual[key]), f"Key [{key}] in {actual}" ) else: self.assertEqual(actual, expected) @@ -459,7 +460,8 @@ def get_outcome_coll_name(self, outcome, collection): def run_test_ops(self, sessions, collection, test): """Added to allow retryable writes spec to override a test's - operation.""" + operation. + """ self.run_operations(sessions, collection, test["operations"]) def parse_client_options(self, opts): diff --git a/test/version.py b/test/version.py index e102db7111..1dd1bec5f9 100644 --- a/test/version.py +++ b/test/version.py @@ -18,7 +18,7 @@ class Version(tuple): def __new__(cls, *version): padded_version = cls._padded(version, 4) - return super(Version, cls).__new__(cls, tuple(padded_version)) + return super().__new__(cls, tuple(padded_version)) @classmethod def _padded(cls, iter, length, padding=0):