|
1 | 1 | import ast |
| 2 | +import functools |
2 | 3 | import inspect |
| 4 | +import types |
3 | 5 | from itertools import dropwhile |
| 6 | +from opcode import opmap |
4 | 7 | from textwrap import dedent |
| 8 | +from typing import Dict |
| 9 | + |
| 10 | +from bytecode import ConcreteBytecode |
| 11 | +from cloudpickle import cloudpickle |
5 | 12 |
|
6 | 13 |
|
7 | 14 | def set_lineno(node, n=1): |
@@ -42,6 +49,99 @@ def bind(func, instance, as_name=None): |
42 | 49 | return bound_method |
43 | 50 |
|
44 | 51 |
|
| 52 | +# based on https://github.com/python/cpython/blob/6078f2033ea15a16cf52fe8d644a95a3be72d2e3/Tools/build/deepfreeze.py#L48 |
| 53 | +def get_localsplus_name_to_idx(code: types.CodeType): |
| 54 | + localsplus = code.co_varnames + code.co_cellvars + code.co_freevars |
| 55 | + return localsplus, {v: i for i, v in enumerate(localsplus)} |
| 56 | + |
| 57 | + |
| 58 | +class _empty_cell_value: |
| 59 | + """Sentinel for empty closures.""" |
| 60 | + |
| 61 | + @classmethod |
| 62 | + def __reduce__(cls): |
| 63 | + return cls.__name__ |
| 64 | + |
| 65 | + |
| 66 | +_empty_cell_value = _empty_cell_value() |
| 67 | + |
| 68 | + |
| 69 | +# based on https://github.com/cloudpipe/cloudpickle/blob/f111f7ab6d302e9b1e2a568d0e4c574895db6a6e/cloudpickle/cloudpickle.py#L513 |
| 70 | +def make_empty_cell(): |
| 71 | + if False: |
| 72 | + # trick the compiler into creating an empty cell in our lambda |
| 73 | + cell = None |
| 74 | + raise AssertionError("this route should not be executed") |
| 75 | + |
| 76 | + return (lambda: cell).__closure__[0] |
| 77 | + |
| 78 | + |
| 79 | +def make_cell(value=_empty_cell_value): |
| 80 | + cell = make_empty_cell() |
| 81 | + if value is not _empty_cell_value: |
| 82 | + cell.cell_contents = value |
| 83 | + return cell |
| 84 | + |
| 85 | + |
| 86 | +# based on https://github.com/python/cpython/blob/a4b44d39cd6941cc03590fee7538776728bdfd0a/Lib/test/test_code.py#L197 |
| 87 | +def replace_closure(code, new_closure: Dict): |
| 88 | + COPY_FREE_VARS = opmap["COPY_FREE_VARS"] |
| 89 | + LOAD_DEREF = opmap["LOAD_DEREF"] |
| 90 | + |
| 91 | + # get the orig localplus that will be loaded from by the orig bytecode LOAD_DEREF arg_i |
| 92 | + localsplus, localsplus_name_to_idx = get_localsplus_name_to_idx(code) |
| 93 | + |
| 94 | + # closure vars go into co_freevars |
| 95 | + new_code = code.replace(co_freevars=tuple(new_closure.keys())) |
| 96 | + # closure is a tuple of cells |
| 97 | + closure = tuple( |
| 98 | + make_cell(v) if not isinstance(v, types.CellType) else v |
| 99 | + for v in new_closure.values() |
| 100 | + ) |
| 101 | + |
| 102 | + new_code = ConcreteBytecode.from_code(new_code) |
| 103 | + # update how many closure vars are loaded from frame |
| 104 | + # see https://github.com/python/cpython/blob/6078f2033ea15a16cf52fe8d644a95a3be72d2e3/Python/bytecodes.c#L1571 |
| 105 | + assert new_code[0].opcode == COPY_FREE_VARS |
| 106 | + new_code[0].arg = len(closure) |
| 107 | + |
| 108 | + # map orig localsplus arg_i to new localplus position/arg_i |
| 109 | + new_localsplus = new_code.varnames + new_code.cellvars + new_code.freevars |
| 110 | + new_localsplus_name_to_idx = {v: i for i, v in enumerate(new_localsplus)} |
| 111 | + for c in new_code: |
| 112 | + if c.opcode == LOAD_DEREF and c.arg < len(localsplus): |
| 113 | + c.arg = new_localsplus_name_to_idx[localsplus[c.arg]] |
| 114 | + new_code = new_code.to_code() |
| 115 | + |
| 116 | + return new_code, closure |
| 117 | + |
| 118 | + |
| 119 | +# Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard); |
| 120 | +# potentially more complete approach https://stackoverflow.com/a/56901529/9045206 |
| 121 | +def copy_func(f, new_closure: Dict = None): |
| 122 | + if new_closure is not None: |
| 123 | + code, closure = replace_closure(f.__code__, new_closure) |
| 124 | + else: |
| 125 | + code, closure = f.__code__, f.__closure__ |
| 126 | + |
| 127 | + g = types.FunctionType( |
| 128 | + code=code, |
| 129 | + globals=f.__globals__, |
| 130 | + name=f.__name__, |
| 131 | + argdefs=f.__defaults__, |
| 132 | + # see https://github.com/cloudpipe/cloudpickle/blob/f111f7ab6d302e9b1e2a568d0e4c574895db6a6e/cloudpickle/cloudpickle.py#L813 |
| 133 | + # for how this trick is accomplished (dill and pickle both fail to pickle eg generic typevars) |
| 134 | + closure=cloudpickle.loads(cloudpickle.dumps(closure)), |
| 135 | + ) |
| 136 | + g.__kwdefaults__ = f.__kwdefaults__ |
| 137 | + g.__dict__.update(f.__dict__) |
| 138 | + g = functools.update_wrapper(g, f) |
| 139 | + |
| 140 | + if inspect.ismethod(f): |
| 141 | + g = bind(g, f.__self__) |
| 142 | + return g |
| 143 | + |
| 144 | + |
45 | 145 | def append_hidden_node(node_body, new_node): |
46 | 146 | last_statement = node_body[-1] |
47 | 147 | new_node = ast.fix_missing_locations( |
|
0 commit comments