Skip to content

Commit

Permalink
support big ints in literals and enums (#1297)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin authored May 21, 2024
1 parent 727deee commit f04418b
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 7 deletions.
1 change: 1 addition & 0 deletions src/input/return_enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,7 @@ impl<'a> EitherInt<'a> {
Ok(Self::BigInt(big_int))
}
}

pub fn into_i64(self, py: Python<'a>) -> ValResult<i64> {
match self {
EitherInt::I64(i) => Ok(i),
Expand Down
14 changes: 8 additions & 6 deletions src/validators/literal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use core::fmt::Debug;
use std::cmp::Ordering;

use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList};
use pyo3::types::{PyDict, PyInt, PyList};
use pyo3::{intern, PyTraverseError, PyVisit};

use ahash::AHashMap;
Expand Down Expand Up @@ -58,11 +58,13 @@ impl<T: Debug> LiteralLookup<T> {
expected_bool.false_id = Some(id);
}
}
if let Ok(either_int) = k.exact_int() {
let int = either_int
.into_i64(py)
.map_err(|_| py_schema_error_type!("error extracting int {:?}", k))?;
expected_int.insert(int, id);
if k.is_exact_instance_of::<PyInt>() {
if let Ok(int_64) = k.extract::<i64>() {
expected_int.insert(int_64, id);
} else {
// cover the case of an int that's > i64::MAX etc.
expected_py_dict.set_item(k, id)?;
}
} else if let Ok(either_str) = k.exact_str() {
let str = either_str
.as_cow()
Expand Down
15 changes: 14 additions & 1 deletion tests/validators/test_enums.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import re
import sys
from enum import Enum, IntFlag
from enum import Enum, IntEnum, IntFlag

import pytest

Expand Down Expand Up @@ -331,3 +331,16 @@ class MyFlags(IntFlag):

with pytest.raises(ValidationError):
v.validate_python(None)


def test_big_int():
class ColorEnum(IntEnum):
GREEN = 1 << 63
BLUE = 1 << 64

v = SchemaValidator(
core_schema.with_default_schema(schema=core_schema.enum_schema(ColorEnum, list(ColorEnum.__members__.values())))
)

assert v.validate_python(ColorEnum.GREEN) is ColorEnum.GREEN
assert v.validate_python(1 << 63) is ColorEnum.GREEN
11 changes: 11 additions & 0 deletions tests/validators/test_literal.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,3 +378,14 @@ class Foo(str, Enum):
with pytest.raises(ValidationError) as exc_info:
v.validate_python('bar_val')
assert exc_info.value.errors(include_url=False) == err


def test_big_int():
big_int = 2**64 + 1
massive_int = 2**128 + 1
v = SchemaValidator(core_schema.literal_schema([big_int, massive_int]))
assert v.validate_python(big_int) == big_int
assert v.validate_python(massive_int) == massive_int
m = r'Input should be 18446744073709551617 or 340282366920938463463374607431768211457 \[type=literal_error'
with pytest.raises(ValidationError, match=m):
v.validate_python(37)

0 comments on commit f04418b

Please sign in to comment.