diff --git a/mypyc/codegen/emitmodule.py b/mypyc/codegen/emitmodule.py index 476c435712bd..5a495ec0b61c 100644 --- a/mypyc/codegen/emitmodule.py +++ b/mypyc/codegen/emitmodule.py @@ -888,7 +888,7 @@ def generate_module_def(self, emitter: Emitter, module_name: str, module: Module emitter.emit_lines('{} = PyModule_Create(&{}module);'.format(module_static, module_prefix), 'if (unlikely({} == NULL))'.format(module_static), - ' return NULL;') + ' goto fail;') emitter.emit_line( 'PyObject *modname = PyObject_GetAttrString((PyObject *){}, "__name__");'.format( module_static)) @@ -896,7 +896,7 @@ def generate_module_def(self, emitter: Emitter, module_name: str, module: Module module_globals = emitter.static_name('globals', module_name) emitter.emit_lines('{} = PyModule_GetDict({});'.format(module_globals, module_static), 'if (unlikely({} == NULL))'.format(module_globals), - ' return NULL;') + ' goto fail;') # HACK: Manually instantiate generated classes here for cl in module.classes: @@ -907,16 +907,19 @@ def generate_module_def(self, emitter: Emitter, module_name: str, module: Module '(PyObject *){t}_template, NULL, modname);' .format(t=type_struct)) emitter.emit_lines('if (unlikely(!{}))'.format(type_struct), - ' return NULL;') + ' goto fail;') emitter.emit_lines('if (CPyGlobalsInit() < 0)', - ' return NULL;') + ' goto fail;') self.generate_top_level_call(module, emitter) emitter.emit_lines('Py_DECREF(modname);') emitter.emit_line('return {};'.format(module_static)) + emitter.emit_lines('fail:', + '{} = NULL;'.format(module_static), + 'return NULL;') emitter.emit_line('}') def generate_top_level_call(self, module: ModuleIR, emitter: Emitter) -> None: @@ -927,7 +930,7 @@ def generate_top_level_call(self, module: ModuleIR, emitter: Emitter) -> None: emitter.emit_lines( 'char result = {}();'.format(emitter.native_function_name(fn.decl)), 'if (result == 2)', - ' return NULL;', + ' goto fail;', ) break diff --git a/mypyc/test-data/run-misc.test b/mypyc/test-data/run-misc.test index 1df800ee6560..dd69e42e9373 100644 --- a/mypyc/test-data/run-misc.test +++ b/mypyc/test-data/run-misc.test @@ -1126,3 +1126,27 @@ C = sys.platform == 'x' and (lambda x: y + x) assert not A assert not B assert not C + +[case testDoesntSegfaultWhenTopLevelFails] +# make the initial import fail +assert False + +class C: + def __init__(self): + self.x = 1 + self.y = 2 +def test() -> None: + a = C() +[file driver.py] +# load native, cause PyInit to be run, create the module but don't finish initializing the globals +try: + import native +except: + pass +try: + # try accessing those globals that were never properly initialized + import native + native.test() +# should fail with AssertionError due to `assert False` in other function +except AssertionError: + pass