Skip to content

Commit b2cf9d1

Browse files
authored
[mypyc] Optimize __(a)enter__/__(a)exit__ paths for native case (#14530)
Closes mypyc/mypyc#904 Directly calls enter and exit handlers in the case that the context manager is implemented natively. Unfortunately the implementation becomes a bit more complicated because there are two different places where we call exit in different ways, and they both need to support the native and non-native cases.
1 parent cf2e404 commit b2cf9d1

File tree

4 files changed

+190
-16
lines changed

4 files changed

+190
-16
lines changed

mypyc/irbuild/statement.py

+38-16
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
Integer,
5151
LoadAddress,
5252
LoadErrorValue,
53+
MethodCall,
5354
RaiseStandardError,
5455
Register,
5556
Return,
@@ -61,6 +62,7 @@
6162
RInstance,
6263
exc_rtuple,
6364
is_tagged,
65+
none_rprimitive,
6466
object_pointer_rprimitive,
6567
object_rprimitive,
6668
)
@@ -657,14 +659,45 @@ def transform_with(
657659
al = "a" if is_async else ""
658660

659661
mgr_v = builder.accept(expr)
660-
typ = builder.call_c(type_op, [mgr_v], line)
661-
exit_ = builder.maybe_spill(builder.py_get_attr(typ, f"__{al}exit__", line))
662-
value = builder.py_call(builder.py_get_attr(typ, f"__{al}enter__", line), [mgr_v], line)
662+
is_native = isinstance(mgr_v.type, RInstance)
663+
if is_native:
664+
value = builder.add(MethodCall(mgr_v, f"__{al}enter__", args=[], line=line))
665+
exit_ = None
666+
else:
667+
typ = builder.call_c(type_op, [mgr_v], line)
668+
exit_ = builder.maybe_spill(builder.py_get_attr(typ, f"__{al}exit__", line))
669+
value = builder.py_call(builder.py_get_attr(typ, f"__{al}enter__", line), [mgr_v], line)
670+
663671
mgr = builder.maybe_spill(mgr_v)
664672
exc = builder.maybe_spill_assignable(builder.true())
665673
if is_async:
666674
value = emit_await(builder, value, line)
667675

676+
def maybe_natively_call_exit(exc_info: bool) -> Value:
677+
if exc_info:
678+
args = get_sys_exc_info(builder)
679+
else:
680+
none = builder.none_object()
681+
args = [none, none, none]
682+
683+
if is_native:
684+
assert isinstance(mgr_v.type, RInstance)
685+
exit_val = builder.gen_method_call(
686+
builder.read(mgr),
687+
f"__{al}exit__",
688+
arg_values=args,
689+
line=line,
690+
result_type=none_rprimitive,
691+
)
692+
else:
693+
assert exit_ is not None
694+
exit_val = builder.py_call(builder.read(exit_), [builder.read(mgr)] + args, line)
695+
696+
if is_async:
697+
return emit_await(builder, exit_val, line)
698+
else:
699+
return exit_val
700+
668701
def try_body() -> None:
669702
if target:
670703
builder.assign(builder.get_assignment_target(target), value, line)
@@ -673,13 +706,7 @@ def try_body() -> None:
673706
def except_body() -> None:
674707
builder.assign(exc, builder.false(), line)
675708
out_block, reraise_block = BasicBlock(), BasicBlock()
676-
exit_val = builder.py_call(
677-
builder.read(exit_), [builder.read(mgr)] + get_sys_exc_info(builder), line
678-
)
679-
if is_async:
680-
exit_val = emit_await(builder, exit_val, line)
681-
682-
builder.add_bool_branch(exit_val, out_block, reraise_block)
709+
builder.add_bool_branch(maybe_natively_call_exit(exc_info=True), out_block, reraise_block)
683710
builder.activate_block(reraise_block)
684711
builder.call_c(reraise_exception_op, [], NO_TRACEBACK_LINE_NO)
685712
builder.add(Unreachable())
@@ -689,13 +716,8 @@ def finally_body() -> None:
689716
out_block, exit_block = BasicBlock(), BasicBlock()
690717
builder.add(Branch(builder.read(exc), exit_block, out_block, Branch.BOOL))
691718
builder.activate_block(exit_block)
692-
none = builder.none_object()
693-
exit_val = builder.py_call(
694-
builder.read(exit_), [builder.read(mgr), none, none, none], line
695-
)
696-
if is_async:
697-
emit_await(builder, exit_val, line)
698719

720+
maybe_natively_call_exit(exc_info=False)
699721
builder.goto_and_activate(out_block)
700722

701723
transform_try_finally_stmt(

mypyc/test-data/irbuild-try.test

+105
Original file line numberDiff line numberDiff line change
@@ -416,3 +416,108 @@ L19:
416416
L20:
417417
return 1
418418

419+
[case testWithNativeSimple]
420+
class DummyContext:
421+
def __enter__(self) -> None:
422+
pass
423+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
424+
pass
425+
426+
def foo(x: DummyContext) -> None:
427+
with x:
428+
print('hello')
429+
[out]
430+
def DummyContext.__enter__(self):
431+
self :: __main__.DummyContext
432+
L0:
433+
return 1
434+
def DummyContext.__exit__(self, exc_type, exc_val, exc_tb):
435+
self :: __main__.DummyContext
436+
exc_type, exc_val, exc_tb :: object
437+
L0:
438+
return 1
439+
def foo(x):
440+
x :: __main__.DummyContext
441+
r0 :: None
442+
r1 :: bool
443+
r2 :: str
444+
r3 :: object
445+
r4 :: str
446+
r5, r6 :: object
447+
r7, r8 :: tuple[object, object, object]
448+
r9, r10, r11 :: object
449+
r12 :: None
450+
r13 :: object
451+
r14 :: int32
452+
r15 :: bit
453+
r16 :: bool
454+
r17 :: bit
455+
r18, r19, r20 :: tuple[object, object, object]
456+
r21 :: object
457+
r22 :: None
458+
r23 :: bit
459+
L0:
460+
r0 = x.__enter__()
461+
r1 = 1
462+
L1:
463+
L2:
464+
r2 = 'hello'
465+
r3 = builtins :: module
466+
r4 = 'print'
467+
r5 = CPyObject_GetAttr(r3, r4)
468+
r6 = PyObject_CallFunctionObjArgs(r5, r2, 0)
469+
goto L8
470+
L3: (handler for L2)
471+
r7 = CPy_CatchError()
472+
r1 = 0
473+
r8 = CPy_GetExcInfo()
474+
r9 = r8[0]
475+
r10 = r8[1]
476+
r11 = r8[2]
477+
r12 = x.__exit__(r9, r10, r11)
478+
r13 = box(None, r12)
479+
r14 = PyObject_IsTrue(r13)
480+
r15 = r14 >= 0 :: signed
481+
r16 = truncate r14: int32 to builtins.bool
482+
if r16 goto L5 else goto L4 :: bool
483+
L4:
484+
CPy_Reraise()
485+
unreachable
486+
L5:
487+
L6:
488+
CPy_RestoreExcInfo(r7)
489+
goto L8
490+
L7: (handler for L3, L4, L5)
491+
CPy_RestoreExcInfo(r7)
492+
r17 = CPy_KeepPropagating()
493+
unreachable
494+
L8:
495+
L9:
496+
L10:
497+
r18 = <error> :: tuple[object, object, object]
498+
r19 = r18
499+
goto L12
500+
L11: (handler for L1, L6, L7, L8)
501+
r20 = CPy_CatchError()
502+
r19 = r20
503+
L12:
504+
if r1 goto L13 else goto L14 :: bool
505+
L13:
506+
r21 = load_address _Py_NoneStruct
507+
r22 = x.__exit__(r21, r21, r21)
508+
L14:
509+
if is_error(r19) goto L16 else goto L15
510+
L15:
511+
CPy_Reraise()
512+
unreachable
513+
L16:
514+
goto L20
515+
L17: (handler for L12, L13, L14, L15)
516+
if is_error(r19) goto L19 else goto L18
517+
L18:
518+
CPy_RestoreExcInfo(r19)
519+
L19:
520+
r23 = CPy_KeepPropagating()
521+
unreachable
522+
L20:
523+
return 1

mypyc/test-data/run-generators.test

+17
Original file line numberDiff line numberDiff line change
@@ -662,3 +662,20 @@ def list_comp() -> List[int]:
662662
[file driver.py]
663663
from native import list_comp
664664
assert list_comp() == [5]
665+
666+
[case testWithNative]
667+
class DummyContext:
668+
def __init__(self) -> None:
669+
self.x = 0
670+
671+
def __enter__(self) -> None:
672+
self.x += 1
673+
674+
def __exit__(self, exc_type, exc_value, exc_tb) -> None:
675+
self.x -= 1
676+
677+
def test_basic() -> None:
678+
context = DummyContext()
679+
with context:
680+
assert context.x == 1
681+
assert context.x == 0

mypyc/test-data/run-misc.test

+30
Original file line numberDiff line numberDiff line change
@@ -1116,3 +1116,33 @@ i = b"foo"
11161116

11171117
def test_redefinition() -> None:
11181118
assert i == b"foo"
1119+
1120+
[case testWithNative]
1121+
class DummyContext:
1122+
def __init__(self):
1123+
self.c = 0
1124+
def __enter__(self) -> None:
1125+
self.c += 1
1126+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
1127+
self.c -= 1
1128+
1129+
def test_dummy_context() -> None:
1130+
c = DummyContext()
1131+
with c:
1132+
assert c.c == 1
1133+
assert c.c == 0
1134+
1135+
[case testWithNativeVarArgs]
1136+
class DummyContext:
1137+
def __init__(self):
1138+
self.c = 0
1139+
def __enter__(self) -> None:
1140+
self.c += 1
1141+
def __exit__(self, *args: object) -> None:
1142+
self.c -= 1
1143+
1144+
def test_dummy_context() -> None:
1145+
c = DummyContext()
1146+
with c:
1147+
assert c.c == 1
1148+
assert c.c == 0

0 commit comments

Comments
 (0)