Skip to content

Commit

Permalink
feat: literal type support
Browse files Browse the repository at this point in the history
  • Loading branch information
moriyoshi committed Mar 19, 2024
1 parent 59a83d4 commit 0ef2796
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 23 deletions.
29 changes: 29 additions & 0 deletions src/jsonie/tests/test_to_jsonic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import pytest

from ..exceptions import ToJsonicConverterError


@pytest.mark.parametrize(
("expected", "input"),
Expand Down Expand Up @@ -580,3 +582,30 @@ def test_pytyped_jsonic_data_to_jsonic_namedtuple(expected, input):
from ..to_jsonic import ToJsonicConverter

assert ToJsonicConverter()(input[0], input[1]) == expected


if hasattr(typing, "Literal"):
Literals = typing.Literal["A", "B", "C"]
UnionedLiterals = typing.Union[typing.Literal["A"], typing.Literal["B", "C"]]

@pytest.mark.parametrize(
("expected", "input"),
[
("A", (Literals, "A")),
("B", (Literals, "B")),
("C", (Literals, "C")),
(ToJsonicConverterError, (Literals, "D")),
("A", (UnionedLiterals, "A")),
("B", (UnionedLiterals, "B")),
("C", (UnionedLiterals, "C")),
(ToJsonicConverterError, (UnionedLiterals, "D")),
],
)
def test_literal(expected, input):
from ..to_jsonic import ToJsonicConverter

if isinstance(expected, type) and issubclass(expected, BaseException):
with pytest.raises(expected):
ToJsonicConverter()(input[0], input[1])
else:
assert ToJsonicConverter()(input[0], input[1]) == expected
51 changes: 32 additions & 19 deletions src/jsonie/to_jsonic.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@


class CustomConverter(typing.Protocol):
@abc.abstractmethod
def resolve_name(self, typ: JsonicType) -> str:
... # pragma: nocover

Expand All @@ -52,6 +53,7 @@ def __call__(
class CustomConverterFuncAdapter:
convert: CustomConverterConvertFunc # type: ignore

@abc.abstractmethod
def resolve_name(self, typ: JsonicType) -> str:
... # pragma: nocover

Expand Down Expand Up @@ -79,8 +81,7 @@ def resolve(
tctx: "TraversalContext",
typ: JsonicType,
name: str,
) -> typing.Optional[str]:
...
) -> typing.Optional[str]: ...

@abc.abstractmethod
def reverse_resolve(
Expand All @@ -89,8 +90,7 @@ def reverse_resolve(
tctx: "TraversalContext",
typ: JsonicType,
name: str,
) -> typing.Optional[str]:
...
) -> typing.Optional[str]: ...


NameMapperResolveFunc = typing.Callable[
Expand All @@ -105,17 +105,15 @@ def resolve(
tctx: "TraversalContext",
typ: JsonicType,
name: str,
) -> typing.Optional[str]:
...
) -> typing.Optional[str]: ...

def reverse_resolve(
self,
converter: "ToJsonicConverter",
tctx: "TraversalContext",
typ: JsonicType,
name: str,
) -> typing.Optional[str]:
...
) -> typing.Optional[str]: ...

def __init__(self, resolve: NameMapperResolveFunc, reverse_resolve: NameMapperResolveFunc):
self.resolve = resolve # type: ignore
Expand Down Expand Up @@ -151,12 +149,10 @@ def reverse_resolve(
class ConverterContext(metaclass=abc.ABCMeta):
@property
@abc.abstractmethod
def stopped(self) -> bool:
...
def stopped(self) -> bool: ...

@abc.abstractmethod
def validation_error_occurred(self, error: ToJsonicConverterError) -> None:
...
def validation_error_occurred(self, error: ToJsonicConverterError) -> None: ...


class DefaultConverterContext(ConverterContext):
Expand Down Expand Up @@ -242,8 +238,7 @@ def is_namedtuple(typ: typing.Type[T]) -> bool:
class NamedTupleType(typing.Protocol):
_fields: typing.Sequence[str]

def _make(self, values: typing.Iterable[str]) -> typing.Tuple:
... # pragma: nocover
def _make(self, values: typing.Iterable[str]) -> typing.Tuple: ... # pragma: nocover


class ToJsonicConverter:
Expand Down Expand Up @@ -286,6 +281,9 @@ def type_repr(self, typ: typing_compat.GenericAlias) -> str: # type: ignore
return custom_converter.resolve_name(typ)
if typing_compat.is_union_type(typ):
return f"any of {english_enumerate((self.type_repr(t) for t in typing_compat.get_args(typ)), conj=', or ')}"
elif typing_compat.is_literal_type(typ): # type: ignore
args = typing_compat.get_args(typ)
return f"any of literal {english_enumerate((str(v) for v in args), conj=', or ')}"
elif typing_compat.is_generic_type(typ): # type: ignore
origin = typing_compat.get_origin(typ)
if issubclass(
Expand Down Expand Up @@ -573,6 +571,19 @@ def _convert_with_generic_type(self, tctx: TraversalContext, typ: typing_compat.
)
return (None, math.inf)

def _convert_with_literal_type(self, tctx: TraversalContext, typ: typing_compat.GenericAlias, value: JSONValue) -> typing.Tuple[JsonicValue, float]: # type: ignore
possible_literals = {v for v in typing_compat.get_args(typ) if isinstance(v, (bool, int, float, str))}
if value in possible_literals:
return (value, 1.0)
else:
tctx.ctx.validation_error_occurred(
ToJsonicConverterError(
tctx.pointer,
f"value is ({json.dumps(value)}) where any of {english_enumerate((str(v) for v in possible_literals), conj=', or ')} expected",
)
)
return (None, math.inf)

def _convert_with_union(self, tctx: TraversalContext, typ: typing_compat.GenericAlias, value: JSONValue) -> typing.Iterable[typing.Tuple[JsonicValue, float]]: # type: ignore
args = typing_compat.get_args(typ)
if len(args) == 2 and None.__class__ in args:
Expand Down Expand Up @@ -602,7 +613,6 @@ def _convert_with_typeddict(self, tctx: TraversalContext, typ: typing._TypedDict
entries: typing.List[typing.Tuple[str, JsonicValue]] = []
confidence = 1.0
for n, vtyp in self._get_type_hints(tctx, typ).items():
v: JsonicValue
name_mapper = self._lookup_name_mapper(typ)
k: str
if name_mapper is None:
Expand Down Expand Up @@ -644,6 +654,7 @@ def _convert_with_dataclass(
)
)
return (None, math.inf)
assert dataclasses.is_dataclass(typ)
attrs: typing.MutableMapping[str, JsonicValue] = {}
confidence = 1.0
for field in dataclasses.fields(typ):
Expand Down Expand Up @@ -813,6 +824,8 @@ def _convert_inner(
)
return (None, math.inf)
return candidates[0]
elif typing_compat.is_literal_type(typ):
return self._convert_with_literal_type(tctx, typ, value)
elif typing_compat.is_generic_type(typ):
return self._convert_with_generic_type(tctx, typ, value)
elif isinstance(typ, typing._SpecialForm):
Expand Down Expand Up @@ -841,8 +854,9 @@ def convert(self, ctx: ConverterContext, typ: typing.Union[typing_compat.Generic
... # pragma: nocover

@typing.overload
def convert(self, ctx: ConverterContext, typ: typing.Type[T], value: JSONValue) -> T:
... # pragma: nocover
def convert(
self, ctx: ConverterContext, typ: typing.Type[T], value: JSONValue
) -> T: ... # pragma: nocover

def convert(self, ctx: ConverterContext, typ: JsonicType, value: JSONValue) -> typing.Any:
pair = self._convert(TraversalContext(ctx, JSONPointer(), typ), typ, value)
Expand All @@ -853,8 +867,7 @@ def __call__(self, typ: typing.Union[typing_compat.GenericAlias, typing._Special
... # pragma: nocover

@typing.overload
def __call__(self, typ: typing.Type[T], value: JSONValue) -> T:
... # pragma: nocover
def __call__(self, typ: typing.Type[T], value: JSONValue) -> T: ... # pragma: nocover

def __call__(self, typ: JsonicType, value: JSONValue) -> typing.Any:
return self.convert(DefaultConverterContext(), typ, value)
Expand Down
11 changes: 11 additions & 0 deletions src/jsonie/typing_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,14 @@ def is_generic_type(typ: typing.Any) -> bool:

def is_genuine_type(typ: typing.Union[GenericAlias, UnionType, typing.Type]) -> bool:
return not isinstance(typ, _generic_alias_types) and not isinstance(typ, typing._SpecialForm) # type: ignore


if hasattr(typing, "Literal"):

def is_literal_type(typ: GenericAlias) -> bool:
return get_origin(typ) is typing.Literal

else:

def is_literal_type(typ: GenericAlias) -> bool:
return False
6 changes: 2 additions & 4 deletions src/jsonie/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ def cause(ex: Tex, cause: Exception) -> Tex:


class DateConstructorProtocol(typing.Protocol):
def __call__(self, year: int, month: int, day: int):
... # pragma: nocover
def __call__(self, year: int, month: int, day: int): ... # pragma: nocover


Td = typing.TypeVar("Td") # Td implements DateConstructorProtocol
Expand All @@ -38,8 +37,7 @@ def __call__(
second: int,
microsecond: int,
tzinfo: typing.Optional[datetime.tzinfo],
):
...
): ...


Tdt = typing.TypeVar("Tdt") # Tdt implements DateTimeConstructorProtocol
Expand Down

0 comments on commit 0ef2796

Please sign in to comment.