diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index 6360d052fa6d..93393ec545a0 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -349,6 +349,7 @@ PyObject *CPyDict_Keys(PyObject *dict); PyObject *CPyDict_Values(PyObject *dict); PyObject *CPyDict_Items(PyObject *dict); char CPyDict_Clear(PyObject *dict); +PyObject *CPyDict_Copy(PyObject *dict); PyObject *CPyDict_GetKeysIter(PyObject *dict); PyObject *CPyDict_GetItemsIter(PyObject *dict); PyObject *CPyDict_GetValuesIter(PyObject *dict); diff --git a/mypyc/lib-rt/dict_ops.c b/mypyc/lib-rt/dict_ops.c index fca096468ddd..1e635ada2c00 100644 --- a/mypyc/lib-rt/dict_ops.c +++ b/mypyc/lib-rt/dict_ops.c @@ -238,6 +238,13 @@ char CPyDict_Clear(PyObject *dict) { return 1; } +PyObject *CPyDict_Copy(PyObject *dict) { + if (PyDict_CheckExact(dict)) { + return PyDict_Copy(dict); + } + return PyObject_CallMethod(dict, "copy", NULL); +} + PyObject *CPyDict_GetKeysIter(PyObject *dict) { if (PyDict_CheckExact(dict)) { // Return dict itself to indicate we can use fast path instead. diff --git a/mypyc/primitives/dict_ops.py b/mypyc/primitives/dict_ops.py index e0278182d5fe..ff9b6482a782 100644 --- a/mypyc/primitives/dict_ops.py +++ b/mypyc/primitives/dict_ops.py @@ -150,6 +150,14 @@ c_function_name='CPyDict_Clear', error_kind=ERR_FALSE) +# dict.copy() +method_op( + name='copy', + arg_types=[dict_rprimitive], + return_type=dict_rprimitive, + c_function_name='CPyDict_Copy', + error_kind=ERR_MAGIC) + # list(dict.keys()) dict_keys_op = custom_op( arg_types=[dict_rprimitive], diff --git a/mypyc/test-data/fixtures/ir.py b/mypyc/test-data/fixtures/ir.py index bb85a41fbe24..ec31ee02fd0a 100644 --- a/mypyc/test-data/fixtures/ir.py +++ b/mypyc/test-data/fixtures/ir.py @@ -173,6 +173,7 @@ def keys(self) -> Iterable[K]: pass def values(self) -> Iterable[V]: pass def items(self) -> Iterable[Tuple[K, V]]: pass def clear(self) -> None: pass + def copy(self) -> Dict[K, V]: pass class set(Generic[T]): def __init__(self, i: Optional[Iterable[T]] = None) -> None: pass diff --git a/mypyc/test-data/irbuild-dict.test b/mypyc/test-data/irbuild-dict.test index 5a127ffdfbdb..1225750b74df 100644 --- a/mypyc/test-data/irbuild-dict.test +++ b/mypyc/test-data/irbuild-dict.test @@ -324,3 +324,14 @@ def f(d): L0: r0 = CPyDict_Clear(d) return 1 + +[case testDictCopy] +from typing import Dict +def f(d: Dict[int, int]) -> Dict[int, int]: + return d.copy() +[out] +def f(d): + d, r0 :: dict +L0: + r0 = CPyDict_Copy(d) + return r0 diff --git a/mypyc/test-data/run-dicts.test b/mypyc/test-data/run-dicts.test index 1955d514d43a..329d186874f4 100644 --- a/mypyc/test-data/run-dicts.test +++ b/mypyc/test-data/run-dicts.test @@ -206,3 +206,14 @@ def test_dict_clear() -> None: dd['a'] = 1 dd.clear() assert dd == {} + +def test_dict_copy() -> None: + d: Dict[str, int] = {} + assert d.copy() == d + d = {'a': 1, 'b': 2} + assert d.copy() == d + assert d.copy() is not d + dd: Dict[str, int] = defaultdict(int) + dd['a'] = 1 + assert dd.copy() == dd + assert isinstance(dd.copy(), defaultdict)