diff --git a/py-polars/polars/datatypes/constructor.py b/py-polars/polars/datatypes/constructor.py index c95923379a64..f867a5737b83 100644 --- a/py-polars/polars/datatypes/constructor.py +++ b/py-polars/polars/datatypes/constructor.py @@ -152,17 +152,19 @@ def numpy_type_to_constructor( if not _DOCUMENTING: _PY_TYPE_TO_CONSTRUCTOR = { float: PySeries.new_opt_f64, + bool: PySeries.new_opt_bool, int: PySeries.new_opt_i64, str: PySeries.new_str, - bool: PySeries.new_opt_bool, bytes: PySeries.new_binary, PyDecimal: PySeries.new_decimal, } -def py_type_to_constructor(dtype: type[Any]) -> Callable[..., PySeries]: +def py_type_to_constructor(py_type: type[Any]) -> Callable[..., PySeries]: """Get the right PySeries constructor for the given Python dtype.""" - try: - return _PY_TYPE_TO_CONSTRUCTOR[dtype] - except KeyError: - return PySeries.new_object + py_type = ( + next((tp for tp in _PY_TYPE_TO_CONSTRUCTOR if issubclass(py_type, tp)), py_type) + if py_type not in _PY_TYPE_TO_CONSTRUCTOR + else py_type + ) + return _PY_TYPE_TO_CONSTRUCTOR.get(py_type, PySeries.new_object) diff --git a/py-polars/tests/unit/constructors/test_constructors.py b/py-polars/tests/unit/constructors/test_constructors.py index bad66919fbdd..72a913355651 100644 --- a/py-polars/tests/unit/constructors/test_constructors.py +++ b/py-polars/tests/unit/constructors/test_constructors.py @@ -21,11 +21,16 @@ from polars.testing import assert_frame_equal, assert_series_equal if TYPE_CHECKING: + import sys from collections.abc import Callable from zoneinfo import ZoneInfo from polars._typing import PolarsDataType + if sys.version_info >= (3, 11): + from typing import Self + else: + from typing_extensions import Self else: from polars._utils.convert import string_to_zoneinfo as ZoneInfo @@ -1752,3 +1757,39 @@ def test_init_list_of_dicts_with_timezone(tz: Any) -> None: assert_frame_equal(df, expected) assert df.schema == {"dt": pl.Datetime("us", time_zone=tz and "UTC")} + + +def test_init_from_subclassed_types() -> None: + # more detailed test of one custom subclass... + import codecs + + class SuperSecretString(str): + def __new__(cls, value: str) -> Self: + return super().__new__(cls, value) + + def __repr__(self) -> str: + return codecs.encode(self, "rot_13") + + w = "windmolen" + sstr = SuperSecretString(w) + + assert sstr == w + assert isinstance(sstr, str) + assert repr(sstr) == "jvaqzbyra" + assert_series_equal(pl.Series([w, w]), pl.Series([sstr, sstr])) + + # ...then validate across other basic types + for BaseType, value in ( + (int, 42), + (float, 5.5), + (bytes, b"value"), + (str, "value"), + ): + + class SubclassedType(BaseType): # type: ignore[misc,valid-type] + def __new__(cls, value: Any) -> Self: + return super().__new__(cls, value) # type: ignore[no-any-return] + + assert ( + pl.Series([value]).to_list() == pl.Series([SubclassedType(value)]).to_list() + )