diff --git a/mypyc/analysis/ircheck.py b/mypyc/analysis/ircheck.py index 6980c9cee419..746a393273a3 100644 --- a/mypyc/analysis/ircheck.py +++ b/mypyc/analysis/ircheck.py @@ -64,6 +64,7 @@ RUnion, bytes_rprimitive, dict_rprimitive, + exact_dict_rprimitive, int_rprimitive, is_float_rprimitive, is_object_rprimitive, @@ -177,6 +178,7 @@ def check_op_sources_valid(fn: FuncIR) -> list[FnError]: int_rprimitive.name, bytes_rprimitive.name, str_rprimitive.name, + exact_dict_rprimitive.name, dict_rprimitive.name, list_rprimitive.name, set_rprimitive.name, @@ -197,7 +199,10 @@ def can_coerce_to(src: RType, dest: RType) -> bool: if isinstance(src, RPrimitive): # If either src or dest is a disjoint type, then they must both be. if src.name in disjoint_types and dest.name in disjoint_types: - return src.name == dest.name + return src.name == dest.name or ( + src.name in ("builtins.dict", "builtins.dict[exact]") + and dest.name in ("builtins.dict", "builtins.dict[exact]") + ) return src.size == dest.size if isinstance(src, RInstance): return is_object_rprimitive(dest) diff --git a/mypyc/annotate.py b/mypyc/annotate.py index bc282fc3ea6c..f6aa0f006151 100644 --- a/mypyc/annotate.py +++ b/mypyc/annotate.py @@ -216,7 +216,7 @@ def function_annotations(func_ir: FuncIR, tree: MypyFile) -> dict[int, list[Anno ann = "Dynamic method call." elif name in op_hints: ann = op_hints[name] - elif name in ("CPyDict_GetItem", "CPyDict_SetItem"): + elif name in ("CPyDict_GetItemUnsafe", "PyDict_SetItem"): if ( isinstance(op.args[0], LoadStatic) and isinstance(op.args[1], LoadLiteral) diff --git a/mypyc/ir/rtypes.py b/mypyc/ir/rtypes.py index 34824a59cd5c..82fa5c790a10 100644 --- a/mypyc/ir/rtypes.py +++ b/mypyc/ir/rtypes.py @@ -487,8 +487,15 @@ def __hash__(self) -> int: "builtins.list", is_unboxed=False, is_refcounted=True, may_be_immortal=False ) -# Python dict object (or an instance of a subclass of dict). +# Python dict object. +exact_dict_rprimitive: Final = RPrimitive( + "builtins.dict[exact]", is_unboxed=False, is_refcounted=True +) +"""A primitive for dicts that are confirmed to be actual instances of builtins.dict, not a subclass.""" + +# An instance of a subclass of dict. dict_rprimitive: Final = RPrimitive("builtins.dict", is_unboxed=False, is_refcounted=True) +"""A primitive that represents instances of builtins.dict or subclasses of dict.""" # Python set object (or an instance of a subclass of set). set_rprimitive: Final = RPrimitive("builtins.set", is_unboxed=False, is_refcounted=True) @@ -608,7 +615,14 @@ def is_list_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]: def is_dict_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]: - return isinstance(rtype, RPrimitive) and rtype.name == "builtins.dict" + return isinstance(rtype, RPrimitive) and rtype.name in ( + "builtins.dict", + "builtins.dict[exact]", + ) + + +def is_exact_dict_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]: + return isinstance(rtype, RPrimitive) and rtype.name == "builtins.dict[exact]" def is_set_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]: diff --git a/mypyc/irbuild/builder.py b/mypyc/irbuild/builder.py index 4f2f539118d7..8ba8cd1395ad 100644 --- a/mypyc/irbuild/builder.py +++ b/mypyc/irbuild/builder.py @@ -93,7 +93,7 @@ bitmap_rprimitive, bytes_rprimitive, c_pyssize_t_rprimitive, - dict_rprimitive, + exact_dict_rprimitive, int_rprimitive, is_float_rprimitive, is_list_rprimitive, @@ -125,7 +125,7 @@ ) from mypyc.irbuild.util import bytes_from_str, is_constant from mypyc.options import CompilerOptions -from mypyc.primitives.dict_ops import dict_get_item_op, dict_set_item_op +from mypyc.primitives.dict_ops import dict_set_item_op, exact_dict_get_item_op from mypyc.primitives.generic_ops import iter_op, next_op, py_setattr_op from mypyc.primitives.list_ops import list_get_item_unsafe_op, list_pop_last, to_list from mypyc.primitives.misc_ops import check_unpack_count_op, get_module_dict_op, import_op @@ -436,6 +436,8 @@ def add_to_non_ext_dict( ) -> None: # Add an attribute entry into the class dict of a non-extension class. key_unicode = self.load_str(key) + # must use `dict_set_item_op` instead of `exact_dict_set_item_op` because + # it breaks enums, and probably other stuff, if we take the fast path. self.primitive_op(dict_set_item_op, [non_ext.dict, key_unicode, val], line) # It's important that accessing class dictionary items from multiple threads @@ -471,7 +473,7 @@ def get_module(self, module: str, line: int) -> Value: # Python 3.7 has a nice 'PyImport_GetModule' function that we can't use :( mod_dict = self.call_c(get_module_dict_op, [], line) # Get module object from modules dict. - return self.primitive_op(dict_get_item_op, [mod_dict, self.load_str(module)], line) + return self.primitive_op(exact_dict_get_item_op, [mod_dict, self.load_str(module)], line) def get_module_attr(self, module: str, attr: str, line: int) -> Value: """Look up an attribute of a module without storing it in the local namespace. @@ -1388,10 +1390,10 @@ def load_global(self, expr: NameExpr) -> Value: def load_global_str(self, name: str, line: int) -> Value: _globals = self.load_globals_dict() reg = self.load_str(name) - return self.primitive_op(dict_get_item_op, [_globals, reg], line) + return self.primitive_op(exact_dict_get_item_op, [_globals, reg], line) def load_globals_dict(self) -> Value: - return self.add(LoadStatic(dict_rprimitive, "globals", self.module_name)) + return self.add(LoadStatic(exact_dict_rprimitive, "globals", self.module_name)) def load_module_attr_by_fullname(self, fullname: str, line: int) -> Value: module, _, name = fullname.rpartition(".") diff --git a/mypyc/irbuild/classdef.py b/mypyc/irbuild/classdef.py index 324b44b95dc4..7cb7a9c71efb 100644 --- a/mypyc/irbuild/classdef.py +++ b/mypyc/irbuild/classdef.py @@ -50,7 +50,7 @@ from mypyc.ir.rtypes import ( RType, bool_rprimitive, - dict_rprimitive, + exact_dict_rprimitive, is_none_rprimitive, is_object_rprimitive, is_optional_type, @@ -271,7 +271,7 @@ def finalize(self, ir: ClassIR) -> None: ) # Add the non-extension class to the dict - self.builder.call_c( + self.builder.primitive_op( exact_dict_set_item_op, [ self.builder.load_globals_dict(), @@ -487,7 +487,7 @@ def allocate_class(builder: IRBuilder, cdef: ClassDef) -> Value: builder.add(InitStatic(tp, cdef.name, builder.module_name, NAMESPACE_TYPE)) # Add it to the dict - builder.call_c( + builder.primitive_op( exact_dict_set_item_op, [builder.load_globals_dict(), builder.load_str(cdef.name), tp], cdef.line, @@ -611,7 +611,7 @@ def setup_non_ext_dict( py_hasattr_op, [metaclass, builder.load_str("__prepare__")], cdef.line ) - non_ext_dict = Register(dict_rprimitive) + non_ext_dict = Register(exact_dict_rprimitive) true_block, false_block, exit_block = BasicBlock(), BasicBlock(), BasicBlock() builder.add_bool_branch(has_prepare, true_block, false_block) @@ -674,7 +674,7 @@ def add_non_ext_class_attr_ann( typ = builder.add(LoadAddress(type_object_op.type, type_object_op.src, stmt.line)) key = builder.load_str(lvalue.name) - builder.call_c(exact_dict_set_item_op, [non_ext.anns, key, typ], stmt.line) + builder.primitive_op(exact_dict_set_item_op, [non_ext.anns, key, typ], stmt.line) def add_non_ext_class_attr( diff --git a/mypyc/irbuild/expression.py b/mypyc/irbuild/expression.py index 4409b1acff26..83d18f4479fa 100644 --- a/mypyc/irbuild/expression.py +++ b/mypyc/irbuild/expression.py @@ -100,7 +100,7 @@ ) from mypyc.irbuild.specialize import apply_function_specialization, apply_method_specialization from mypyc.primitives.bytes_ops import bytes_slice_op -from mypyc.primitives.dict_ops import dict_get_item_op, dict_new_op, exact_dict_set_item_op +from mypyc.primitives.dict_ops import dict_new_op, exact_dict_get_item_op, exact_dict_set_item_op from mypyc.primitives.generic_ops import iter_op, name_op from mypyc.primitives.list_ops import list_append_op, list_extend_op, list_slice_op from mypyc.primitives.misc_ops import ellipsis_op, get_module_dict_op, new_slice_op, type_op @@ -186,7 +186,7 @@ def transform_name_expr(builder: IRBuilder, expr: NameExpr) -> Value: # instead load the module separately on each access. mod_dict = builder.call_c(get_module_dict_op, [], expr.line) obj = builder.primitive_op( - dict_get_item_op, [mod_dict, builder.load_str(expr.node.fullname)], expr.line + exact_dict_get_item_op, [mod_dict, builder.load_str(expr.node.fullname)], expr.line ) return obj else: @@ -1081,7 +1081,7 @@ def transform_dictionary_comprehension(builder: IRBuilder, o: DictionaryComprehe def gen_inner_stmts() -> None: k = builder.accept(o.key) v = builder.accept(o.value) - builder.call_c(exact_dict_set_item_op, [builder.read(d), k, v], o.line) + builder.primitive_op(exact_dict_set_item_op, [builder.read(d), k, v], o.line) comprehension_helper(builder, loop_params, gen_inner_stmts, o.line) return builder.read(d) diff --git a/mypyc/irbuild/for_helpers.py b/mypyc/irbuild/for_helpers.py index 762b41866a05..43c72bbda6f5 100644 --- a/mypyc/irbuild/for_helpers.py +++ b/mypyc/irbuild/for_helpers.py @@ -47,6 +47,7 @@ c_pyssize_t_rprimitive, int_rprimitive, is_dict_rprimitive, + is_exact_dict_rprimitive, is_fixed_width_rtype, is_immutable_rprimitive, is_list_rprimitive, @@ -70,6 +71,11 @@ dict_next_key_op, dict_next_value_op, dict_value_iter_op, + exact_dict_check_size_op, + exact_dict_iter_fast_path_op, + exact_dict_next_item_op, + exact_dict_next_key_op, + exact_dict_next_value_op, ) from mypyc.primitives.exc_ops import no_err_occurred_op, propagate_if_error_op from mypyc.primitives.generic_ops import aiter_op, anext_op, iter_op, next_op @@ -416,8 +422,10 @@ def make_for_loop_generator( # Special case "for k in ". expr_reg = builder.accept(expr) target_type = builder.get_dict_key_type(expr) - - for_dict = ForDictionaryKeys(builder, index, body_block, loop_exit, line, nested) + for_loop_cls = ( + ForExactDictionaryKeys if is_exact_dict_rprimitive(rtyp) else ForDictionaryKeys + ) + for_dict = for_loop_cls(builder, index, body_block, loop_exit, line, nested) for_dict.init(expr_reg, target_type) return for_dict @@ -499,13 +507,22 @@ def make_for_loop_generator( for_dict_type: type[ForGenerator] | None = None if expr.callee.name == "keys": target_type = builder.get_dict_key_type(expr.callee.expr) - for_dict_type = ForDictionaryKeys + if is_exact_dict_rprimitive(rtype): + for_dict_type = ForExactDictionaryKeys + else: + for_dict_type = ForDictionaryKeys elif expr.callee.name == "values": target_type = builder.get_dict_value_type(expr.callee.expr) - for_dict_type = ForDictionaryValues + if is_exact_dict_rprimitive(rtype): + for_dict_type = ForExactDictionaryValues + else: + for_dict_type = ForDictionaryValues else: target_type = builder.get_dict_item_type(expr.callee.expr) - for_dict_type = ForDictionaryItems + if is_exact_dict_rprimitive(rtype): + for_dict_type = ForExactDictionaryItems + else: + for_dict_type = ForDictionaryItems for_dict_gen = for_dict_type(builder, index, body_block, loop_exit, line, nested) for_dict_gen.init(expr_reg, target_type) return for_dict_gen @@ -886,6 +903,7 @@ class ForDictionaryCommon(ForGenerator): dict_next_op: ClassVar[CFunctionDescription] dict_iter_op: ClassVar[CFunctionDescription] + dict_size_op: ClassVar[CFunctionDescription] = dict_check_size_op def need_cleanup(self) -> bool: # Technically, a dict subclass can raise an unrelated exception @@ -932,7 +950,7 @@ def gen_step(self) -> None: line = self.line # Technically, we don't need a new primitive for this, but it is simpler. builder.call_c( - dict_check_size_op, + self.dict_size_op, [builder.read(self.expr_target, line), builder.read(self.size, line)], line, ) @@ -1010,6 +1028,30 @@ def begin_body(self) -> None: builder.assign(target, rvalue, line) +class ForExactDictionaryKeys(ForDictionaryKeys): + """Generate optimized IR for a for loop over dictionary items without type checks.""" + + dict_next_op = exact_dict_next_key_op + dict_iter_op = exact_dict_iter_fast_path_op + dict_size_op = exact_dict_check_size_op + + +class ForExactDictionaryValues(ForDictionaryValues): + """Generate optimized IR for a for loop over dictionary items without type checks.""" + + dict_next_op = exact_dict_next_value_op + dict_iter_op = exact_dict_iter_fast_path_op + dict_size_op = exact_dict_check_size_op + + +class ForExactDictionaryItems(ForDictionaryItems): + """Generate optimized IR for a for loop over dictionary items without type checks.""" + + dict_next_op = exact_dict_next_item_op + dict_iter_op = exact_dict_iter_fast_path_op + dict_size_op = exact_dict_check_size_op + + class ForRange(ForGenerator): """Generate optimized IR for a for loop over an integer range.""" diff --git a/mypyc/irbuild/function.py b/mypyc/irbuild/function.py index f0fc424aea54..edd4e15e9a7b 100644 --- a/mypyc/irbuild/function.py +++ b/mypyc/irbuild/function.py @@ -56,7 +56,7 @@ from mypyc.ir.rtypes import ( RInstance, bool_rprimitive, - dict_rprimitive, + exact_dict_rprimitive, int_rprimitive, object_rprimitive, ) @@ -77,8 +77,8 @@ from mypyc.irbuild.generator import gen_generator_func, gen_generator_func_body from mypyc.irbuild.targets import AssignmentTarget from mypyc.primitives.dict_ops import ( - dict_get_method_with_none, dict_new_op, + exact_dict_get_method_with_none, exact_dict_set_item_op, ) from mypyc.primitives.generic_ops import py_setattr_op @@ -127,7 +127,7 @@ def transform_decorator(builder: IRBuilder, dec: Decorator) -> None: if decorated_func is not None: # Set the callable object representing the decorated function as a global. - builder.call_c( + builder.primitive_op( exact_dict_set_item_op, [builder.load_globals_dict(), builder.load_str(dec.func.name), decorated_func], decorated_func.line, @@ -814,10 +814,12 @@ def generate_singledispatch_dispatch_function( arg_type = builder.builder.get_type_of_obj(arg_info.args[0], line) dispatch_cache = builder.builder.get_attr( - dispatch_func_obj, "dispatch_cache", dict_rprimitive, line + dispatch_func_obj, "dispatch_cache", exact_dict_rprimitive, line ) call_find_impl, use_cache, call_func = BasicBlock(), BasicBlock(), BasicBlock() - get_result = builder.primitive_op(dict_get_method_with_none, [dispatch_cache, arg_type], line) + get_result = builder.primitive_op( + exact_dict_get_method_with_none, [dispatch_cache, arg_type], line + ) is_not_none = builder.translate_is_op(get_result, builder.none_object(), "is not", line) impl_to_use = Register(object_rprimitive) builder.add_bool_branch(is_not_none, use_cache, call_find_impl) @@ -830,7 +832,7 @@ def generate_singledispatch_dispatch_function( find_impl = builder.load_module_attr_by_fullname("functools._find_impl", line) registry = load_singledispatch_registry(builder, dispatch_func_obj, line) uncached_impl = builder.py_call(find_impl, [arg_type, registry], line) - builder.call_c(exact_dict_set_item_op, [dispatch_cache, arg_type, uncached_impl], line) + builder.primitive_op(exact_dict_set_item_op, [dispatch_cache, arg_type, uncached_impl], line) builder.assign(impl_to_use, uncached_impl, line) builder.goto(call_func) @@ -894,8 +896,8 @@ def gen_dispatch_func_ir( """ builder.enter(FuncInfo(fitem, dispatch_name)) setup_callable_class(builder) - builder.fn_info.callable_class.ir.attributes["registry"] = dict_rprimitive - builder.fn_info.callable_class.ir.attributes["dispatch_cache"] = dict_rprimitive + builder.fn_info.callable_class.ir.attributes["registry"] = exact_dict_rprimitive + builder.fn_info.callable_class.ir.attributes["dispatch_cache"] = exact_dict_rprimitive builder.fn_info.callable_class.ir.has_dict = True builder.fn_info.callable_class.ir.needs_getseters = True generate_singledispatch_callable_class_ctor(builder) @@ -958,7 +960,7 @@ def add_register_method_to_callable_class(builder: IRBuilder, fn_info: FuncInfo) def load_singledispatch_registry(builder: IRBuilder, dispatch_func_obj: Value, line: int) -> Value: - return builder.builder.get_attr(dispatch_func_obj, "registry", dict_rprimitive, line) + return builder.builder.get_attr(dispatch_func_obj, "registry", exact_dict_rprimitive, line) def singledispatch_main_func_name(orig_name: str) -> str: @@ -1007,9 +1009,9 @@ def maybe_insert_into_registry_dict(builder: IRBuilder, fitem: FuncDef) -> None: registry = load_singledispatch_registry(builder, dispatch_func_obj, line) for typ in types: loaded_type = load_type(builder, typ, None, line) - builder.call_c(exact_dict_set_item_op, [registry, loaded_type, to_insert], line) + builder.primitive_op(exact_dict_set_item_op, [registry, loaded_type, to_insert], line) dispatch_cache = builder.builder.get_attr( - dispatch_func_obj, "dispatch_cache", dict_rprimitive, line + dispatch_func_obj, "dispatch_cache", exact_dict_rprimitive, line ) builder.gen_method_call(dispatch_cache, "clear", [], None, line) diff --git a/mypyc/irbuild/ll_builder.py b/mypyc/irbuild/ll_builder.py index 4b85c13892c1..343878714ef0 100644 --- a/mypyc/irbuild/ll_builder.py +++ b/mypyc/irbuild/ll_builder.py @@ -94,11 +94,14 @@ c_pyssize_t_rprimitive, c_size_t_rprimitive, check_native_int_range, + dict_rprimitive, + exact_dict_rprimitive, float_rprimitive, int_rprimitive, is_bool_or_bit_rprimitive, is_bytes_rprimitive, is_dict_rprimitive, + is_exact_dict_rprimitive, is_fixed_width_rtype, is_float_rprimitive, is_frozenset_rprimitive, @@ -134,6 +137,9 @@ dict_new_op, dict_ssize_t_size_op, dict_update_in_display_op, + exact_dict_copy_op, + exact_dict_ssize_t_size_op, + exact_dict_update_in_display_op, ) from mypyc.primitives.exc_ops import err_occurred_op, keep_propagating_op from mypyc.primitives.float_ops import copysign_op, int_to_float_op @@ -168,6 +174,7 @@ bool_op, buf_init_item, debug_print_op, + dict_is_true_op, fast_isinstance_op, none_object_op, not_implemented_op, @@ -852,8 +859,11 @@ def _construct_varargs( ) star2_result = self._create_dict(star2_keys, star2_values, line) - - self.call_c(dict_update_in_display_op, [star2_result, value], line=line) + if is_exact_dict_rprimitive(value.type): + op = exact_dict_update_in_display_op + else: + op = dict_update_in_display_op + self.call_c(op, [star2_result, value], line=line) else: nullable = kind.is_optional() maybe_pos = kind.is_positional() and has_star @@ -1815,9 +1825,18 @@ def make_dict(self, key_value_pairs: Sequence[DictEntry], line: int) -> Value: else: # **value if result is None: + if len(key_value_pairs) == 1 and is_exact_dict_rprimitive(value.type): + # fast path for cases like `my_func(**dict(zip(iterable, other)))` and similar + return self.call_c(exact_dict_copy_op, [value], line=line) + result = self._create_dict(keys, values, line) - self.call_c(dict_update_in_display_op, [result, value], line=line) + if is_exact_dict_rprimitive(value.type): + op = exact_dict_update_in_display_op + else: + op = dict_update_in_display_op + + self.call_c(op, [result, value], line=line) if result is None: result = self._create_dict(keys, values, line) @@ -1921,10 +1940,12 @@ def bool_value(self, value: Value) -> Value: result = self.add(ComparisonOp(value, zero, ComparisonOp.NEQ)) elif is_str_rprimitive(value.type): result = self.call_c(str_check_if_true, [value], value.line) + elif is_same_type(value.type, exact_dict_rprimitive): + result = self.primitive_op(dict_is_true_op, [value], line=value.line) elif ( - is_list_rprimitive(value.type) - or is_dict_rprimitive(value.type) - or is_tuple_rprimitive(value.type) + is_same_type(value.type, list_rprimitive) + or is_same_type(value.type, dict_rprimitive) + or is_same_type(value.type, tuple_rprimitive) ): length = self.builtin_len(value, value.line) zero = Integer(0) @@ -2428,6 +2449,8 @@ def builtin_len(self, val: Value, line: int, use_pyssize_t: bool = False) -> Val elem_address = self.add(GetElementPtr(val, PySetObject, "used")) size_value = self.load_mem(elem_address, c_pyssize_t_rprimitive) self.add(KeepAlive([val])) + elif is_exact_dict_rprimitive(typ): + size_value = self.call_c(exact_dict_ssize_t_size_op, [val], line) elif is_dict_rprimitive(typ): size_value = self.call_c(dict_ssize_t_size_op, [val], line) elif is_str_rprimitive(typ): diff --git a/mypyc/irbuild/mapper.py b/mypyc/irbuild/mapper.py index 05aa0e45c569..248ab7d1046f 100644 --- a/mypyc/irbuild/mapper.py +++ b/mypyc/irbuild/mapper.py @@ -33,6 +33,7 @@ bool_rprimitive, bytes_rprimitive, dict_rprimitive, + exact_dict_rprimitive, float_rprimitive, frozenset_rprimitive, int16_rprimitive, @@ -90,6 +91,9 @@ def type_to_rtype(self, typ: Type | None) -> RType: return bytes_rprimitive elif typ.type.fullname == "builtins.list": return list_rprimitive + # TODO: figure out why this breaks tests, fix, and uncomment + # elif typ.type.fullname == "builtins.dict": + # return exact_dict_rprimitive # Dict subclasses are at least somewhat common and we # specifically support them, so make sure that dict operations # get optimized on them. @@ -154,7 +158,7 @@ def type_to_rtype(self, typ: Type | None) -> RType: elif isinstance(typ, Overloaded): return object_rprimitive elif isinstance(typ, TypedDictType): - return dict_rprimitive + return exact_dict_rprimitive elif isinstance(typ, LiteralType): return self.type_to_rtype(typ.fallback) elif isinstance(typ, (UninhabitedType, UnboundType)): @@ -169,7 +173,7 @@ def get_arg_rtype(self, typ: Type, kind: ArgKind) -> RType: if kind == ARG_STAR: return tuple_rprimitive elif kind == ARG_STAR2: - return dict_rprimitive + return exact_dict_rprimitive else: return self.type_to_rtype(typ) diff --git a/mypyc/irbuild/match.py b/mypyc/irbuild/match.py index c2ca9cfd32ff..89e02228310c 100644 --- a/mypyc/irbuild/match.py +++ b/mypyc/irbuild/match.py @@ -22,7 +22,7 @@ from mypyc.irbuild.builder import IRBuilder from mypyc.primitives.dict_ops import ( dict_copy, - dict_del_item, + exact_dict_del_item, mapping_has_key, supports_mapping_protocol, ) @@ -239,7 +239,7 @@ def visit_mapping_pattern(self, pattern: MappingPattern) -> None: self.builder.assign(target, rest, pattern.rest.line) for i, key_name in enumerate(keys): - self.builder.call_c(dict_del_item, [rest, key_name], pattern.keys[i].line) + self.builder.call_c(exact_dict_del_item, [rest, key_name], pattern.keys[i].line) self.builder.goto(self.code_block) diff --git a/mypyc/irbuild/prepare.py b/mypyc/irbuild/prepare.py index 61e3e5b95cf4..611c776c07ae 100644 --- a/mypyc/irbuild/prepare.py +++ b/mypyc/irbuild/prepare.py @@ -54,7 +54,7 @@ from mypyc.ir.rtypes import ( RInstance, RType, - dict_rprimitive, + exact_dict_rprimitive, none_rprimitive, object_pointer_rprimitive, object_rprimitive, @@ -626,7 +626,7 @@ def prepare_init_method(cdef: ClassDef, ir: ClassIR, module_name: str, mapper: M [ init_sig.args[0], RuntimeArg("args", tuple_rprimitive, ARG_STAR), - RuntimeArg("kwargs", dict_rprimitive, ARG_STAR2), + RuntimeArg("kwargs", exact_dict_rprimitive, ARG_STAR2), ], init_sig.ret_type, ) diff --git a/mypyc/irbuild/specialize.py b/mypyc/irbuild/specialize.py index 0880c62bc7a5..cb65ff0b8932 100644 --- a/mypyc/irbuild/specialize.py +++ b/mypyc/irbuild/specialize.py @@ -53,12 +53,14 @@ bytes_rprimitive, c_int_rprimitive, dict_rprimitive, + exact_dict_rprimitive, int16_rprimitive, int32_rprimitive, int64_rprimitive, int_rprimitive, is_bool_rprimitive, is_dict_rprimitive, + is_exact_dict_rprimitive, is_fixed_width_rtype, is_float_rprimitive, is_int16_rprimitive, @@ -91,6 +93,9 @@ dict_keys_op, dict_setdefault_spec_init_op, dict_values_op, + exact_dict_items_op, + exact_dict_keys_op, + exact_dict_values_op, isinstance_dict, ) from mypyc.primitives.float_ops import isinstance_float @@ -254,11 +259,22 @@ def dict_methods_fast_path(builder: IRBuilder, expr: CallExpr, callee: RefExpr) # so the corresponding helpers in CPy.h fallback to (inlined) # generic logic. if attr == "keys": - return builder.call_c(dict_keys_op, [obj], expr.line) + if is_exact_dict_rprimitive(rtype): + op = exact_dict_keys_op + else: + op = dict_keys_op elif attr == "values": - return builder.call_c(dict_values_op, [obj], expr.line) + if is_exact_dict_rprimitive(rtype): + op = exact_dict_values_op + else: + op = dict_values_op else: - return builder.call_c(dict_items_op, [obj], expr.line) + if is_exact_dict_rprimitive(rtype): + op = exact_dict_items_op + else: + op = dict_items_op + + return builder.call_c(op, [obj], expr.line) @specialize_function("builtins.list") @@ -367,6 +383,7 @@ def faster_min_max(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value @specialize_function("join", str_rprimitive) @specialize_function("extend", list_rprimitive) @specialize_function("update", dict_rprimitive) +@specialize_function("update", exact_dict_rprimitive) @specialize_function("update", set_rprimitive) def translate_safe_generator_call( builder: IRBuilder, expr: CallExpr, callee: RefExpr @@ -608,6 +625,7 @@ def translate_isinstance(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> @specialize_function("setdefault", dict_rprimitive) +@specialize_function("setdefault", exact_dict_rprimitive) def translate_dict_setdefault(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None: """Special case for 'dict.setdefault' which would only construct default empty collection when needed. diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index 5dec7509ac7b..2e49d67fa5b2 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -711,6 +711,17 @@ tuple_T3CIO CPyDict_NextValue(PyObject *dict_or_iter, CPyTagged offset); tuple_T4CIOO CPyDict_NextItem(PyObject *dict_or_iter, CPyTagged offset); int CPyMapping_Check(PyObject *obj); +// Unsafe dict operations (assume PyDict_CheckExact(dict) is always true) +int CPyDict_UpdateFromAnyUnsafe(PyObject *dict, PyObject *stuff); +PyObject *CPyDict_GetItemUnsafe(PyObject *dict, PyObject *key); +tuple_T3CIO CPyDict_NextKeyUnsafe(PyObject *dict_or_iter, CPyTagged offset); +tuple_T3CIO CPyDict_NextValueUnsafe(PyObject *dict_or_iter, CPyTagged offset); +tuple_T4CIOO CPyDict_NextItemUnsafe(PyObject *dict_or_iter, CPyTagged offset); +PyObject *CPyDict_KeysViewUnsafe(PyObject *dict); +PyObject *CPyDict_ValuesViewUnsafe(PyObject *dict); +PyObject *CPyDict_ItemsViewUnsafe(PyObject *dict); +PyObject *_CPyDict_GetIterUnsafe(PyObject *dict); + // Check that dictionary didn't change size during iteration. static inline char CPyDict_CheckSize(PyObject *dict, Py_ssize_t size) { if (!PyDict_CheckExact(dict)) { @@ -725,6 +736,25 @@ static inline char CPyDict_CheckSize(PyObject *dict, Py_ssize_t size) { return 1; } +// Unsafe because it assumes dict is actually a dict. +static inline char CPyDict_CheckSizeUnsafe(PyObject *dict, Py_ssize_t size) { + Py_ssize_t dict_size = PyDict_Size(dict); + if (size != dict_size) { + PyErr_SetString(PyExc_RuntimeError, "dictionary changed size during iteration"); + return 0; + } + return 1; +} + + +static inline char CPyDict_IsTrue(PyObject *dict) { + Py_ssize_t size = PyDict_Size(dict); + if (size != 0) { + return 1; + } + return 0; +} + // Str operations diff --git a/mypyc/lib-rt/dict_ops.c b/mypyc/lib-rt/dict_ops.c index b102aba57307..547e8c779c32 100644 --- a/mypyc/lib-rt/dict_ops.c +++ b/mypyc/lib-rt/dict_ops.c @@ -15,15 +15,7 @@ // some indirections. PyObject *CPyDict_GetItem(PyObject *dict, PyObject *key) { if (PyDict_CheckExact(dict)) { - PyObject *res = PyDict_GetItemWithError(dict, key); - if (!res) { - if (!PyErr_Occurred()) { - PyErr_SetObject(PyExc_KeyError, key); - } - } else { - Py_INCREF(res); - } - return res; + return CPyDict_GetItemUnsafe(dict, key); } else { return PyObject_GetItem(dict, key); } @@ -169,12 +161,7 @@ int CPyDict_Update(PyObject *dict, PyObject *stuff) { int CPyDict_UpdateFromAny(PyObject *dict, PyObject *stuff) { if (PyDict_CheckExact(dict)) { // Argh this sucks - _Py_IDENTIFIER(keys); - if (PyDict_Check(stuff) || _CPyObject_HasAttrId(stuff, &PyId_keys)) { - return PyDict_Update(dict, stuff); - } else { - return PyDict_MergeFromSeq2(dict, stuff, 1); - } + return CPyDict_UpdateFromAnyUnsafe(dict, stuff); } else { return CPyDict_UpdateGeneral(dict, stuff); } @@ -348,9 +335,7 @@ PyObject *CPyDict_GetKeysIter(PyObject *dict) { PyObject *CPyDict_GetItemsIter(PyObject *dict) { if (PyDict_CheckExact(dict)) { - // Return dict itself to indicate we can use fast path instead. - Py_INCREF(dict); - return dict; + return _CPyDict_GetIterUnsafe(dict); } _Py_IDENTIFIER(items); PyObject *name = _PyUnicode_FromId(&PyId_items); /* borrowed */ @@ -485,7 +470,108 @@ tuple_T4CIOO CPyDict_NextItem(PyObject *dict_or_iter, CPyTagged offset) { Py_INCREF(ret.f3); return ret; } +tuple_T3CIO CPyDict_NextKeyUnsafe(PyObject *dict_or_iter, CPyTagged offset) { + tuple_T3CIO ret; + Py_ssize_t py_offset = CPyTagged_AsSsize_t(offset); + PyObject *dummy; + + ret.f0 = PyDict_Next(dict_or_iter, &py_offset, &ret.f2, &dummy); + if (ret.f0) { + ret.f1 = CPyTagged_FromSsize_t(py_offset); + } else { + // Set key to None, so mypyc can manage refcounts. + ret.f1 = 0; + ret.f2 = Py_None; + } + // PyDict_Next() returns borrowed references. + Py_INCREF(ret.f2); + return ret; +} + +tuple_T3CIO CPyDict_NextValueUnsafe(PyObject *dict_or_iter, CPyTagged offset) { + tuple_T3CIO ret; + Py_ssize_t py_offset = CPyTagged_AsSsize_t(offset); + PyObject *dummy; + + ret.f0 = PyDict_Next(dict_or_iter, &py_offset, &dummy, &ret.f2); + if (ret.f0) { + ret.f1 = CPyTagged_FromSsize_t(py_offset); + } else { + // Set value to None, so mypyc can manage refcounts. + ret.f1 = 0; + ret.f2 = Py_None; + } + // PyDict_Next() returns borrowed references. + Py_INCREF(ret.f2); + return ret; +} + +tuple_T4CIOO CPyDict_NextItemUnsafe(PyObject *dict_or_iter, CPyTagged offset) { + tuple_T4CIOO ret; + Py_ssize_t py_offset = CPyTagged_AsSsize_t(offset); + + ret.f0 = PyDict_Next(dict_or_iter, &py_offset, &ret.f2, &ret.f3); + if (ret.f0) { + ret.f1 = CPyTagged_FromSsize_t(py_offset); + } else { + // Set key and value to None, so mypyc can manage refcounts. + ret.f1 = 0; + ret.f2 = Py_None; + ret.f3 = Py_None; + } + // PyDict_Next() returns borrowed references. + Py_INCREF(ret.f2); + Py_INCREF(ret.f3); + return ret; +} int CPyMapping_Check(PyObject *obj) { return Py_TYPE(obj)->tp_flags & Py_TPFLAGS_MAPPING; } + +// ======================= +// Unsafe dict operations +// ======================= + +// Unsafe: assumes dict is a true dict (PyDict_CheckExact(dict) is always true) + +int CPyDict_UpdateFromAnyUnsafe(PyObject *dict, PyObject *stuff) { + // Argh this sucks + _Py_IDENTIFIER(keys); + if (PyDict_Check(stuff) || _CPyObject_HasAttrId(stuff, &PyId_keys)) { + return PyDict_Update(dict, stuff); + } else { + return PyDict_MergeFromSeq2(dict, stuff, 1); + } +} + +PyObject *CPyDict_GetItemUnsafe(PyObject *dict, PyObject *key) { + // No type check, direct call + PyObject *res = PyDict_GetItemWithError(dict, key); + if (!res) { + if (!PyErr_Occurred()) { + PyErr_SetObject(PyExc_KeyError, key); + } + } else { + Py_INCREF(res); + } + return res; +} + +PyObject *_CPyDict_GetIterUnsafe(PyObject *dict) { + // No type check, no-op to prepare for fast path. Returns dict to pass directly to fast path for further handling. + Py_INCREF(dict); + return dict; +} + +PyObject *CPyDict_KeysViewUnsafe(PyObject *dict) { + return _CPyDictView_New(dict, &PyDictKeys_Type); +} + +PyObject *CPyDict_ValuesViewUnsafe(PyObject *dict) { + return _CPyDictView_New(dict, &PyDictValues_Type); +} + +PyObject *CPyDict_ItemsViewUnsafe(PyObject *dict) { + return _CPyDictView_New(dict, &PyDictItems_Type); +} diff --git a/mypyc/primitives/bytes_ops.py b/mypyc/primitives/bytes_ops.py index c88e89d1a2ba..b0c664e75062 100644 --- a/mypyc/primitives/bytes_ops.py +++ b/mypyc/primitives/bytes_ops.py @@ -10,6 +10,7 @@ c_int_rprimitive, c_pyssize_t_rprimitive, dict_rprimitive, + exact_dict_rprimitive, int_rprimitive, list_rprimitive, object_rprimitive, @@ -30,7 +31,7 @@ # bytes(obj) function_op( name="builtins.bytes", - arg_types=[RUnion([list_rprimitive, dict_rprimitive, str_rprimitive])], + arg_types=[RUnion([list_rprimitive, dict_rprimitive, exact_dict_rprimitive, str_rprimitive])], return_type=bytes_rprimitive, c_function_name="PyBytes_FromObject", error_kind=ERR_MAGIC, diff --git a/mypyc/primitives/dict_ops.py b/mypyc/primitives/dict_ops.py index f98bcc8ac2ec..04fa3d5af920 100644 --- a/mypyc/primitives/dict_ops.py +++ b/mypyc/primitives/dict_ops.py @@ -11,9 +11,11 @@ dict_next_rtuple_pair, dict_next_rtuple_single, dict_rprimitive, + exact_dict_rprimitive, int_rprimitive, list_rprimitive, object_rprimitive, + void_rtype, ) from mypyc.primitives.registry import ( ERR_NEG_INT, @@ -31,14 +33,17 @@ function_op( name="builtins.dict", arg_types=[], - return_type=dict_rprimitive, + return_type=exact_dict_rprimitive, c_function_name="PyDict_New", error_kind=ERR_MAGIC, ) # Construct an empty dictionary. dict_new_op = custom_op( - arg_types=[], return_type=dict_rprimitive, c_function_name="PyDict_New", error_kind=ERR_MAGIC + arg_types=[], + return_type=exact_dict_rprimitive, + c_function_name="PyDict_New", + error_kind=ERR_MAGIC, ) # Construct a dictionary from keys and values. @@ -46,7 +51,7 @@ # Variable arguments are (key1, value1, ..., keyN, valueN). dict_build_op = custom_op( arg_types=[c_pyssize_t_rprimitive], - return_type=dict_rprimitive, + return_type=exact_dict_rprimitive, c_function_name="CPyDict_Build", error_kind=ERR_MAGIC, var_arg_type=object_rprimitive, @@ -54,6 +59,15 @@ # Construct a dictionary from another dictionary. dict_copy_op = function_op( + name="builtins.dict", + arg_types=[exact_dict_rprimitive], + return_type=exact_dict_rprimitive, + c_function_name="PyDict_Copy", + error_kind=ERR_MAGIC, + priority=2, +) + +function_op( name="builtins.dict", arg_types=[dict_rprimitive], return_type=dict_rprimitive, @@ -66,7 +80,7 @@ dict_copy = function_op( name="builtins.dict", arg_types=[object_rprimitive], - return_type=dict_rprimitive, + return_type=exact_dict_rprimitive, c_function_name="CPyDict_FromAny", error_kind=ERR_MAGIC, ) @@ -81,6 +95,16 @@ ) # dict[key] +exact_dict_get_item_op = method_op( + name="__getitem__", + arg_types=[exact_dict_rprimitive, object_rprimitive], + return_type=object_rprimitive, + c_function_name="CPyDict_GetItemUnsafe", + error_kind=ERR_MAGIC, + priority=2, +) + +# dictsubclass[key] dict_get_item_op = method_op( name="__getitem__", arg_types=[dict_rprimitive, object_rprimitive], @@ -100,14 +124,28 @@ # dict[key] = value (exact dict only, no subclasses) # NOTE: this is currently for internal use only, and not used for CallExpr specialization -exact_dict_set_item_op = custom_op( - arg_types=[dict_rprimitive, object_rprimitive, object_rprimitive], +exact_dict_set_item_op = method_op( + name="__setitem__", + arg_types=[exact_dict_rprimitive, object_rprimitive, object_rprimitive], return_type=c_int_rprimitive, c_function_name="PyDict_SetItem", error_kind=ERR_NEG_INT, + priority=2, ) # key in dict +binary_op( + name="in", + arg_types=[object_rprimitive, exact_dict_rprimitive], + return_type=c_int_rprimitive, + c_function_name="PyDict_Contains", + error_kind=ERR_NEG_INT, + truncated_type=bool_rprimitive, + ordering=[1, 0], + priority=2, +) + +# key in dict or dict subclass binary_op( name="in", arg_types=[object_rprimitive, dict_rprimitive], @@ -119,6 +157,36 @@ ) # dict1.update(dict2) +exact_dict_update_op = method_op( + name="update", + arg_types=[exact_dict_rprimitive, exact_dict_rprimitive], + return_type=c_int_rprimitive, + c_function_name="PyDict_Update", + error_kind=ERR_NEG_INT, + priority=5, +) + +# dictorsubclass.update(dict) +dict_update_from_exact_dict_op = method_op( + name="update", + arg_types=[dict_rprimitive, exact_dict_rprimitive], + return_type=c_int_rprimitive, + c_function_name="CPyDict_Update", + error_kind=ERR_NEG_INT, + priority=3, +) + +# dict.update(dictsubclass) +exact_dict_update_from_dict_op = method_op( + name="update", + arg_types=[exact_dict_rprimitive, dict_rprimitive], + return_type=c_int_rprimitive, + c_function_name="PyDict_Update", + error_kind=ERR_NEG_INT, + priority=4, +) + +# dictsubclass1.update(dictsubclass2) dict_update_op = method_op( name="update", arg_types=[dict_rprimitive, dict_rprimitive], @@ -128,6 +196,15 @@ priority=2, ) +# Operation used for **value in with exact dictionary `value`. +# This is mostly like dict.update(obj), but has customized error handling. +exact_dict_update_in_display_op = custom_op( + arg_types=[exact_dict_rprimitive, exact_dict_rprimitive], + return_type=c_int_rprimitive, + c_function_name="PyDict_Update", + error_kind=ERR_NEG_INT, +) + # Operation used for **value in dict displays. # This is mostly like dict.update(obj), but has customized error handling. dict_update_in_display_op = custom_op( @@ -138,6 +215,15 @@ ) # dict.update(obj) +method_op( + name="update", + arg_types=[exact_dict_rprimitive, object_rprimitive], + return_type=c_int_rprimitive, + c_function_name="CPyDict_UpdateFromAnyUnsafe", + error_kind=ERR_NEG_INT, +) + +# dictorsubclass.update(obj) method_op( name="update", arg_types=[dict_rprimitive, object_rprimitive], @@ -156,6 +242,16 @@ ) # dict.get(key) +exact_dict_get_method_with_none = method_op( + name="get", + arg_types=[exact_dict_rprimitive, object_rprimitive], + return_type=object_rprimitive, + c_function_name="CPyDict_GetWithNone", + error_kind=ERR_MAGIC, + priority=2, +) + +# dictorsubclass.get(key) dict_get_method_with_none = method_op( name="get", arg_types=[dict_rprimitive, object_rprimitive], @@ -165,6 +261,16 @@ ) # dict.setdefault(key, default) +exact_dict_setdefault_op = method_op( + name="setdefault", + arg_types=[exact_dict_rprimitive, object_rprimitive, object_rprimitive], + return_type=object_rprimitive, + c_function_name="PyDict_SetDefault", + error_kind=ERR_NEVER, + priority=2, +) + +# dictorsubclass.setdefault(key, default) dict_setdefault_op = method_op( name="setdefault", arg_types=[dict_rprimitive, object_rprimitive, object_rprimitive], @@ -194,6 +300,16 @@ ) # dict.keys() +method_op( + name="keys", + arg_types=[exact_dict_rprimitive], + return_type=object_rprimitive, + c_function_name="CPyDict_KeysViewUnsafe", + error_kind=ERR_MAGIC, + priority=2, +) + +# dictorsubclass.keys() method_op( name="keys", arg_types=[dict_rprimitive], @@ -203,6 +319,16 @@ ) # dict.values() +method_op( + name="values", + arg_types=[exact_dict_rprimitive], + return_type=object_rprimitive, + c_function_name="CPyDict_ValuesViewUnsafe", + error_kind=ERR_MAGIC, + priority=2, +) + +# dictorsubclass.values() method_op( name="values", arg_types=[dict_rprimitive], @@ -212,6 +338,16 @@ ) # dict.items() +method_op( + name="items", + arg_types=[exact_dict_rprimitive], + return_type=object_rprimitive, + c_function_name="CPyDict_ItemsViewUnsafe", + error_kind=ERR_MAGIC, + priority=2, +) + +# dictorsubclass.items() method_op( name="items", arg_types=[dict_rprimitive], @@ -221,6 +357,16 @@ ) # dict.clear() +method_op( + name="clear", + arg_types=[exact_dict_rprimitive], + return_type=void_rtype, + c_function_name="PyDict_Clear", + error_kind=ERR_NEVER, + priority=2, +) + +# dictsubclass.clear() method_op( name="clear", arg_types=[dict_rprimitive], @@ -230,6 +376,16 @@ ) # dict.copy() +method_op( + name="copy", + arg_types=[exact_dict_rprimitive], + return_type=exact_dict_rprimitive, + c_function_name="PyDict_Copy", + error_kind=ERR_NEVER, + priority=2, +) + +# dictsubclass.copy() method_op( name="copy", arg_types=[dict_rprimitive], @@ -238,7 +394,23 @@ error_kind=ERR_MAGIC, ) +# dict.copy() custom_op +exact_dict_copy_op = custom_op( + arg_types=[exact_dict_rprimitive], + return_type=exact_dict_rprimitive, + c_function_name="PyDict_Copy", + error_kind=ERR_NEVER, +) + # list(dict.keys()) +exact_dict_keys_op = custom_op( + arg_types=[exact_dict_rprimitive], + return_type=list_rprimitive, + c_function_name="PyDict_Keys", + error_kind=ERR_NEVER, +) + +# list(dictorsubclass.keys()) dict_keys_op = custom_op( arg_types=[dict_rprimitive], return_type=list_rprimitive, @@ -247,6 +419,14 @@ ) # list(dict.values()) +exact_dict_values_op = custom_op( + arg_types=[exact_dict_rprimitive], + return_type=list_rprimitive, + c_function_name="PyDict_Values", + error_kind=ERR_NEVER, +) + +# list(dictorsubclass.values()) dict_values_op = custom_op( arg_types=[dict_rprimitive], return_type=list_rprimitive, @@ -255,6 +435,14 @@ ) # list(dict.items()) +exact_dict_items_op = custom_op( + arg_types=[exact_dict_rprimitive], + return_type=list_rprimitive, + c_function_name="PyDict_Items", + error_kind=ERR_NEVER, +) + +# list(dictorsubclass.items()) dict_items_op = custom_op( arg_types=[dict_rprimitive], return_type=list_rprimitive, @@ -263,6 +451,14 @@ ) # PyDict_Next() fast iteration +exact_dict_iter_fast_path_op = custom_op( + arg_types=[exact_dict_rprimitive], + return_type=exact_dict_rprimitive, + c_function_name="_CPyDict_GetIterUnsafe", + error_kind=ERR_NEVER, +) + +# PyDict_Next() fast iteration for subclass dict_key_iter_op = custom_op( arg_types=[dict_rprimitive], return_type=object_rprimitive, @@ -305,7 +501,36 @@ error_kind=ERR_NEVER, ) +exact_dict_next_key_op = custom_op( + arg_types=[object_rprimitive, int_rprimitive], + return_type=dict_next_rtuple_single, + c_function_name="CPyDict_NextKeyUnsafe", + error_kind=ERR_NEVER, +) + +exact_dict_next_value_op = custom_op( + arg_types=[object_rprimitive, int_rprimitive], + return_type=dict_next_rtuple_single, + c_function_name="CPyDict_NextValueUnsafe", + error_kind=ERR_NEVER, +) + +exact_dict_next_item_op = custom_op( + arg_types=[exact_dict_rprimitive, int_rprimitive], + return_type=dict_next_rtuple_pair, + c_function_name="CPyDict_NextItemUnsafe", + error_kind=ERR_NEVER, +) + # check that len(dict) == const during iteration +exact_dict_check_size_op = custom_op( + arg_types=[exact_dict_rprimitive, c_pyssize_t_rprimitive], + return_type=bit_rprimitive, + c_function_name="CPyDict_CheckSizeUnsafe", + error_kind=ERR_FALSE, +) + +# check that len(dictorsubclass) == const during iteration dict_check_size_op = custom_op( arg_types=[dict_rprimitive, c_pyssize_t_rprimitive], return_type=bit_rprimitive, @@ -313,6 +538,13 @@ error_kind=ERR_FALSE, ) +exact_dict_ssize_t_size_op = custom_op( + arg_types=[exact_dict_rprimitive], + return_type=c_pyssize_t_rprimitive, + c_function_name="PyDict_Size", + error_kind=ERR_NEVER, +) + dict_ssize_t_size_op = custom_op( arg_types=[dict_rprimitive], return_type=c_pyssize_t_rprimitive, @@ -321,6 +553,13 @@ ) # Delete an item from a dict +exact_dict_del_item = custom_op( + arg_types=[exact_dict_rprimitive, object_rprimitive], + return_type=c_int_rprimitive, + c_function_name="PyDict_DelItem", + error_kind=ERR_NEG_INT, +) + dict_del_item = custom_op( arg_types=[object_rprimitive, object_rprimitive], return_type=c_int_rprimitive, diff --git a/mypyc/primitives/misc_ops.py b/mypyc/primitives/misc_ops.py index 8e6e450c64dc..bb92f0807aa0 100644 --- a/mypyc/primitives/misc_ops.py +++ b/mypyc/primitives/misc_ops.py @@ -13,6 +13,7 @@ c_pyssize_t_rprimitive, cstring_rprimitive, dict_rprimitive, + exact_dict_rprimitive, float_rprimitive, int_rprimitive, none_rprimitive, @@ -161,7 +162,7 @@ # Get the sys.modules dictionary get_module_dict_op = custom_op( arg_types=[], - return_type=dict_rprimitive, + return_type=exact_dict_rprimitive, c_function_name="PyImport_GetModuleDict", error_kind=ERR_NEVER, is_borrowed=True, @@ -188,6 +189,15 @@ priority=0, ) +# bool(dict) +dict_is_true_op = function_op( + name="builtins.bool", + arg_types=[exact_dict_rprimitive], + return_type=bit_rprimitive, + c_function_name="CPyDict_IsTrue", + error_kind=ERR_NEVER, +) + # bool(obj) with unboxed result bool_op = function_op( name="builtins.bool", diff --git a/mypyc/rt_subtype.py b/mypyc/rt_subtype.py index 004e56ed75bc..01619158a954 100644 --- a/mypyc/rt_subtype.py +++ b/mypyc/rt_subtype.py @@ -27,6 +27,8 @@ RVoid, is_bit_rprimitive, is_bool_rprimitive, + is_dict_rprimitive, + is_exact_dict_rprimitive, is_int_rprimitive, is_short_int_rprimitive, ) @@ -58,6 +60,8 @@ def visit_rprimitive(self, left: RPrimitive) -> bool: return True if is_bit_rprimitive(left) and is_bool_rprimitive(self.right): return True + if is_exact_dict_rprimitive(left) and is_dict_rprimitive(self.right): + return True return left is self.right def visit_rtuple(self, left: RTuple) -> bool: diff --git a/mypyc/subtype.py b/mypyc/subtype.py index 726a48d7a01d..6feb4b83b5cf 100644 --- a/mypyc/subtype.py +++ b/mypyc/subtype.py @@ -14,6 +14,8 @@ RVoid, is_bit_rprimitive, is_bool_rprimitive, + is_dict_rprimitive, + is_exact_dict_rprimitive, is_fixed_width_rtype, is_int_rprimitive, is_object_rprimitive, @@ -67,6 +69,9 @@ def visit_rprimitive(self, left: RPrimitive) -> bool: elif is_fixed_width_rtype(left): if is_int_rprimitive(right): return True + elif is_exact_dict_rprimitive(left): + if is_dict_rprimitive(right): + return True return left is right def visit_rtuple(self, left: RTuple) -> bool: diff --git a/mypyc/test-data/irbuild-basic.test b/mypyc/test-data/irbuild-basic.test index 612f3266fd79..a8c6f0db8d43 100644 --- a/mypyc/test-data/irbuild-basic.test +++ b/mypyc/test-data/irbuild-basic.test @@ -526,11 +526,11 @@ def __top_level__(): r11 :: native_int[4] r12 :: c_ptr r13 :: object - r14 :: dict + r14 :: dict[exact] r15, r16 :: str r17 :: bit r18 :: str - r19 :: dict + r19 :: dict[exact] r20 :: str r21 :: i32 r22 :: bit @@ -540,7 +540,7 @@ def __top_level__(): r26 :: native_int[1] r27 :: c_ptr r28 :: object - r29 :: dict + r29 :: dict[exact] r30, r31 :: str r32 :: bit r33 :: object @@ -572,7 +572,7 @@ L2: r18 = 'filler' r19 = __main__.globals :: static r20 = '_' - r21 = CPyDict_SetItem(r19, r20, r18) + r21 = PyDict_SetItem(r19, r20, r18) r22 = r21 >= 0 :: signed r23 = load_address single :: module r24 = [r23] @@ -604,25 +604,25 @@ def h() -> int: [out] def f(x): x :: int - r0 :: dict + r0 :: dict[exact] r1 :: str r2, r3 :: object r4 :: object[1] r5 :: object_ptr r6 :: object r7 :: int - r8 :: dict + r8 :: dict[exact] r9 :: str r10, r11 :: object r12, r13 :: int - r14 :: dict + r14 :: dict[exact] r15 :: str r16, r17 :: object r18, r19 :: int L0: r0 = __main__.globals :: static r1 = 'g' - r2 = CPyDict_GetItem(r0, r1) + r2 = CPyDict_GetItemUnsafe(r0, r1) r3 = box(int, x) r4 = [r3] r5 = load_address r4 @@ -631,13 +631,13 @@ L0: r7 = unbox(int, r6) r8 = __main__.globals :: static r9 = 'h' - r10 = CPyDict_GetItem(r8, r9) + r10 = CPyDict_GetItemUnsafe(r8, r9) r11 = PyObject_Vectorcall(r10, 0, 0, 0) r12 = unbox(int, r11) r13 = CPyTagged_Add(r7, r12) r14 = __main__.globals :: static r15 = 'two' - r16 = CPyDict_GetItem(r14, r15) + r16 = CPyDict_GetItemUnsafe(r14, r15) r17 = PyObject_Vectorcall(r16, 0, 0, 0) r18 = unbox(int, r17) r19 = CPyTagged_Add(r13, r18) @@ -648,10 +648,10 @@ def __top_level__(): r3 :: str r4, r5 :: object r6 :: str - r7 :: dict + r7 :: dict[exact] r8, r9, r10 :: object r11 :: str - r12 :: dict + r12 :: dict[exact] r13 :: object L0: r0 = builtins :: module @@ -1137,7 +1137,7 @@ L0: return r0 def call_python_function(x): x :: int - r0 :: dict + r0 :: dict[exact] r1 :: str r2, r3 :: object r4 :: object[1] @@ -1147,7 +1147,7 @@ def call_python_function(x): L0: r0 = __main__.globals :: static r1 = 'f' - r2 = CPyDict_GetItem(r0, r1) + r2 = CPyDict_GetItemUnsafe(r0, r1) r3 = box(int, x) r4 = [r3] r5 = load_address r4 @@ -1159,13 +1159,13 @@ def return_float(): L0: return 5.0 def return_callable_type(): - r0 :: dict + r0 :: dict[exact] r1 :: str r2 :: object L0: r0 = __main__.globals :: static r1 = 'return_float' - r2 = CPyDict_GetItem(r0, r1) + r2 = CPyDict_GetItemUnsafe(r0, r1) return r2 def call_callable_type(): r0, f, r1 :: object @@ -1423,7 +1423,7 @@ def f() -> None: print(x) [out] def f(): - r0 :: dict + r0 :: dict[exact] r1 :: str r2 :: object r3 :: int @@ -1436,7 +1436,7 @@ def f(): L0: r0 = __main__.globals :: static r1 = 'x' - r2 = CPyDict_GetItem(r0, r1) + r2 = CPyDict_GetItemUnsafe(r0, r1) r3 = unbox(int, r2) r4 = builtins :: module r5 = 'print' @@ -1452,12 +1452,12 @@ def __top_level__(): r2 :: bit r3 :: str r4 :: object - r5 :: dict + r5 :: dict[exact] r6 :: str r7 :: object r8 :: i32 r9 :: bit - r10 :: dict + r10 :: dict[exact] r11 :: str r12 :: object r13 :: int @@ -1480,11 +1480,11 @@ L2: r5 = __main__.globals :: static r6 = 'x' r7 = object 1 - r8 = CPyDict_SetItem(r5, r6, r7) + r8 = PyDict_SetItem(r5, r6, r7) r9 = r8 >= 0 :: signed r10 = __main__.globals :: static r11 = 'x' - r12 = CPyDict_GetItem(r10, r11) + r12 = CPyDict_GetItemUnsafe(r10, r11) r13 = unbox(int, r12) r14 = builtins :: module r15 = 'print' @@ -1672,7 +1672,7 @@ L0: return r0 def g(): r0 :: tuple[int, int, int] - r1 :: dict + r1 :: dict[exact] r2 :: str r3, r4, r5 :: object r6 :: tuple[int, int, int] @@ -1687,7 +1687,7 @@ L0: return r6 def h(): r0 :: tuple[int, int] - r1 :: dict + r1 :: dict[exact] r2 :: str r3 :: object r4 :: list @@ -1701,7 +1701,7 @@ L0: r0 = (4, 6) r1 = __main__.globals :: static r2 = 'f' - r3 = CPyDict_GetItem(r1, r2) + r3 = CPyDict_GetItemUnsafe(r1, r2) r4 = PyList_New(1) r5 = object 1 r6 = list_items r4 @@ -1732,7 +1732,7 @@ L0: def g(): r0, r1, r2 :: str r3, r4, r5 :: object - r6, r7 :: dict + r6, r7 :: dict[exact] r8 :: str r9 :: object r10 :: tuple @@ -1749,7 +1749,7 @@ L0: r6 = CPyDict_Build(3, r0, r3, r1, r4, r2, r5) r7 = __main__.globals :: static r8 = 'f' - r9 = CPyDict_GetItem(r7, r8) + r9 = CPyDict_GetItemUnsafe(r7, r8) r10 = CPyTuple_LoadEmptyTupleConstant() r11 = PyDict_Copy(r6) r12 = PyObject_Call(r9, r10, r11) @@ -1758,10 +1758,10 @@ L0: def h(): r0, r1 :: str r2, r3 :: object - r4, r5 :: dict + r4, r5 :: dict[exact] r6 :: str r7 :: object - r8 :: dict + r8 :: dict[exact] r9 :: i32 r10 :: bit r11 :: object @@ -1776,9 +1776,9 @@ L0: r4 = CPyDict_Build(2, r0, r2, r1, r3) r5 = __main__.globals :: static r6 = 'f' - r7 = CPyDict_GetItem(r5, r6) + r7 = CPyDict_GetItemUnsafe(r5, r6) r8 = PyDict_New() - r9 = CPyDict_UpdateInDisplay(r8, r4) + r9 = PyDict_Update(r8, r4) r10 = r9 >= 0 :: signed r11 = object 1 r12 = PyTuple_Pack(1, r11) @@ -1931,7 +1931,7 @@ def f() -> Dict[int, int]: return {x: x*x for x in [1,2,3] if x != 2 if x != 3} [out] def f(): - r0 :: dict + r0 :: dict[exact] r1 :: list r2, r3, r4 :: object r5 :: ptr @@ -2149,7 +2149,7 @@ def __top_level__(): r3 :: str r4, r5 :: object r6 :: str - r7 :: dict + r7 :: dict[exact] r8 :: object r9, r10 :: str r11 :: object @@ -2161,53 +2161,53 @@ def __top_level__(): r17 :: object r18 :: tuple[object, object] r19 :: object - r20 :: dict + r20 :: dict[exact] r21 :: str r22 :: object r23 :: object[2] r24 :: object_ptr r25 :: object - r26 :: dict + r26 :: dict[exact] r27 :: str r28 :: i32 r29 :: bit r30 :: str - r31 :: dict + r31 :: dict[exact] r32 :: str r33, r34 :: object r35 :: object[2] r36 :: object_ptr r37 :: object r38 :: tuple - r39 :: dict + r39 :: dict[exact] r40 :: str r41 :: i32 r42 :: bit - r43 :: dict + r43 :: dict[exact] r44 :: str r45, r46, r47 :: object - r48 :: dict + r48 :: dict[exact] r49 :: str r50 :: i32 r51 :: bit r52 :: str - r53 :: dict + r53 :: dict[exact] r54 :: str r55 :: object - r56 :: dict + r56 :: dict[exact] r57 :: str r58 :: object r59 :: object[2] r60 :: object_ptr r61 :: object - r62 :: dict + r62 :: dict[exact] r63 :: str r64 :: i32 r65 :: bit r66 :: list r67, r68, r69 :: object r70 :: ptr - r71 :: dict + r71 :: dict[exact] r72 :: str r73 :: i32 r74 :: bit @@ -2239,19 +2239,19 @@ L2: r19 = box(tuple[object, object], r18) r20 = __main__.globals :: static r21 = 'NamedTuple' - r22 = CPyDict_GetItem(r20, r21) + r22 = CPyDict_GetItemUnsafe(r20, r21) r23 = [r9, r19] r24 = load_address r23 r25 = PyObject_Vectorcall(r22, r24, 2, 0) keep_alive r9, r19 r26 = __main__.globals :: static r27 = 'Lol' - r28 = CPyDict_SetItem(r26, r27, r25) + r28 = PyDict_SetItem(r26, r27, r25) r29 = r28 >= 0 :: signed r30 = '' r31 = __main__.globals :: static r32 = 'Lol' - r33 = CPyDict_GetItem(r31, r32) + r33 = CPyDict_GetItemUnsafe(r31, r32) r34 = object 1 r35 = [r34, r30] r36 = load_address r35 @@ -2260,31 +2260,31 @@ L2: r38 = cast(tuple, r37) r39 = __main__.globals :: static r40 = 'x' - r41 = CPyDict_SetItem(r39, r40, r38) + r41 = PyDict_SetItem(r39, r40, r38) r42 = r41 >= 0 :: signed r43 = __main__.globals :: static r44 = 'List' - r45 = CPyDict_GetItem(r43, r44) + r45 = CPyDict_GetItemUnsafe(r43, r44) r46 = load_address PyLong_Type r47 = PyObject_GetItem(r45, r46) r48 = __main__.globals :: static r49 = 'Foo' - r50 = CPyDict_SetItem(r48, r49, r47) + r50 = PyDict_SetItem(r48, r49, r47) r51 = r50 >= 0 :: signed r52 = 'Bar' r53 = __main__.globals :: static r54 = 'Foo' - r55 = CPyDict_GetItem(r53, r54) + r55 = CPyDict_GetItemUnsafe(r53, r54) r56 = __main__.globals :: static r57 = 'NewType' - r58 = CPyDict_GetItem(r56, r57) + r58 = CPyDict_GetItemUnsafe(r56, r57) r59 = [r52, r55] r60 = load_address r59 r61 = PyObject_Vectorcall(r58, r60, 2, 0) keep_alive r52, r55 r62 = __main__.globals :: static r63 = 'Bar' - r64 = CPyDict_SetItem(r62, r63, r61) + r64 = PyDict_SetItem(r62, r63, r61) r65 = r64 >= 0 :: signed r66 = PyList_New(3) r67 = object 1 @@ -2297,7 +2297,7 @@ L2: keep_alive r66 r71 = __main__.globals :: static r72 = 'y' - r73 = CPyDict_SetItem(r71, r72, r66) + r73 = PyDict_SetItem(r71, r72, r66) r74 = r73 >= 0 :: signed return 1 @@ -2581,19 +2581,19 @@ def c(): r0 :: __main__.c_env r1 :: __main__.d_c_obj r2 :: bool - r3 :: dict + r3 :: dict[exact] r4 :: str r5 :: object r6 :: object[1] r7 :: object_ptr r8 :: object - r9 :: dict + r9 :: dict[exact] r10 :: str r11 :: object r12 :: object[1] r13 :: object_ptr r14, d :: object - r15 :: dict + r15 :: dict[exact] r16 :: str r17 :: i32 r18 :: bit @@ -2610,14 +2610,14 @@ L0: r1.__mypyc_env__ = r0; r2 = is_error r3 = __main__.globals :: static r4 = 'b' - r5 = CPyDict_GetItem(r3, r4) + r5 = CPyDict_GetItemUnsafe(r3, r4) r6 = [r1] r7 = load_address r6 r8 = PyObject_Vectorcall(r5, r7, 1, 0) keep_alive r1 r9 = __main__.globals :: static r10 = 'a' - r11 = CPyDict_GetItem(r9, r10) + r11 = CPyDict_GetItemUnsafe(r9, r10) r12 = [r8] r13 = load_address r12 r14 = PyObject_Vectorcall(r11, r13, 1, 0) @@ -2643,24 +2643,24 @@ def __top_level__(): r3 :: str r4, r5 :: object r6 :: str - r7 :: dict + r7 :: dict[exact] r8 :: object - r9 :: dict + r9 :: dict[exact] r10 :: str r11 :: object - r12 :: dict + r12 :: dict[exact] r13 :: str r14 :: object r15 :: object[1] r16 :: object_ptr r17 :: object - r18 :: dict + r18 :: dict[exact] r19 :: str r20 :: object r21 :: object[1] r22 :: object_ptr r23 :: object - r24 :: dict + r24 :: dict[exact] r25 :: str r26 :: i32 r27 :: bit @@ -2681,17 +2681,17 @@ L2: typing = r8 :: module r9 = __main__.globals :: static r10 = 'c' - r11 = CPyDict_GetItem(r9, r10) + r11 = CPyDict_GetItemUnsafe(r9, r10) r12 = __main__.globals :: static r13 = 'b' - r14 = CPyDict_GetItem(r12, r13) + r14 = CPyDict_GetItemUnsafe(r12, r13) r15 = [r11] r16 = load_address r15 r17 = PyObject_Vectorcall(r14, r16, 1, 0) keep_alive r11 r18 = __main__.globals :: static r19 = 'a' - r20 = CPyDict_GetItem(r18, r19) + r20 = CPyDict_GetItemUnsafe(r18, r19) r21 = [r17] r22 = load_address r21 r23 = PyObject_Vectorcall(r20, r22, 1, 0) @@ -2784,7 +2784,7 @@ def __top_level__(): r3 :: str r4, r5 :: object r6 :: str - r7 :: dict + r7 :: dict[exact] r8 :: object L0: r0 = builtins :: module @@ -3284,24 +3284,24 @@ x = 1 [file p/m.py] [out] def root(): - r0 :: dict + r0 :: dict[exact] r1, r2 :: object r3 :: bit r4 :: str r5 :: object r6 :: str - r7 :: dict + r7 :: dict[exact] r8 :: str r9 :: object r10 :: i32 r11 :: bit - r12 :: dict + r12 :: dict[exact] r13, r14 :: object r15 :: bit r16 :: str r17 :: object r18 :: str - r19 :: dict + r19 :: dict[exact] r20 :: str r21 :: object r22 :: i32 @@ -3320,8 +3320,8 @@ L2: r6 = 'dataclasses' r7 = PyImport_GetModuleDict() r8 = 'dataclasses' - r9 = CPyDict_GetItem(r7, r8) - r10 = CPyDict_SetItem(r0, r6, r9) + r9 = CPyDict_GetItemUnsafe(r7, r8) + r10 = PyDict_SetItem(r0, r6, r9) r11 = r10 >= 0 :: signed r12 = __main__.globals :: static r13 = enum :: module @@ -3336,23 +3336,23 @@ L4: r18 = 'enum' r19 = PyImport_GetModuleDict() r20 = 'enum' - r21 = CPyDict_GetItem(r19, r20) - r22 = CPyDict_SetItem(r12, r18, r21) + r21 = CPyDict_GetItemUnsafe(r19, r20) + r22 = PyDict_SetItem(r12, r18, r21) r23 = r22 >= 0 :: signed return 1 def submodule(): - r0 :: dict + r0 :: dict[exact] r1, r2 :: object r3 :: bit r4 :: str r5 :: object r6 :: str - r7 :: dict + r7 :: dict[exact] r8 :: str r9 :: object r10 :: i32 r11 :: bit - r12 :: dict + r12 :: dict[exact] r13 :: str r14 :: object r15 :: str @@ -3372,12 +3372,12 @@ L2: r6 = 'p' r7 = PyImport_GetModuleDict() r8 = 'p' - r9 = CPyDict_GetItem(r7, r8) - r10 = CPyDict_SetItem(r0, r6, r9) + r9 = CPyDict_GetItemUnsafe(r7, r8) + r10 = PyDict_SetItem(r0, r6, r9) r11 = r10 >= 0 :: signed r12 = PyImport_GetModuleDict() r13 = 'p' - r14 = CPyDict_GetItem(r12, r13) + r14 = CPyDict_GetItemUnsafe(r12, r13) r15 = 'x' r16 = CPyObject_GetAttr(r14, r15) r17 = unbox(int, r16) diff --git a/mypyc/test-data/irbuild-bool.test b/mypyc/test-data/irbuild-bool.test index 5eac6d8db24f..49edfdc0647e 100644 --- a/mypyc/test-data/irbuild-bool.test +++ b/mypyc/test-data/irbuild-bool.test @@ -59,13 +59,17 @@ L0: return r1 [case testConversionToBool] -from typing import List, Optional +from typing import List, Optional, TypedDict class C: pass class D: def __bool__(self) -> bool: return True +class E(TypedDict): + a: str + b: int + def list_to_bool(l: List[str]) -> bool: return bool(l) @@ -80,6 +84,10 @@ def optional_truthy_to_bool(o: Optional[C]) -> bool: def optional_maybe_falsey_to_bool(o: Optional[D]) -> bool: return bool(o) + +def typeddict_to_bool(o: E) -> bool: + return bool(o) +[typing fixtures/typing-full.pyi] [out] def D.__bool__(self): self :: __main__.D @@ -139,6 +147,12 @@ L2: r4 = 0 L3: return r4 +def typeddict_to_bool(o): + o :: dict[exact] + r0 :: bit +L0: + r0 = CPyDict_IsTrue(o) + return r0 [case testBoolComparisons] def eq(x: bool, y: bool) -> bool: diff --git a/mypyc/test-data/irbuild-classes.test b/mypyc/test-data/irbuild-classes.test index 78ca7b68cefb..fcbfec095f49 100644 --- a/mypyc/test-data/irbuild-classes.test +++ b/mypyc/test-data/irbuild-classes.test @@ -201,19 +201,19 @@ def __top_level__(): r3 :: str r4, r5 :: object r6 :: str - r7 :: dict + r7 :: dict[exact] r8, r9 :: object r10 :: str - r11 :: dict + r11 :: dict[exact] r12 :: object r13 :: str - r14 :: dict + r14 :: dict[exact] r15 :: str r16 :: object r17 :: object[1] r18 :: object_ptr r19 :: object - r20 :: dict + r20 :: dict[exact] r21 :: str r22 :: i32 r23 :: bit @@ -225,7 +225,7 @@ def __top_level__(): r30 :: tuple r31 :: i32 r32 :: bit - r33 :: dict + r33 :: dict[exact] r34 :: str r35 :: i32 r36 :: bit @@ -236,15 +236,15 @@ def __top_level__(): r42 :: tuple r43 :: i32 r44 :: bit - r45 :: dict + r45 :: dict[exact] r46 :: str r47 :: i32 r48 :: bit r49, r50 :: object - r51 :: dict + r51 :: dict[exact] r52 :: str r53 :: object - r54 :: dict + r54 :: dict[exact] r55 :: str r56, r57 :: object r58 :: tuple @@ -255,7 +255,7 @@ def __top_level__(): r65 :: tuple r66 :: i32 r67 :: bit - r68 :: dict + r68 :: dict[exact] r69 :: str r70 :: i32 r71 :: bit @@ -282,14 +282,14 @@ L2: r13 = 'T' r14 = __main__.globals :: static r15 = 'TypeVar' - r16 = CPyDict_GetItem(r14, r15) + r16 = CPyDict_GetItemUnsafe(r14, r15) r17 = [r13] r18 = load_address r17 r19 = PyObject_Vectorcall(r16, r18, 1, 0) keep_alive r13 r20 = __main__.globals :: static r21 = 'T' - r22 = CPyDict_SetItem(r20, r21, r19) + r22 = PyDict_SetItem(r20, r21, r19) r23 = r22 >= 0 :: signed r24 = :: object r25 = '__main__' @@ -322,10 +322,10 @@ L2: r50 = __main__.S :: type r51 = __main__.globals :: static r52 = 'Generic' - r53 = CPyDict_GetItem(r51, r52) + r53 = CPyDict_GetItemUnsafe(r51, r52) r54 = __main__.globals :: static r55 = 'T' - r56 = CPyDict_GetItem(r54, r55) + r56 = CPyDict_GetItemUnsafe(r54, r55) r57 = PyObject_GetItem(r53, r56) r58 = PyTuple_Pack(3, r49, r50, r57) r59 = '__main__' @@ -1064,7 +1064,7 @@ L0: return 1 def B.__mypyc_defaults_setup(__mypyc_self__): __mypyc_self__ :: __main__.B - r0 :: dict + r0 :: dict[exact] r1 :: str r2 :: object r3 :: str @@ -1073,7 +1073,7 @@ L0: __mypyc_self__.x = 20 r0 = __main__.globals :: static r1 = 'LOL' - r2 = CPyDict_GetItem(r0, r1) + r2 = CPyDict_GetItemUnsafe(r0, r1) r3 = cast(str, r2) __mypyc_self__.y = r3 r4 = box(None, 1) diff --git a/mypyc/test-data/irbuild-dict.test b/mypyc/test-data/irbuild-dict.test index e0c014f07813..5519c15a1b66 100644 --- a/mypyc/test-data/irbuild-dict.test +++ b/mypyc/test-data/irbuild-dict.test @@ -36,7 +36,8 @@ def f() -> None: d = {} # type: Dict[bool, int] [out] def f(): - r0, d :: dict + r0 :: dict[exact] + d :: dict L0: r0 = PyDict_New() d = r0 @@ -49,7 +50,8 @@ def f() -> None: [out] def f(): - r0, d :: dict + r0 :: dict[exact] + d :: dict L0: r0 = PyDict_New() d = r0 @@ -63,7 +65,8 @@ def f(x): x :: object r0 :: str r1, r2 :: object - r3, d :: dict + r3 :: dict[exact] + d :: dict L0: r0 = '' r1 = object 1 @@ -198,7 +201,7 @@ def f(x, y): y :: dict r0 :: str r1 :: object - r2 :: dict + r2 :: dict[exact] r3 :: i32 r4 :: bit r5 :: object @@ -211,7 +214,7 @@ L0: r3 = CPyDict_UpdateInDisplay(r2, y) r4 = r3 >= 0 :: signed r5 = object 3 - r6 = CPyDict_SetItem(r2, r0, r5) + r6 = PyDict_SetItem(r2, r0, r5) r7 = r6 >= 0 :: signed return r2 @@ -323,7 +326,9 @@ L11: L12: return 1 def union_of_dicts(d): - d, r0, new :: dict + d :: dict + r0 :: dict[exact] + new :: dict r1 :: short_int r2 :: native_int r3 :: object @@ -379,10 +384,10 @@ L4: L5: return 1 def typeddict(d): - d :: dict + d :: dict[exact] r0 :: short_int r1 :: native_int - r2 :: object + r2 :: dict[exact] r3 :: tuple[bool, short_int, object, object] r4 :: short_int r5 :: bool @@ -396,9 +401,9 @@ def typeddict(d): L0: r0 = 0 r1 = PyDict_Size(d) - r2 = CPyDict_GetItemsIter(d) + r2 = _CPyDict_GetIterUnsafe(d) L1: - r3 = CPyDict_NextItem(r2, r0) + r3 = CPyDict_NextItemUnsafe(r2, r0) r4 = r3[1] r0 = r4 r5 = r3[0] @@ -416,7 +421,7 @@ L3: name = v L4: L5: - r11 = CPyDict_CheckSize(d, r1) + r11 = CPyDict_CheckSizeUnsafe(d, r1) goto L1 L6: r12 = CPy_NoErrOccurred() @@ -547,7 +552,7 @@ def f4(d, flag): r1 :: object r2, r3 :: str r4 :: object - r5 :: dict + r5 :: dict[exact] r6, r7 :: object L0: if flag goto L1 else goto L2 :: bool diff --git a/mypyc/test-data/irbuild-generics.test b/mypyc/test-data/irbuild-generics.test index 96437a0079c9..5e75105951e7 100644 --- a/mypyc/test-data/irbuild-generics.test +++ b/mypyc/test-data/irbuild-generics.test @@ -246,7 +246,7 @@ def fn_mapping(m): r35, x_3 :: str r36 :: i32 r37, r38, r39 :: bit - r40 :: dict + r40 :: dict[exact] r41 :: short_int r42 :: native_int r43 :: object @@ -391,7 +391,7 @@ def fn_union(m): r34, x_3 :: str r35 :: i32 r36, r37, r38 :: bit - r39 :: dict + r39 :: dict[exact] r40 :: short_int r41 :: native_int r42 :: object @@ -499,11 +499,11 @@ L19: L20: return 1 def fn_typeddict(t): - t :: dict + t :: dict[exact] r0 :: list r1 :: short_int r2 :: native_int - r3 :: object + r3 :: dict[exact] r4 :: tuple[bool, short_int, object] r5 :: short_int r6 :: bool @@ -514,7 +514,7 @@ def fn_typeddict(t): r13 :: list r14 :: short_int r15 :: native_int - r16 :: object + r16 :: dict[exact] r17 :: tuple[bool, short_int, object] r18 :: short_int r19 :: bool @@ -524,7 +524,7 @@ def fn_typeddict(t): r25 :: set r26 :: short_int r27 :: native_int - r28 :: object + r28 :: dict[exact] r29 :: tuple[bool, short_int, object] r30 :: short_int r31 :: bool @@ -532,10 +532,10 @@ def fn_typeddict(t): r33, x_3 :: str r34 :: i32 r35, r36, r37 :: bit - r38 :: dict + r38 :: dict[exact] r39 :: short_int r40 :: native_int - r41 :: object + r41 :: dict[exact] r42 :: tuple[bool, short_int, object, object] r43 :: short_int r44 :: bool @@ -548,9 +548,9 @@ L0: r0 = PyList_New(0) r1 = 0 r2 = PyDict_Size(t) - r3 = CPyDict_GetKeysIter(t) + r3 = _CPyDict_GetIterUnsafe(t) L1: - r4 = CPyDict_NextKey(r3, r1) + r4 = CPyDict_NextKeyUnsafe(r3, r1) r5 = r4[1] r1 = r5 r6 = r4[0] @@ -562,7 +562,7 @@ L2: r9 = PyList_Append(r0, x) r10 = r9 >= 0 :: signed L3: - r11 = CPyDict_CheckSize(t, r2) + r11 = CPyDict_CheckSizeUnsafe(t, r2) goto L1 L4: r12 = CPy_NoErrOccurred() @@ -570,9 +570,9 @@ L5: r13 = PyList_New(0) r14 = 0 r15 = PyDict_Size(t) - r16 = CPyDict_GetValuesIter(t) + r16 = _CPyDict_GetIterUnsafe(t) L6: - r17 = CPyDict_NextValue(r16, r14) + r17 = CPyDict_NextValueUnsafe(r16, r14) r18 = r17[1] r14 = r18 r19 = r17[0] @@ -583,7 +583,7 @@ L7: r21 = PyList_Append(r13, x_2) r22 = r21 >= 0 :: signed L8: - r23 = CPyDict_CheckSize(t, r15) + r23 = CPyDict_CheckSizeUnsafe(t, r15) goto L6 L9: r24 = CPy_NoErrOccurred() @@ -591,9 +591,9 @@ L10: r25 = PySet_New(0) r26 = 0 r27 = PyDict_Size(t) - r28 = CPyDict_GetKeysIter(t) + r28 = _CPyDict_GetIterUnsafe(t) L11: - r29 = CPyDict_NextKey(r28, r26) + r29 = CPyDict_NextKeyUnsafe(r28, r26) r30 = r29[1] r26 = r30 r31 = r29[0] @@ -605,7 +605,7 @@ L12: r34 = PySet_Add(r25, x_3) r35 = r34 >= 0 :: signed L13: - r36 = CPyDict_CheckSize(t, r27) + r36 = CPyDict_CheckSizeUnsafe(t, r27) goto L11 L14: r37 = CPy_NoErrOccurred() @@ -613,9 +613,9 @@ L15: r38 = PyDict_New() r39 = 0 r40 = PyDict_Size(t) - r41 = CPyDict_GetItemsIter(t) + r41 = _CPyDict_GetIterUnsafe(t) L16: - r42 = CPyDict_NextItem(r41, r39) + r42 = CPyDict_NextItemUnsafe(r41, r39) r43 = r42[1] r39 = r43 r44 = r42[0] @@ -629,7 +629,7 @@ L17: r48 = PyDict_SetItem(r38, k, v) r49 = r48 >= 0 :: signed L18: - r50 = CPyDict_CheckSize(t, r40) + r50 = CPyDict_CheckSizeUnsafe(t, r40) goto L16 L19: r51 = CPy_NoErrOccurred() @@ -683,7 +683,7 @@ def inner_deco_obj.__call__(__mypyc_self__, args, kwargs): r6, x :: object r7 :: native_int can_listcomp :: list - r8 :: dict + r8 :: dict[exact] r9 :: short_int r10 :: native_int r11 :: object diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index 28aff3dcfc45..e46a1466bb81 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -1089,7 +1089,8 @@ def f(x): x :: object r0 :: i32 r1 :: bit - r2, rest :: dict + r2 :: dict[exact] + rest :: dict r3 :: str r4 :: object r5 :: str @@ -1118,6 +1119,7 @@ L3: L4: r10 = box(None, 1) return r10 + [case testMatchMappingPatternWithRestPopKeys_python3_10] def f(x): match x: @@ -1137,7 +1139,8 @@ def f(x): r8 :: i32 r9 :: bit r10 :: bool - r11, rest :: dict + r11 :: dict[exact] + rest :: dict r12 :: i32 r13 :: bit r14 :: str @@ -1183,6 +1186,7 @@ L5: L6: r21 = box(None, 1) return r21 + [case testMatchEmptySequencePattern_python3_10] def f(x): match x: diff --git a/mypyc/test-data/irbuild-set.test b/mypyc/test-data/irbuild-set.test index 5586a2bf4cfb..63b5fec940c8 100644 --- a/mypyc/test-data/irbuild-set.test +++ b/mypyc/test-data/irbuild-set.test @@ -161,7 +161,8 @@ L5: def test3(): r0, r1, r2 :: str r3, r4, r5 :: object - r6, tmp_dict :: dict + r6 :: dict[exact] + tmp_dict :: dict r7 :: set r8 :: short_int r9 :: native_int @@ -646,7 +647,7 @@ L0: return r3 def not_precomputed_non_final_name(i): i :: int - r0 :: dict + r0 :: dict[exact] r1 :: str r2 :: object r3 :: int @@ -661,7 +662,7 @@ def not_precomputed_non_final_name(i): L0: r0 = __main__.globals :: static r1 = 'non_const' - r2 = CPyDict_GetItem(r0, r1) + r2 = CPyDict_GetItemUnsafe(r0, r1) r3 = unbox(int, r2) r4 = PySet_New(0) r5 = box(int, r3) @@ -766,7 +767,7 @@ L4: L5: return 1 def not_precomputed(): - r0 :: dict + r0 :: dict[exact] r1 :: str r2 :: object r3 :: int @@ -780,7 +781,7 @@ def not_precomputed(): L0: r0 = __main__.globals :: static r1 = 'non_const' - r2 = CPyDict_GetItem(r0, r1) + r2 = CPyDict_GetItemUnsafe(r0, r1) r3 = unbox(int, r2) r4 = PySet_New(0) r5 = box(int, r3) diff --git a/mypyc/test-data/irbuild-singledispatch.test b/mypyc/test-data/irbuild-singledispatch.test index 1060ee63c57d..8b94f4171942 100644 --- a/mypyc/test-data/irbuild-singledispatch.test +++ b/mypyc/test-data/irbuild-singledispatch.test @@ -14,7 +14,7 @@ L0: return 0 def f_obj.__init__(__mypyc_self__): __mypyc_self__ :: __main__.f_obj - r0, r1 :: dict + r0, r1 :: dict[exact] r2 :: str r3 :: i32 r4 :: bit @@ -31,13 +31,13 @@ def f_obj.__call__(__mypyc_self__, arg): arg :: object r0 :: ptr r1 :: object - r2 :: dict + r2 :: dict[exact] r3, r4 :: object r5 :: bit r6, r7 :: object r8 :: str r9 :: object - r10 :: dict + r10 :: dict[exact] r11 :: object[2] r12 :: object_ptr r13 :: object @@ -124,14 +124,14 @@ L0: return r0 def f(arg): arg :: object - r0 :: dict + r0 :: dict[exact] r1 :: str r2 :: object r3 :: bool L0: r0 = __main__.globals :: static r1 = 'f' - r2 = CPyDict_GetItem(r0, r1) + r2 = CPyDict_GetItemUnsafe(r0, r1) r3 = f_obj.__call__(r2, arg) return r3 def g(arg): @@ -155,7 +155,7 @@ L0: return 1 def f_obj.__init__(__mypyc_self__): __mypyc_self__ :: __main__.f_obj - r0, r1 :: dict + r0, r1 :: dict[exact] r2 :: str r3 :: i32 r4 :: bit @@ -172,13 +172,13 @@ def f_obj.__call__(__mypyc_self__, x): x :: object r0 :: ptr r1 :: object - r2 :: dict + r2 :: dict[exact] r3, r4 :: object r5 :: bit r6, r7 :: object r8 :: str r9 :: object - r10 :: dict + r10 :: dict[exact] r11 :: object[2] r12 :: object_ptr r13 :: object @@ -255,14 +255,14 @@ L0: return r0 def f(x): x :: object - r0 :: dict + r0 :: dict[exact] r1 :: str r2 :: object r3 :: None L0: r0 = __main__.globals :: static r1 = 'f' - r2 = CPyDict_GetItem(r0, r1) + r2 = CPyDict_GetItemUnsafe(r0, r1) r3 = f_obj.__call__(r2, x) return r3 def test(): diff --git a/mypyc/test-data/irbuild-statements.test b/mypyc/test-data/irbuild-statements.test index 48b8e0e318b8..46643136a131 100644 --- a/mypyc/test-data/irbuild-statements.test +++ b/mypyc/test-data/irbuild-statements.test @@ -759,7 +759,8 @@ def delDictMultiple() -> None: def delDict(): r0, r1 :: str r2, r3 :: object - r4, d :: dict + r4 :: dict[exact] + d :: dict r5 :: str r6 :: i32 r7 :: bit @@ -777,7 +778,8 @@ L0: def delDictMultiple(): r0, r1, r2, r3 :: str r4, r5, r6, r7 :: object - r8, d :: dict + r8 :: dict[exact] + d :: dict r9, r10 :: str r11 :: i32 r12 :: bit diff --git a/mypyc/test-data/run-dicts.test b/mypyc/test-data/run-dicts.test index 2b75b32c906e..c39b9cd03a47 100644 --- a/mypyc/test-data/run-dicts.test +++ b/mypyc/test-data/run-dicts.test @@ -84,13 +84,13 @@ update_dict(d, object.__dict__) assert d == dict(object.__dict__) assert u(10) == 10 -assert get_content({1: 2}) == ([1], [2], [(1, 2)]) +assert get_content({1: 2}) == ([1], [2], [(1, 2)]), get_content({1: 2}) od = OrderedDict([(1, 2), (3, 4)]) -assert get_content(od) == ([1, 3], [2, 4], [(1, 2), (3, 4)]) +assert get_content(od) == ([1, 3], [2, 4], [(1, 2), (3, 4)]), get_content(od) od.move_to_end(1) -assert get_content(od) == ([3, 1], [4, 2], [(3, 4), (1, 2)]) -assert get_content_set({1: 2}) == ({1}, {2}, {(1, 2)}) -assert get_content_set(od) == ({1, 3}, {2, 4}, {(1, 2), (3, 4)}) +assert get_content(od) == ([3, 1], [4, 2], [(3, 4), (1, 2)]), get_content(od) +assert get_content_set({1: 2}) == ({1}, {2}, {(1, 2)}), get_content_set({1: 2}) +assert get_content_set(od) == ({1, 3}, {2, 4}, {(1, 2), (3, 4)}), get_content_set(od) [typing fixtures/typing-full.pyi]