Skip to content

Commit

Permalink
Structure code so by choosing {"type": Literal} we can pick the Literal
Browse files Browse the repository at this point in the history
  • Loading branch information
rhaps0dy committed May 10, 2024
1 parent d54db1c commit 0c5094b
Showing 1 changed file with 22 additions and 8 deletions.
30 changes: 22 additions & 8 deletions databind/src/databind/json/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,20 +763,39 @@ def _check_style_compatibility(self, ctx: Context, style: str, value: t.Any) ->
def convert(self, ctx: Context) -> t.Any:
datatype = ctx.datatype
union: t.Optional[Union]
literal_types: list[TypeHint] = []

if isinstance(datatype, UnionTypeHint):
if datatype.has_none_type():
raise NotImplementedError("unable to handle Union type with None in it")

literal_types = [a for a in datatype if isinstance(a, LiteralTypeHint)]
literal_values = [
literal_value
for literal_type in datatype
if isinstance(literal_type, LiteralTypeHint)
for literal_value in literal_type.values
]

literal_type: t.Optional[TypeHint] = None
if literal_values:
literal_type = t.Literal[tuple(literal_values)] # type: ignore

non_literal_types = [a for a in datatype if not isinstance(a, LiteralTypeHint)]
if not all(isinstance(a, ClassTypeHint) for a in non_literal_types):
raise NotImplementedError(f"members of plain Union must be concrete or Literal types: {datatype}")

members = {t.cast(ClassTypeHint, a).type.__name__: a for a in non_literal_types}
if len(members) != len(non_literal_types):
raise NotImplementedError(f"members of plain Union cannot have overlapping type names: {datatype}")
raise ConversionError(
self, ctx, f"members of plain Union cannot have overlapping type names: {datatype}"
)

if literal_type is not None:
if "Literal" in members:
raise ConversionError(
self, ctx, f"members of plain Union with a Literal cannot have type name Literal: {datatype}"
)
members["Literal"] = literal_type

union = Union(members, Union.BEST_MATCH)
elif isinstance(datatype, (AnnotatedTypeHint, ClassTypeHint)):
union = ctx.get_setting(Union)
Expand All @@ -794,11 +813,6 @@ def convert(self, ctx: Context) -> t.Any:
return ctx.spawn(ctx.value, member_type, None).convert()
except ConversionError as exc:
errors.append((exc.origin, exc))
for literal_type in literal_types:
try:
return ctx.spawn(ctx.value, literal_type, None).convert()
except ConversionError as exc:
errors.append((exc.origin, exc))
raise ConversionError(
self,
ctx,
Expand Down

0 comments on commit 0c5094b

Please sign in to comment.