Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mypyc] (Re-)Support iterating over an Union of dicts #14713

Merged
merged 3 commits into from
Feb 16, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,23 +879,39 @@ def get_sequence_type_from_type(self, target_type: Type) -> RType:
else:
return self.type_to_rtype(target_type.args[0])

def get_dict_base_type(self, expr: Expression) -> Instance:
def get_dict_base_type(self, expr: Expression) -> list[Instance]:
"""Find dict type of a dict-like expression.

This is useful for dict subclasses like SymbolTable.
"""
target_type = get_proper_type(self.types[expr])
assert isinstance(target_type, Instance), target_type
dict_base = next(base for base in target_type.type.mro if base.fullname == "builtins.dict")
return map_instance_to_supertype(target_type, dict_base)
if isinstance(target_type, UnionType):
types = [get_proper_type(item) for item in target_type.items]
else:
types = [target_type]

dict_types = []
for t in types:
assert isinstance(t, Instance), t
dict_base = next(base for base in t.type.mro if base.fullname == "builtins.dict")
dict_types.append(map_instance_to_supertype(t, dict_base))
return dict_types

def get_dict_key_type(self, expr: Expression) -> RType:
dict_base_type = self.get_dict_base_type(expr)
return self.type_to_rtype(dict_base_type.args[0])
dict_base_types = self.get_dict_base_type(expr)
if len(dict_base_types) == 1:
return self.type_to_rtype(dict_base_types[0].args[0])
else:
rtypes = [self.type_to_rtype(t.args[0]) for t in dict_base_types]
return RUnion.make_simplified_union(rtypes)

def get_dict_value_type(self, expr: Expression) -> RType:
dict_base_type = self.get_dict_base_type(expr)
return self.type_to_rtype(dict_base_type.args[1])
dict_base_types = self.get_dict_base_type(expr)
if len(dict_base_types) == 1:
return self.type_to_rtype(dict_base_types[0].args[1])
else:
rtypes = [self.type_to_rtype(t.args[1]) for t in dict_base_types]
return RUnion.make_simplified_union(rtypes)

def get_dict_item_type(self, expr: Expression) -> RType:
key_type = self.get_dict_key_type(expr)
Expand Down
58 changes: 57 additions & 1 deletion mypyc/test-data/irbuild-dict.test
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,17 @@ L0:
return r2

[case testDictIterationMethods]
from typing import Dict
from typing import Dict, Union
def print_dict_methods(d1: Dict[int, int], d2: Dict[int, int]) -> None:
for v in d1.values():
if v in d2:
return
for k, v in d2.items():
d2[k] += v
def union_of_dicts(d: Union[Dict[str, int], Dict[str, str]]) -> None:
new = {}
for k, v in d.items():
new[k] = int(v)
[out]
def print_dict_methods(d1, d2):
d1, d2 :: dict
Expand Down Expand Up @@ -314,6 +318,58 @@ L11:
r34 = CPy_NoErrOccured()
L12:
return 1
def union_of_dicts(d):
d, r0, new :: dict
r1 :: short_int
r2 :: native_int
r3 :: short_int
r4 :: object
r5 :: tuple[bool, short_int, object, object]
r6 :: short_int
r7 :: bool
r8, r9 :: object
r10 :: str
r11 :: union[int, str]
k :: str
v :: union[int, str]
r12, r13 :: object
r14 :: int
r15 :: object
r16 :: int32
r17, r18, r19 :: bit
L0:
r0 = PyDict_New()
new = r0
r1 = 0
r2 = PyDict_Size(d)
r3 = r2 << 1
r4 = CPyDict_GetItemsIter(d)
L1:
r5 = CPyDict_NextItem(r4, r1)
r6 = r5[1]
r1 = r6
r7 = r5[0]
if r7 goto L2 else goto L4 :: bool
L2:
r8 = r5[2]
r9 = r5[3]
r10 = cast(str, r8)
r11 = cast(union[int, str], r9)
k = r10
v = r11
r12 = load_address PyLong_Type
r13 = PyObject_CallFunctionObjArgs(r12, v, 0)
r14 = unbox(int, r13)
r15 = box(int, r14)
r16 = CPyDict_SetItem(new, k, r15)
r17 = r16 >= 0 :: signed
L3:
r18 = CPyDict_CheckSize(d, r3)
goto L1
L4:
r19 = CPy_NoErrOccured()
L5:
return 1

[case testDictLoadAddress]
def f() -> None:
Expand Down