Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CLN: handle EAs and fast path (no bounds checking) in safe_sort #25696

Merged
merged 17 commits into from
May 7, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
7356997
BUG: fix usage of na_sentinel with sort=True in factorize()
jorisvandenbossche Mar 7, 2019
e1ab3a4
fix dtype
jorisvandenbossche Mar 11, 2019
a9c880e
Merge remote-tracking branch 'upstream/master' into factorize-na-sent…
jorisvandenbossche Mar 11, 2019
db30797
Merge remote-tracking branch 'upstream/master' into factorize-na-sent…
jorisvandenbossche Mar 12, 2019
ba944eb
Attempt to include it in safe_sort
jorisvandenbossche Mar 12, 2019
c6203cb
Merge remote-tracking branch 'upstream/master' into factorize-na-sent…
jorisvandenbossche Mar 12, 2019
d70b447
Merge remote-tracking branch 'upstream/master' into factorize-na-sent…
jorisvandenbossche Apr 5, 2019
fdf330a
feedback Jeff
jorisvandenbossche Apr 5, 2019
b08ea6d
add tests for safe_sort
jorisvandenbossche Apr 5, 2019
9de26fc
additional test for other na_sentinel in case of out of bound indices
jorisvandenbossche Apr 5, 2019
bcb8c7e
additional test for EA with custom na_sentinel
jorisvandenbossche Apr 5, 2019
13f6706
update factorize test for EAs with custom na_sentinel (which now work…
jorisvandenbossche Apr 5, 2019
8db84e7
Merge remote-tracking branch 'upstream/master' into factorize-na-sent…
jorisvandenbossche Apr 5, 2019
d0cef9e
Merge remote-tracking branch 'upstream/master' into factorize-na-sent…
jorisvandenbossche Apr 11, 2019
5157e89
Merge remote-tracking branch 'upstream/master' into factorize-na-sent…
jorisvandenbossche May 6, 2019
e350641
linting
jorisvandenbossche May 6, 2019
151aa6a
more linting
jorisvandenbossche May 6, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v0.25.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ Other
^^^^^

- Removed unused C functions from vendored UltraJSON implementation (:issue:`26198`)
- Bug in :func:`factorize` when passing an ``ExtensionArray`` with a custom ``na_sentinel`` (:issue:`25696`).


.. _whatsnew_0.250.contributors:
Expand Down
18 changes: 2 additions & 16 deletions pandas/core/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,22 +617,8 @@ def factorize(values, sort=False, order=None, na_sentinel=-1, size_hint=None):

if sort and len(uniques) > 0:
from pandas.core.sorting import safe_sort
if na_sentinel == -1:
# GH-25409 take_1d only works for na_sentinels of -1
try:
order = uniques.argsort()
order2 = order.argsort()
labels = take_1d(order2, labels, fill_value=na_sentinel)
uniques = uniques.take(order)
except TypeError:
# Mixed types, where uniques.argsort fails.
uniques, labels = safe_sort(uniques, labels,
na_sentinel=na_sentinel,
assume_unique=True)
else:
uniques, labels = safe_sort(uniques, labels,
na_sentinel=na_sentinel,
assume_unique=True)
uniques, labels = safe_sort(uniques, labels, na_sentinel=na_sentinel,
assume_unique=True, verify=False)

uniques = _reconstruct_data(uniques, dtype, original)

Expand Down
50 changes: 35 additions & 15 deletions pandas/core/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from pandas.core.dtypes.cast import infer_dtype_from_array
from pandas.core.dtypes.common import (
ensure_int64, ensure_platform_int, is_categorical_dtype, is_list_like)
ensure_int64, ensure_platform_int, is_categorical_dtype,
is_extension_array_dtype, is_list_like)
from pandas.core.dtypes.missing import isna

import pandas.core.algorithms as algorithms
Expand Down Expand Up @@ -403,7 +404,8 @@ def _reorder_by_uniques(uniques, labels):
return uniques, labels


def safe_sort(values, labels=None, na_sentinel=-1, assume_unique=False):
def safe_sort(values, labels=None, na_sentinel=-1, assume_unique=False,
verify=True):
"""
Sort ``values`` and reorder corresponding ``labels``.
``values`` should be unique if ``labels`` is not None.
Expand All @@ -424,6 +426,12 @@ def safe_sort(values, labels=None, na_sentinel=-1, assume_unique=False):
assume_unique : bool, default False
When True, ``values`` are assumed to be unique, which can speed up
the calculation. Ignored when ``labels`` is None.
verify : bool, default True
Check if labels are out of bound for the values and put out of bound
labels equal to na_sentinel. If ``verify=False``, it is assumed there
are no out of bound labels. Ignored when ``labels`` is None.

.. versionadded:: 0.25.0

Returns
-------
Expand All @@ -445,8 +453,8 @@ def safe_sort(values, labels=None, na_sentinel=-1, assume_unique=False):
raise TypeError("Only list-like objects are allowed to be passed to"
"safe_sort as values")

if not isinstance(values, np.ndarray):

if (not isinstance(values, np.ndarray)
and not is_extension_array_dtype(values)):
# don't convert to string types
dtype, _ = infer_dtype_from_array(values)
values = np.asarray(values, dtype=dtype)
Expand All @@ -460,7 +468,8 @@ def sort_mixed(values):
return np.concatenate([nums, np.asarray(strs, dtype=object)])

sorter = None
if lib.infer_dtype(values, skipna=False) == 'mixed-integer':
if (not is_extension_array_dtype(values)
and lib.infer_dtype(values, skipna=False) == 'mixed-integer'):
# unorderable in py3 if mixed str/int
ordered = sort_mixed(values)
else:
Expand Down Expand Up @@ -493,15 +502,26 @@ def sort_mixed(values):
t.map_locations(values)
sorter = ensure_platform_int(t.lookup(ordered))

reverse_indexer = np.empty(len(sorter), dtype=np.int_)
reverse_indexer.put(sorter, np.arange(len(sorter)))

mask = (labels < -len(values)) | (labels >= len(values)) | \
(labels == na_sentinel)

# (Out of bound indices will be masked with `na_sentinel` next, so we may
# deal with them here without performance loss using `mode='wrap'`.)
new_labels = reverse_indexer.take(labels, mode='wrap')
np.putmask(new_labels, mask, na_sentinel)
if na_sentinel == -1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would rather just fix take_1d

# take_1d is faster, but only works for na_sentinels of -1
order2 = sorter.argsort()
new_labels = algorithms.take_1d(order2, labels, fill_value=-1)
if verify:
mask = (labels < -len(values)) | (labels >= len(values))
else:
mask = None
else:
reverse_indexer = np.empty(len(sorter), dtype=np.int_)
reverse_indexer.put(sorter, np.arange(len(sorter)))
# Out of bound indices will be masked with `na_sentinel` next, so we
# may deal with them here without performance loss using `mode='wrap'`
new_labels = reverse_indexer.take(labels, mode='wrap')

mask = labels == na_sentinel
if verify:
mask = mask | (labels < -len(values)) | (labels >= len(values))

if mask is not None:
np.putmask(new_labels, mask, na_sentinel)

return ordered, ensure_platform_int(new_labels)
19 changes: 14 additions & 5 deletions pandas/tests/test_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import pandas.core.algorithms as algos
from pandas.core.arrays import DatetimeArray
import pandas.core.common as com
from pandas.core.sorting import safe_sort
import pandas.util.testing as tm
from pandas.util.testing import assert_almost_equal

Expand Down Expand Up @@ -325,18 +326,26 @@ def test_parametrized_factorize_na_value(self, data, na_value):

@pytest.mark.parametrize('sort', [True, False])
@pytest.mark.parametrize('na_sentinel', [-1, -10, 100])
def test_factorize_na_sentinel(self, sort, na_sentinel):
data = np.array(['b', 'a', None, 'b'], dtype=object)
@pytest.mark.parametrize('data, uniques', [
(np.array(['b', 'a', None, 'b'], dtype=object),
np.array(['b', 'a'], dtype=object)),
(pd.array([2, 1, np.nan, 2], dtype='Int64'),
pd.array([2, 1], dtype='Int64'))],
ids=['numpy_array', 'extension_array'])
def test_factorize_na_sentinel(self, sort, na_sentinel, data, uniques):
labels, uniques = algos.factorize(data, sort=sort,
na_sentinel=na_sentinel)
if sort:
expected_labels = np.array([1, 0, na_sentinel, 1], dtype=np.intp)
expected_uniques = np.array(['a', 'b'], dtype=object)
expected_uniques = safe_sort(uniques)
else:
expected_labels = np.array([0, 1, na_sentinel, 0], dtype=np.intp)
expected_uniques = np.array(['b', 'a'], dtype=object)
expected_uniques = uniques
tm.assert_numpy_array_equal(labels, expected_labels)
tm.assert_numpy_array_equal(uniques, expected_uniques)
if isinstance(data, np.ndarray):
tm.assert_numpy_array_equal(uniques, expected_uniques)
else:
tm.assert_extension_array_equal(uniques, expected_uniques)


class TestUnique:
Expand Down
53 changes: 41 additions & 12 deletions pandas/tests/test_sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from numpy import nan
import pytest

from pandas import DataFrame, MultiIndex, Series, concat, merge, to_datetime
from pandas import (
DataFrame, MultiIndex, Series, array, concat, merge, to_datetime)
from pandas.core import common as com
from pandas.core.sorting import (
decons_group_index, get_group_index, is_int64_overflow_possible,
Expand Down Expand Up @@ -358,34 +359,43 @@ def test_basic_sort(self):
expected = np.array([])
tm.assert_numpy_array_equal(result, expected)

def test_labels(self):
@pytest.mark.parametrize('verify', [True, False])
def test_labels(self, verify):
values = [3, 1, 2, 0, 4]
expected = np.array([0, 1, 2, 3, 4])

labels = [0, 1, 1, 2, 3, 0, -1, 4]
result, result_labels = safe_sort(values, labels)
result, result_labels = safe_sort(values, labels, verify=verify)
expected_labels = np.array([3, 1, 1, 2, 0, 3, -1, 4], dtype=np.intp)
tm.assert_numpy_array_equal(result, expected)
tm.assert_numpy_array_equal(result_labels, expected_labels)

# na_sentinel
labels = [0, 1, 1, 2, 3, 0, 99, 4]
result, result_labels = safe_sort(values, labels,
na_sentinel=99)
result, result_labels = safe_sort(values, labels, na_sentinel=99,
verify=verify)
expected_labels = np.array([3, 1, 1, 2, 0, 3, 99, 4], dtype=np.intp)
tm.assert_numpy_array_equal(result, expected)
tm.assert_numpy_array_equal(result_labels, expected_labels)

# out of bound indices
labels = [0, 101, 102, 2, 3, 0, 99, 4]
result, result_labels = safe_sort(values, labels)
expected_labels = np.array([3, -1, -1, 2, 0, 3, -1, 4], dtype=np.intp)
labels = []
result, result_labels = safe_sort(values, labels, verify=verify)
expected_labels = np.array([], dtype=np.intp)
tm.assert_numpy_array_equal(result, expected)
tm.assert_numpy_array_equal(result_labels, expected_labels)

labels = []
result, result_labels = safe_sort(values, labels)
expected_labels = np.array([], dtype=np.intp)
@pytest.mark.parametrize('na_sentinel', [-1, 99])
def test_labels_out_of_bound(self, na_sentinel):
values = [3, 1, 2, 0, 4]
expected = np.array([0, 1, 2, 3, 4])

# out of bound indices
labels = [0, 101, 102, 2, 3, 0, 99, 4]
result, result_labels = safe_sort(
values, labels, na_sentinel=na_sentinel)
expected_labels = np.array(
[3, na_sentinel, na_sentinel, 2, 0, 3, na_sentinel, 4],
dtype=np.intp)
tm.assert_numpy_array_equal(result, expected)
tm.assert_numpy_array_equal(result_labels, expected_labels)

Expand Down Expand Up @@ -430,3 +440,22 @@ def test_exceptions(self):
with pytest.raises(ValueError,
match="values should be unique"):
safe_sort(values=[0, 1, 2, 1], labels=[0, 1])

def test_extension_array(self):
# a = array([1, 3, np.nan, 2], dtype='Int64')
a = array([1, 3, 2], dtype='Int64')
result = safe_sort(a)
# expected = array([1, 2, 3, np.nan], dtype='Int64')
expected = array([1, 2, 3], dtype='Int64')
tm.assert_extension_array_equal(result, expected)

@pytest.mark.parametrize('verify', [True, False])
@pytest.mark.parametrize('na_sentinel', [-1, 99])
def test_extension_array_labels(self, verify, na_sentinel):
a = array([1, 3, 2], dtype='Int64')
result, labels = safe_sort(a, [0, 1, na_sentinel, 2],
na_sentinel=na_sentinel, verify=verify)
expected_values = array([1, 2, 3], dtype='Int64')
expected_labels = np.array([0, 2, na_sentinel, 1], dtype=np.intp)
tm.assert_extension_array_equal(result, expected_values)
tm.assert_numpy_array_equal(labels, expected_labels)