Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parametrized NA sentinel for factorize #20473

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions pandas/_libs/hashtable.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ cdef class Factorizer:
return self.count

def factorize(self, ndarray[object] values, sort=False, na_sentinel=-1,
check_null=True):
na_value=None):
"""
Factorize values with nans replaced by na_sentinel
>>> factorize(np.array([1,2,np.nan], dtype='O'), na_sentinel=20)
Expand All @@ -81,7 +81,7 @@ cdef class Factorizer:
uniques.extend(self.uniques.to_array())
self.uniques = uniques
labels = self.table.get_labels(values, self.uniques,
self.count, na_sentinel, check_null)
self.count, na_sentinel, na_value)
mask = (labels == na_sentinel)
# sort on
if sort:
Expand Down Expand Up @@ -114,7 +114,7 @@ cdef class Int64Factorizer:
return self.count

def factorize(self, int64_t[:] values, sort=False,
na_sentinel=-1, check_null=True):
na_sentinel=-1, na_value=None):
"""
Factorize values with nans replaced by na_sentinel
>>> factorize(np.array([1,2,np.nan], dtype='O'), na_sentinel=20)
Expand All @@ -126,7 +126,7 @@ cdef class Int64Factorizer:
self.uniques = uniques
labels = self.table.get_labels(values, self.uniques,
self.count, na_sentinel,
check_null)
na_value=na_value)

# sort on
if sort:
Expand Down
48 changes: 33 additions & 15 deletions pandas/_libs/hashtable_class_helper.pxi.in
Original file line number Diff line number Diff line change
Expand Up @@ -250,13 +250,13 @@ cdef class HashTable:

{{py:

# name, dtype, null_condition, float_group
dtypes = [('Float64', 'float64', 'val != val', True),
('UInt64', 'uint64', 'False', False),
('Int64', 'int64', 'val == iNaT', False)]
# name, dtype, float_group, default_na_value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when is float_group used? seems superfluous?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to be used in unique.

dtypes = [('Float64', 'float64', True, 'nan'),
('UInt64', 'uint64', False, 0),
('Int64', 'int64', False, 'iNaT')]

def get_dispatch(dtypes):
for (name, dtype, null_condition, float_group) in dtypes:
for (name, dtype, float_group, default_na_value) in dtypes:
unique_template = """\
cdef:
Py_ssize_t i, n = len(values)
Expand Down Expand Up @@ -298,13 +298,13 @@ def get_dispatch(dtypes):
return uniques.to_array()
"""

unique_template = unique_template.format(name=name, dtype=dtype, null_condition=null_condition, float_group=float_group)
unique_template = unique_template.format(name=name, dtype=dtype, float_group=float_group)

yield (name, dtype, null_condition, float_group, unique_template)
yield (name, dtype, float_group, default_na_value, unique_template)
}}


{{for name, dtype, null_condition, float_group, unique_template in get_dispatch(dtypes)}}
{{for name, dtype, float_group, default_na_value, unique_template in get_dispatch(dtypes)}}

cdef class {{name}}HashTable(HashTable):

Expand Down Expand Up @@ -408,24 +408,36 @@ cdef class {{name}}HashTable(HashTable):
@cython.boundscheck(False)
def get_labels(self, {{dtype}}_t[:] values, {{name}}Vector uniques,
Py_ssize_t count_prior, Py_ssize_t na_sentinel,
bint check_null=True):
object na_value=None):
cdef:
Py_ssize_t i, n = len(values)
int64_t[:] labels
Py_ssize_t idx, count = count_prior
int ret = 0
{{dtype}}_t val
{{dtype}}_t val, na_value2
khiter_t k
{{name}}VectorData *ud
bint use_na_value

labels = np.empty(n, dtype=np.int64)
ud = uniques.data
use_na_value = na_value is not None

if use_na_value:
# We need this na_value2 because we want to allow users
# to *optionally* specify an NA sentinel *of the correct* type.
# We use None, to make it optional, which requires `object` type
# for the parameter. To please the compiler, we use na_value2,
# which is only used if it's *specified*.
na_value2 = <{{dtype}}_t>na_value
else:
na_value2 = {{default_na_value}}

with nogil:
for i in range(n):
val = values[i]

if check_null and {{null_condition}}:
if val != val or (use_na_value and val == na_value2):
labels[i] = na_sentinel
continue

Expand Down Expand Up @@ -695,7 +707,7 @@ cdef class StringHashTable(HashTable):
@cython.boundscheck(False)
def get_labels(self, ndarray[object] values, ObjectVector uniques,
Py_ssize_t count_prior, int64_t na_sentinel,
bint check_null=1):
object na_value=None):
cdef:
Py_ssize_t i, n = len(values)
int64_t[:] labels
Expand All @@ -706,18 +718,21 @@ cdef class StringHashTable(HashTable):
char *v
char **vecs
khiter_t k
bint use_na_value

# these by-definition *must* be strings
labels = np.zeros(n, dtype=np.int64)
uindexer = np.empty(n, dtype=np.int64)
use_na_value = na_value is not None

# pre-filter out missing
# and assign pointers
vecs = <char **> malloc(n * sizeof(char *))
for i in range(n):
val = values[i]

if PyUnicode_Check(val) or PyString_Check(val):
if ((PyUnicode_Check(val) or PyString_Check(val)) and
not (use_na_value and val == na_value)):
v = util.get_c_string(val)
vecs[i] = v
else:
Expand Down Expand Up @@ -868,22 +883,25 @@ cdef class PyObjectHashTable(HashTable):

def get_labels(self, ndarray[object] values, ObjectVector uniques,
Py_ssize_t count_prior, int64_t na_sentinel,
bint check_null=True):
object na_value=None):
cdef:
Py_ssize_t i, n = len(values)
int64_t[:] labels
Py_ssize_t idx, count = count_prior
int ret = 0
object val
khiter_t k
bint use_na_value

labels = np.empty(n, dtype=np.int64)
use_na_value = na_value is not None

for i in range(n):
val = values[i]
hash(val)

if check_null and val != val or val is None:
if ((val != val or val is None) or
(use_na_value and val == na_value)):
labels[i] = na_sentinel
continue

Expand Down
29 changes: 21 additions & 8 deletions pandas/core/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
_ensure_float64, _ensure_uint64,
_ensure_int64)
from pandas.compat.numpy import _np_version_under1p10
from pandas.core.dtypes.missing import isna
from pandas.core.dtypes.missing import isna, na_value_for_dtype

from pandas.core import common as com
from pandas._libs import algos, lib, hashtable as htable
Expand Down Expand Up @@ -435,19 +435,23 @@ def isin(comps, values):
return f(comps, values)


def _factorize_array(values, check_nulls, na_sentinel=-1, size_hint=None):
def _factorize_array(values, na_sentinel=-1, size_hint=None,
na_value=None):
"""Factorize an array-like to labels and uniques.

This doesn't do any coercion of types or unboxing before factorization.

Parameters
----------
values : ndarray
check_nulls : bool
Whether to check for nulls in the hashtable's 'get_labels' method.
na_sentinel : int, default -1
size_hint : int, optional
Passsed through to the hashtable's 'get_labels' method
na_value : object, optional
A value in `values` to consider missing. Note: only use this
parameter when you know that you don't have any values pandas would
consider missing in the array (NaN for float data, iNaT for
datetimes, etc.).

Returns
-------
Expand All @@ -457,7 +461,8 @@ def _factorize_array(values, check_nulls, na_sentinel=-1, size_hint=None):

table = hash_klass(size_hint or len(values))
uniques = vec_klass()
labels = table.get_labels(values, uniques, 0, na_sentinel, check_nulls)
labels = table.get_labels(values, uniques, 0, na_sentinel,
na_value=na_value)

labels = _ensure_platform_int(labels)
uniques = uniques.to_array()
Expand Down Expand Up @@ -508,10 +513,18 @@ def factorize(values, sort=False, order=None, na_sentinel=-1, size_hint=None):
dtype = original.dtype
else:
values, dtype, _ = _ensure_data(values)
check_nulls = not is_integer_dtype(original)
labels, uniques = _factorize_array(values, check_nulls,

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use pandas.core.dtypes.missing.na_value_for_dtype (or maybe add a kwarg to return the underlying value). just want to try to keep this logic in 1 place.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why na_value_for_dtype(PeriodDtype) is nan? It should be NaT, right?

if (is_datetime64_any_dtype(original) or
is_timedelta64_dtype(original) or
is_period_dtype(original)):
na_value = na_value_for_dtype(original.dtype)
else:
na_value = None

labels, uniques = _factorize_array(values,
na_sentinel=na_sentinel,
size_hint=size_hint)
size_hint=size_hint,
na_value=na_value)

if sort and len(uniques) > 0:
from pandas.core.sorting import safe_sort
Expand Down
6 changes: 2 additions & 4 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from pandas import compat
from pandas.compat import u, lzip
from pandas._libs import lib, algos as libalgos
from pandas._libs.tslib import iNaT

from pandas.core.dtypes.generic import (
ABCSeries, ABCIndexClass, ABCCategoricalIndex)
Expand Down Expand Up @@ -2163,11 +2162,10 @@ def factorize(self, na_sentinel=-1):
from pandas.core.algorithms import _factorize_array

codes = self.codes.astype('int64')
codes[codes == -1] = iNaT
# We set missing codes, normally -1, to iNaT so that the
# Int64HashTable treats them as missing values.
labels, uniques = _factorize_array(codes, check_nulls=True,
na_sentinel=na_sentinel)
labels, uniques = _factorize_array(codes, na_sentinel=na_sentinel,
na_value=-1)
uniques = self._constructor(self.categories.take(uniques),
categories=self.categories,
ordered=self.ordered)
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/dtypes/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
is_datetimelike_v_numeric, is_float_dtype,
is_datetime64_dtype, is_datetime64tz_dtype,
is_timedelta64_dtype, is_interval_dtype,
is_period_dtype,
is_complex_dtype,
is_string_like_dtype, is_bool_dtype,
is_integer_dtype, is_dtype_equal,
Expand Down Expand Up @@ -393,7 +394,7 @@ def na_value_for_dtype(dtype, compat=True):
dtype = pandas_dtype(dtype)

if (is_datetime64_dtype(dtype) or is_datetime64tz_dtype(dtype) or
is_timedelta64_dtype(dtype)):
is_timedelta64_dtype(dtype) or is_period_dtype(dtype)):
return NaT
elif is_float_dtype(dtype):
return np.nan
Expand Down
41 changes: 23 additions & 18 deletions pandas/tests/dtypes/test_missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from pandas import (NaT, Float64Index, Series,
DatetimeIndex, TimedeltaIndex, date_range)
from pandas.core.dtypes.common import is_scalar
from pandas.core.dtypes.dtypes import DatetimeTZDtype
from pandas.core.dtypes.dtypes import (
DatetimeTZDtype, PeriodDtype, IntervalDtype)
from pandas.core.dtypes.missing import (
array_equivalent, isna, notna, isnull, notnull,
na_value_for_dtype)
Expand Down Expand Up @@ -311,23 +312,27 @@ def test_array_equivalent_str():
np.array(['A', 'X'], dtype=dtype))


def test_na_value_for_dtype():
for dtype in [np.dtype('M8[ns]'), np.dtype('m8[ns]'),
DatetimeTZDtype('datetime64[ns, US/Eastern]')]:
assert na_value_for_dtype(dtype) is NaT

for dtype in ['u1', 'u2', 'u4', 'u8',
'i1', 'i2', 'i4', 'i8']:
assert na_value_for_dtype(np.dtype(dtype)) == 0

for dtype in ['bool']:
assert na_value_for_dtype(np.dtype(dtype)) is False

for dtype in ['f2', 'f4', 'f8']:
assert np.isnan(na_value_for_dtype(np.dtype(dtype)))

for dtype in ['O']:
assert np.isnan(na_value_for_dtype(np.dtype(dtype)))
@pytest.mark.parametrize('dtype, na_value', [
# Datetime-like
(np.dtype("M8[ns]"), NaT),
(np.dtype("m8[ns]"), NaT),
(DatetimeTZDtype('datetime64[ns, US/Eastern]'), NaT),
(PeriodDtype("M"), NaT),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

# Integer
('u1', 0), ('u2', 0), ('u4', 0), ('u8', 0),
('i1', 0), ('i2', 0), ('i4', 0), ('i8', 0),
# Bool
('bool', False),
# Float
('f2', np.nan), ('f4', np.nan), ('f8', np.nan),
# Object
('O', np.nan),
# Interval
(IntervalDtype(), np.nan),
])
def test_na_value_for_dtype(dtype, na_value):
result = na_value_for_dtype(dtype)
assert result is na_value


class TestNAObj(object):
Expand Down
30 changes: 30 additions & 0 deletions pandas/tests/test_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,36 @@ def test_deprecate_order(self):
with tm.assert_produces_warning(False):
algos.factorize(data)

@pytest.mark.parametrize('data', [
np.array([0, 1, 0], dtype='u8'),
np.array([-2**63, 1, -2**63], dtype='i8'),
np.array(['__nan__', 'foo', '__nan__'], dtype='object'),
])
def test_parametrized_factorize_na_value_default(self, data):
# arrays that include the NA default for that type, but isn't used.
l, u = algos.factorize(data)
expected_uniques = data[[0, 1]]
expected_labels = np.array([0, 1, 0], dtype='i8')
tm.assert_numpy_array_equal(l, expected_labels)
tm.assert_numpy_array_equal(u, expected_uniques)

@pytest.mark.parametrize('data, na_value', [
(np.array([0, 1, 0, 2], dtype='u8'), 0),
(np.array([1, 0, 1, 2], dtype='u8'), 1),
(np.array([-2**63, 1, -2**63, 0], dtype='i8'), -2**63),
(np.array([1, -2**63, 1, 0], dtype='i8'), 1),
(np.array(['a', '', 'a', 'b'], dtype=object), 'a'),
(np.array([(), ('a', 1), (), ('a', 2)], dtype=object), ()),
(np.array([('a', 1), (), ('a', 1), ('a', 2)], dtype=object),
('a', 1)),
])
def test_parametrized_factorize_na_value(self, data, na_value):
l, u = algos._factorize_array(data, na_value=na_value)
expected_uniques = data[[1, 3]]
expected_labels = np.array([-1, 0, -1, 1], dtype='i8')
tm.assert_numpy_array_equal(l, expected_labels)
tm.assert_numpy_array_equal(u, expected_uniques)


class TestUnique(object):

Expand Down