Skip to content

[mypyc] Refactor: use new-style primitives for function ops #18211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
IntOp,
LoadStatic,
Op,
PrimitiveDescription,
RaiseStandardError,
Register,
SetAttr,
Expand Down Expand Up @@ -381,6 +382,9 @@ def load_module(self, name: str) -> Value:
def call_c(self, desc: CFunctionDescription, args: list[Value], line: int) -> Value:
return self.builder.call_c(desc, args, line)

def primitive_op(self, desc: PrimitiveDescription, args: list[Value], line: int) -> Value:
return self.builder.primitive_op(desc, args, line)

def int_op(self, type: RType, lhs: Value, rhs: Value, op: int, line: int) -> Value:
return self.builder.int_op(type, lhs, rhs, op, line)

Expand Down Expand Up @@ -691,7 +695,7 @@ def assign(self, target: Register | AssignmentTarget, rvalue_reg: Value, line: i
else:
key = self.load_str(target.attr)
boxed_reg = self.builder.box(rvalue_reg)
self.call_c(py_setattr_op, [target.obj, key, boxed_reg], line)
self.primitive_op(py_setattr_op, [target.obj, key, boxed_reg], line)
elif isinstance(target, AssignmentTargetIndex):
target_reg2 = self.gen_method_call(
target.base, "__setitem__", [target.index, rvalue_reg], None, line
Expand Down Expand Up @@ -768,7 +772,7 @@ def process_iterator_tuple_assignment_helper(
def process_iterator_tuple_assignment(
self, target: AssignmentTargetTuple, rvalue_reg: Value, line: int
) -> None:
iterator = self.call_c(iter_op, [rvalue_reg], line)
iterator = self.primitive_op(iter_op, [rvalue_reg], line)

# This may be the whole lvalue list if there is no starred value
split_idx = target.star_idx if target.star_idx is not None else len(target.items)
Expand All @@ -794,7 +798,7 @@ def process_iterator_tuple_assignment(
# Assign the starred value and all values after it
if target.star_idx is not None:
post_star_vals = target.items[split_idx + 1 :]
iter_list = self.call_c(to_list, [iterator], line)
iter_list = self.primitive_op(to_list, [iterator], line)
iter_list_len = self.builtin_len(iter_list, line)
post_star_len = Integer(len(post_star_vals))
condition = self.binary_op(post_star_len, iter_list_len, "<=", line)
Expand Down Expand Up @@ -1051,9 +1055,9 @@ def call_refexpr_with_args(
# Handle data-driven special-cased primitive call ops.
if callee.fullname and expr.arg_kinds == [ARG_POS] * len(arg_values):
fullname = get_call_target_fullname(callee)
call_c_ops_candidates = function_ops.get(fullname, [])
target = self.builder.matching_call_c(
call_c_ops_candidates, arg_values, expr.line, self.node_type(expr)
primitive_candidates = function_ops.get(fullname, [])
target = self.builder.matching_primitive_op(
primitive_candidates, arg_values, expr.line, self.node_type(expr)
)
if target:
return target
Expand Down
8 changes: 4 additions & 4 deletions mypyc/irbuild/classdef.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def add_attr(self, lvalue: NameExpr, stmt: AssignmentStmt) -> None:
return
typ = self.builder.load_native_type_object(self.cdef.fullname)
value = self.builder.accept(stmt.rvalue)
self.builder.call_c(
self.builder.primitive_op(
py_setattr_op, [typ, self.builder.load_str(lvalue.name), value], stmt.line
)
if self.builder.non_function_scope() and stmt.is_final_def:
Expand Down Expand Up @@ -452,7 +452,7 @@ def allocate_class(builder: IRBuilder, cdef: ClassDef) -> Value:
)
)
# Populate a '__mypyc_attrs__' field containing the list of attrs
builder.call_c(
builder.primitive_op(
py_setattr_op,
[
tp,
Expand Down Expand Up @@ -483,7 +483,7 @@ def make_generic_base_class(
for tv, type_param in zip(tvs, type_args):
if type_param.kind == TYPE_VAR_TUPLE_KIND:
# Evaluate *Ts for a TypeVarTuple
it = builder.call_c(iter_op, [tv], line)
it = builder.primitive_op(iter_op, [tv], line)
tv = builder.call_c(next_op, [it], line)
args.append(tv)

Expand Down Expand Up @@ -603,7 +603,7 @@ def setup_non_ext_dict(
This class dictionary is passed to the metaclass constructor.
"""
# Check if the metaclass defines a __prepare__ method, and if so, call it.
has_prepare = builder.call_c(
has_prepare = builder.primitive_op(
py_hasattr_op, [metaclass, builder.load_str("__prepare__")], cdef.line
)

Expand Down
8 changes: 4 additions & 4 deletions mypyc/irbuild/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def translate_super_method_call(builder: IRBuilder, expr: CallExpr, callee: Supe
# Grab first argument
vself: Value = builder.self()
if decl.kind == FUNC_CLASSMETHOD:
vself = builder.call_c(type_op, [vself], expr.line)
vself = builder.primitive_op(type_op, [vself], expr.line)
elif builder.fn_info.is_generator:
# For generator classes, the self target is the 6th value
# in the symbol table (which is an ordered dict). This is sort
Expand Down Expand Up @@ -953,7 +953,7 @@ def transform_tuple_expr(builder: IRBuilder, expr: TupleExpr) -> Value:
def _visit_tuple_display(builder: IRBuilder, expr: TupleExpr) -> Value:
"""Create a list, then turn it into a tuple."""
val_as_list = _visit_list_display(builder, expr.items, expr.line)
return builder.call_c(list_tuple_op, [val_as_list], expr.line)
return builder.primitive_op(list_tuple_op, [val_as_list], expr.line)


def transform_dict_expr(builder: IRBuilder, expr: DictExpr) -> Value:
Expand Down Expand Up @@ -1045,12 +1045,12 @@ def get_arg(arg: Expression | None) -> Value:
return builder.accept(arg)

args = [get_arg(expr.begin_index), get_arg(expr.end_index), get_arg(expr.stride)]
return builder.call_c(new_slice_op, args, expr.line)
return builder.primitive_op(new_slice_op, args, expr.line)


def transform_generator_expr(builder: IRBuilder, o: GeneratorExpr) -> Value:
builder.warning("Treating generator comprehension as list", o.line)
return builder.call_c(iter_op, [translate_list_comprehension(builder, o)], o.line)
return builder.primitive_op(iter_op, [translate_list_comprehension(builder, o)], o.line)


def transform_assignment_expr(builder: IRBuilder, o: AssignmentExpr) -> Value:
Expand Down
2 changes: 1 addition & 1 deletion mypyc/irbuild/for_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ def init(self, expr_reg: Value, target_type: RType) -> None:
# for the for-loop. If we are inside of a generator function, spill these into the
# environment class.
builder = self.builder
iter_reg = builder.call_c(iter_op, [expr_reg], self.line)
iter_reg = builder.primitive_op(iter_op, [expr_reg], self.line)
builder.maybe_spill(expr_reg)
self.iter_target = builder.maybe_spill(iter_reg)
self.target_type = target_type
Expand Down
6 changes: 3 additions & 3 deletions mypyc/irbuild/format_str_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,12 @@ def convert_format_expr_to_str(
if is_str_rprimitive(node_type):
var_str = builder.accept(x)
elif is_int_rprimitive(node_type) or is_short_int_rprimitive(node_type):
var_str = builder.call_c(int_to_str_op, [builder.accept(x)], line)
var_str = builder.primitive_op(int_to_str_op, [builder.accept(x)], line)
else:
var_str = builder.call_c(str_op, [builder.accept(x)], line)
var_str = builder.primitive_op(str_op, [builder.accept(x)], line)
elif format_op == FormatOp.INT:
if is_int_rprimitive(node_type) or is_short_int_rprimitive(node_type):
var_str = builder.call_c(int_to_str_op, [builder.accept(x)], line)
var_str = builder.primitive_op(int_to_str_op, [builder.accept(x)], line)
else:
return None
else:
Expand Down
8 changes: 5 additions & 3 deletions mypyc/irbuild/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,9 @@ def handle_ext_method(builder: IRBuilder, cdef: ClassDef, fdef: FuncDef) -> None

# Set the callable object representing the decorated method as an attribute of the
# extension class.
builder.call_c(py_setattr_op, [typ, builder.load_str(name), decorated_func], fdef.line)
builder.primitive_op(
py_setattr_op, [typ, builder.load_str(name), decorated_func], fdef.line
)

if fdef.is_property:
# If there is a property setter, it will be processed after the getter,
Expand Down Expand Up @@ -973,7 +975,7 @@ def generate_singledispatch_callable_class_ctor(builder: IRBuilder) -> None:
cache_dict = builder.call_c(dict_new_op, [], line)
dispatch_cache_str = builder.load_str("dispatch_cache")
# use the py_setattr_op instead of SetAttr so that it also gets added to our __dict__
builder.call_c(py_setattr_op, [builder.self(), dispatch_cache_str, cache_dict], line)
builder.primitive_op(py_setattr_op, [builder.self(), dispatch_cache_str, cache_dict], line)
# the generated C code seems to expect that __init__ returns a char, so just return 1
builder.add(Return(Integer(1, bool_rprimitive, line), line))

Expand Down Expand Up @@ -1016,7 +1018,7 @@ def maybe_insert_into_registry_dict(builder: IRBuilder, fitem: FuncDef) -> None:
registry_dict = builder.builder.make_dict([(loaded_object_type, main_func_obj)], line)

dispatch_func_obj = builder.load_global_str(fitem.name, line)
builder.call_c(
builder.primitive_op(
py_setattr_op, [dispatch_func_obj, builder.load_str("registry"), registry_dict], line
)

Expand Down
20 changes: 11 additions & 9 deletions mypyc/irbuild/ll_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ def py_get_attr(self, obj: Value, attr: str, line: int) -> Value:
Prefer get_attr() which generates optimized code for native classes.
"""
key = self.load_str(attr)
return self.call_c(py_getattr_op, [obj, key], line)
return self.primitive_op(py_getattr_op, [obj, key], line)

# isinstance() checks

Expand Down Expand Up @@ -656,7 +656,9 @@ def isinstance_native(self, obj: Value, class_ir: ClassIR, line: int) -> Value:
"""
concrete = all_concrete_classes(class_ir)
if concrete is None or len(concrete) > FAST_ISINSTANCE_MAX_SUBCLASSES + 1:
return self.call_c(fast_isinstance_op, [obj, self.get_native_type(class_ir)], line)
return self.primitive_op(
fast_isinstance_op, [obj, self.get_native_type(class_ir)], line
)
if not concrete:
# There can't be any concrete instance that matches this.
return self.false()
Expand Down Expand Up @@ -857,7 +859,7 @@ def _construct_varargs(
if star_result is None:
star_result = self.new_tuple(star_values, line)
else:
star_result = self.call_c(list_tuple_op, [star_result], line)
star_result = self.primitive_op(list_tuple_op, [star_result], line)
if has_star2 and star2_result is None:
star2_result = self._create_dict(star2_keys, star2_values, line)

Expand Down Expand Up @@ -1515,7 +1517,7 @@ def compare_tuples(self, lhs: Value, rhs: Value, op: str, line: int = -1) -> Val
# Cast to bool if necessary since most types uses comparison returning a object type
# See generic_ops.py for more information
if not is_bool_rprimitive(compare.type):
compare = self.call_c(bool_op, [compare], line)
compare = self.primitive_op(bool_op, [compare], line)
if i < len(lhs.type.types) - 1:
branch = Branch(compare, early_stop, check_blocks[i + 1], Branch.BOOL)
else:
Expand All @@ -1534,7 +1536,7 @@ def compare_tuples(self, lhs: Value, rhs: Value, op: str, line: int = -1) -> Val
def translate_instance_contains(self, inst: Value, item: Value, op: str, line: int) -> Value:
res = self.gen_method_call(inst, "__contains__", [item], None, line)
if not is_bool_rprimitive(res.type):
res = self.call_c(bool_op, [res], line)
res = self.primitive_op(bool_op, [res], line)
if op == "not in":
res = self.bool_bitwise_op(res, Integer(1, rtype=bool_rprimitive), "^", line)
return res
Expand Down Expand Up @@ -1667,7 +1669,7 @@ def new_list_op(self, values: list[Value], line: int) -> Value:
return result_list

def new_set_op(self, values: list[Value], line: int) -> Value:
return self.call_c(new_set_op, values, line)
return self.primitive_op(new_set_op, values, line)

def setup_rarray(
self, item_type: RType, values: Sequence[Value], *, object_ptr: bool = False
Expand Down Expand Up @@ -1775,7 +1777,7 @@ def bool_value(self, value: Value) -> Value:
self.goto(end)
self.activate_block(end)
else:
result = self.call_c(bool_op, [value], value.line)
result = self.primitive_op(bool_op, [value], value.line)
return result

def add_bool_branch(self, value: Value, true: BasicBlock, false: BasicBlock) -> None:
Expand Down Expand Up @@ -2065,7 +2067,7 @@ def float_mod(self, lhs: Value, rhs: Value, line: int) -> Value:
self.activate_block(copysign)
# If the remainder is zero, CPython ensures the result has the
# same sign as the denominator.
adj = self.call_c(copysign_op, [Float(0.0), rhs], line)
adj = self.primitive_op(copysign_op, [Float(0.0), rhs], line)
self.add(Assign(res, adj))
self.add(Goto(done))
self.activate_block(done)
Expand Down Expand Up @@ -2260,7 +2262,7 @@ def new_tuple_with_length(self, length: Value, line: int) -> Value:
return self.call_c(new_tuple_with_length_op, [length], line)

def int_to_float(self, n: Value, line: int) -> Value:
return self.call_c(int_to_float_op, [n], line)
return self.primitive_op(int_to_float_op, [n], line)

# Internal helpers

Expand Down
4 changes: 2 additions & 2 deletions mypyc/irbuild/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def visit_class_pattern(self, pattern: ClassPattern) -> None:
else slow_isinstance_op
)

cond = self.builder.call_c(
cond = self.builder.primitive_op(
isinstance_op, [self.subject, self.builder.accept(pattern.class_ref)], pattern.line
)

Expand Down Expand Up @@ -246,7 +246,7 @@ def visit_mapping_pattern(self, pattern: MappingPattern) -> None:
self.builder.activate_block(self.code_block)
self.code_block = BasicBlock()

rest = self.builder.call_c(dict_copy, [self.subject], pattern.rest.line)
rest = self.builder.primitive_op(dict_copy, [self.subject], pattern.rest.line)

target = self.builder.get_assignment_target(pattern.rest)

Expand Down
10 changes: 7 additions & 3 deletions mypyc/irbuild/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
LoadLiteral,
LoadStatic,
MethodCall,
PrimitiveDescription,
RaiseStandardError,
Register,
Return,
Expand Down Expand Up @@ -757,7 +758,7 @@ def transform_with(
value = builder.add(MethodCall(mgr_v, f"__{al}enter__", args=[], line=line))
exit_ = None
else:
typ = builder.call_c(type_op, [mgr_v], line)
typ = builder.primitive_op(type_op, [mgr_v], line)
exit_ = builder.maybe_spill(builder.py_get_attr(typ, f"__{al}exit__", line))
value = builder.py_call(builder.py_get_attr(typ, f"__{al}enter__", line), [mgr_v], line)

Expand Down Expand Up @@ -876,7 +877,7 @@ def transform_del_item(builder: IRBuilder, target: AssignmentTarget, line: int)
line,
)
key = builder.load_str(target.attr)
builder.call_c(py_delattr_op, [target.obj, key], line)
builder.primitive_op(py_delattr_op, [target.obj, key], line)
elif isinstance(target, AssignmentTargetRegister):
# Delete a local by assigning an error value to it, which will
# prompt the insertion of uninit checks.
Expand Down Expand Up @@ -924,7 +925,10 @@ def emit_yield_from_or_await(
received_reg = Register(object_rprimitive)

get_op = coro_op if is_await else iter_op
iter_val = builder.call_c(get_op, [val], line)
if isinstance(get_op, PrimitiveDescription):
iter_val = builder.primitive_op(get_op, [val], line)
else:
iter_val = builder.call_c(get_op, [val], line)

iter_reg = builder.maybe_spill_assignable(iter_val)

Expand Down
32 changes: 16 additions & 16 deletions mypyc/primitives/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ class LoadAddressDescription(NamedTuple):
src: str # name of the target to load


# CallC op for method call(such as 'str.join')
# CallC op for method call (such as 'str.join')
method_call_ops: dict[str, list[CFunctionDescription]] = {}

# CallC op for top level function call(such as 'builtins.list')
function_ops: dict[str, list[CFunctionDescription]] = {}
# Primitive ops for top level function call (such as 'builtins.list')
function_ops: dict[str, list[PrimitiveDescription]] = {}

# CallC op for binary ops
# Primitive ops for binary operations
binary_ops: dict[str, list[PrimitiveDescription]] = {}

# CallC op for unary ops
Expand Down Expand Up @@ -161,8 +161,8 @@ def function_op(
steals: StealsDescription = False,
is_borrowed: bool = False,
priority: int = 1,
) -> CFunctionDescription:
"""Define a c function call op that replaces a function call.
) -> PrimitiveDescription:
"""Define a C function call op that replaces a function call.

This will be automatically generated by matching against the AST.

Expand All @@ -175,19 +175,19 @@ def function_op(
if extra_int_constants is None:
extra_int_constants = []
ops = function_ops.setdefault(name, [])
desc = CFunctionDescription(
desc = PrimitiveDescription(
name,
arg_types,
return_type,
var_arg_type,
truncated_type,
c_function_name,
error_kind,
steals,
is_borrowed,
ordering,
extra_int_constants,
priority,
var_arg_type=var_arg_type,
truncated_type=truncated_type,
c_function_name=c_function_name,
error_kind=error_kind,
steals=steals,
is_borrowed=is_borrowed,
ordering=ordering,
extra_int_constants=extra_int_constants,
priority=priority,
is_pure=False,
)
ops.append(desc)
Expand Down
8 changes: 2 additions & 6 deletions mypyc/test/test_cheader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,14 @@ def check_name(name: str) -> None:
rf"\b{name}\b", header
), f'"{name}" is used in mypyc.primitives but not declared in CPy.h'

for old_values in [
registry.method_call_ops.values(),
registry.function_ops.values(),
registry.unary_ops.values(),
]:
for old_values in [registry.method_call_ops.values(), registry.unary_ops.values()]:
for old_ops in old_values:
if isinstance(old_ops, CFunctionDescription):
old_ops = [old_ops]
for old_op in old_ops:
check_name(old_op.c_function_name)

for values in [registry.binary_ops.values()]:
for values in [registry.binary_ops.values(), registry.function_ops.values()]:
for ops in values:
if isinstance(ops, PrimitiveDescription):
ops = [ops]
Expand Down