From 7237831d64c051b2d6e4d99970f9b6ccf7a7bfce Mon Sep 17 00:00:00 2001 From: Richard Si Date: Thu, 16 Feb 2023 05:05:06 -0500 Subject: [PATCH] [mypyc] (Re-)Support iterating over an Union of dicts (#14713) An optimization to make iterating over dict.keys(), dict.values() and dict.items() faster caused mypyc to crash while compiling a Union of dictionaries. This commit fixes the optimization helpers to properly handle unions. irbuild.Builder.get_dict_base_type() now returns list[Instance] with the union items. In the common case we don't have a union, a single-element list is returned. And get_dict_key_type() and get_dict_value_type() will now build a simplified RUnion as needed. Fixes https://github.com/mypyc/mypyc/issues/965 and probably #14694. --- mypyc/codegen/literals.py | 5 ++- mypyc/irbuild/builder.py | 32 ++++++++++++----- mypyc/test-data/irbuild-dict.test | 58 ++++++++++++++++++++++++++++++- 3 files changed, 83 insertions(+), 12 deletions(-) diff --git a/mypyc/codegen/literals.py b/mypyc/codegen/literals.py index 784a8ed27c4e4..05884b754452f 100644 --- a/mypyc/codegen/literals.py +++ b/mypyc/codegen/literals.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Dict, FrozenSet, List, Tuple, Union, cast +from typing import Any, FrozenSet, List, Tuple, Union, cast from typing_extensions import Final # Supported Python literal types. All tuple / frozenset items must have supported @@ -151,8 +151,7 @@ def _encode_collection_values( ... """ - # FIXME: https://github.com/mypyc/mypyc/issues/965 - value_by_index = {index: value for value, index in cast(Dict[Any, int], values).items()} + value_by_index = {index: value for value, index in values.items()} result = [] count = len(values) result.append(str(count)) diff --git a/mypyc/irbuild/builder.py b/mypyc/irbuild/builder.py index f2a70d4e8691e..f37fae608083b 100644 --- a/mypyc/irbuild/builder.py +++ b/mypyc/irbuild/builder.py @@ -879,23 +879,39 @@ def get_sequence_type_from_type(self, target_type: Type) -> RType: else: return self.type_to_rtype(target_type.args[0]) - def get_dict_base_type(self, expr: Expression) -> Instance: + def get_dict_base_type(self, expr: Expression) -> list[Instance]: """Find dict type of a dict-like expression. This is useful for dict subclasses like SymbolTable. """ target_type = get_proper_type(self.types[expr]) - assert isinstance(target_type, Instance), target_type - dict_base = next(base for base in target_type.type.mro if base.fullname == "builtins.dict") - return map_instance_to_supertype(target_type, dict_base) + if isinstance(target_type, UnionType): + types = [get_proper_type(item) for item in target_type.items] + else: + types = [target_type] + + dict_types = [] + for t in types: + assert isinstance(t, Instance), t + dict_base = next(base for base in t.type.mro if base.fullname == "builtins.dict") + dict_types.append(map_instance_to_supertype(t, dict_base)) + return dict_types def get_dict_key_type(self, expr: Expression) -> RType: - dict_base_type = self.get_dict_base_type(expr) - return self.type_to_rtype(dict_base_type.args[0]) + dict_base_types = self.get_dict_base_type(expr) + if len(dict_base_types) == 1: + return self.type_to_rtype(dict_base_types[0].args[0]) + else: + rtypes = [self.type_to_rtype(t.args[0]) for t in dict_base_types] + return RUnion.make_simplified_union(rtypes) def get_dict_value_type(self, expr: Expression) -> RType: - dict_base_type = self.get_dict_base_type(expr) - return self.type_to_rtype(dict_base_type.args[1]) + dict_base_types = self.get_dict_base_type(expr) + if len(dict_base_types) == 1: + return self.type_to_rtype(dict_base_types[0].args[1]) + else: + rtypes = [self.type_to_rtype(t.args[1]) for t in dict_base_types] + return RUnion.make_simplified_union(rtypes) def get_dict_item_type(self, expr: Expression) -> RType: key_type = self.get_dict_key_type(expr) diff --git a/mypyc/test-data/irbuild-dict.test b/mypyc/test-data/irbuild-dict.test index 3e2c295637ab5..99643b9451f0b 100644 --- a/mypyc/test-data/irbuild-dict.test +++ b/mypyc/test-data/irbuild-dict.test @@ -218,13 +218,17 @@ L0: return r2 [case testDictIterationMethods] -from typing import Dict +from typing import Dict, Union def print_dict_methods(d1: Dict[int, int], d2: Dict[int, int]) -> None: for v in d1.values(): if v in d2: return for k, v in d2.items(): d2[k] += v +def union_of_dicts(d: Union[Dict[str, int], Dict[str, str]]) -> None: + new = {} + for k, v in d.items(): + new[k] = int(v) [out] def print_dict_methods(d1, d2): d1, d2 :: dict @@ -314,6 +318,58 @@ L11: r34 = CPy_NoErrOccured() L12: return 1 +def union_of_dicts(d): + d, r0, new :: dict + r1 :: short_int + r2 :: native_int + r3 :: short_int + r4 :: object + r5 :: tuple[bool, short_int, object, object] + r6 :: short_int + r7 :: bool + r8, r9 :: object + r10 :: str + r11 :: union[int, str] + k :: str + v :: union[int, str] + r12, r13 :: object + r14 :: int + r15 :: object + r16 :: int32 + r17, r18, r19 :: bit +L0: + r0 = PyDict_New() + new = r0 + r1 = 0 + r2 = PyDict_Size(d) + r3 = r2 << 1 + r4 = CPyDict_GetItemsIter(d) +L1: + r5 = CPyDict_NextItem(r4, r1) + r6 = r5[1] + r1 = r6 + r7 = r5[0] + if r7 goto L2 else goto L4 :: bool +L2: + r8 = r5[2] + r9 = r5[3] + r10 = cast(str, r8) + r11 = cast(union[int, str], r9) + k = r10 + v = r11 + r12 = load_address PyLong_Type + r13 = PyObject_CallFunctionObjArgs(r12, v, 0) + r14 = unbox(int, r13) + r15 = box(int, r14) + r16 = CPyDict_SetItem(new, k, r15) + r17 = r16 >= 0 :: signed +L3: + r18 = CPyDict_CheckSize(d, r3) + goto L1 +L4: + r19 = CPy_NoErrOccured() +L5: + return 1 [case testDictLoadAddress] def f() -> None: