Skip to content

Commit 1b1b99b

Browse files
TomAugspurgerjreback
authored andcommitted
Parametrized NA sentinel for factorize (#20473)
1 parent 2179302 commit 1b1b99b

File tree

7 files changed

+115
-50
lines changed

7 files changed

+115
-50
lines changed

pandas/_libs/hashtable.pyx

+4-4
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ cdef class Factorizer:
7070
return self.count
7171

7272
def factorize(self, ndarray[object] values, sort=False, na_sentinel=-1,
73-
check_null=True):
73+
na_value=None):
7474
"""
7575
Factorize values with nans replaced by na_sentinel
7676
>>> factorize(np.array([1,2,np.nan], dtype='O'), na_sentinel=20)
@@ -81,7 +81,7 @@ cdef class Factorizer:
8181
uniques.extend(self.uniques.to_array())
8282
self.uniques = uniques
8383
labels = self.table.get_labels(values, self.uniques,
84-
self.count, na_sentinel, check_null)
84+
self.count, na_sentinel, na_value)
8585
mask = (labels == na_sentinel)
8686
# sort on
8787
if sort:
@@ -114,7 +114,7 @@ cdef class Int64Factorizer:
114114
return self.count
115115

116116
def factorize(self, int64_t[:] values, sort=False,
117-
na_sentinel=-1, check_null=True):
117+
na_sentinel=-1, na_value=None):
118118
"""
119119
Factorize values with nans replaced by na_sentinel
120120
>>> factorize(np.array([1,2,np.nan], dtype='O'), na_sentinel=20)
@@ -126,7 +126,7 @@ cdef class Int64Factorizer:
126126
self.uniques = uniques
127127
labels = self.table.get_labels(values, self.uniques,
128128
self.count, na_sentinel,
129-
check_null)
129+
na_value=na_value)
130130

131131
# sort on
132132
if sort:

pandas/_libs/hashtable_class_helper.pxi.in

+33-15
Original file line numberDiff line numberDiff line change
@@ -250,13 +250,13 @@ cdef class HashTable:
250250

251251
{{py:
252252

253-
# name, dtype, null_condition, float_group
254-
dtypes = [('Float64', 'float64', 'val != val', True),
255-
('UInt64', 'uint64', 'False', False),
256-
('Int64', 'int64', 'val == iNaT', False)]
253+
# name, dtype, float_group, default_na_value
254+
dtypes = [('Float64', 'float64', True, 'nan'),
255+
('UInt64', 'uint64', False, 0),
256+
('Int64', 'int64', False, 'iNaT')]
257257

258258
def get_dispatch(dtypes):
259-
for (name, dtype, null_condition, float_group) in dtypes:
259+
for (name, dtype, float_group, default_na_value) in dtypes:
260260
unique_template = """\
261261
cdef:
262262
Py_ssize_t i, n = len(values)
@@ -298,13 +298,13 @@ def get_dispatch(dtypes):
298298
return uniques.to_array()
299299
"""
300300

301-
unique_template = unique_template.format(name=name, dtype=dtype, null_condition=null_condition, float_group=float_group)
301+
unique_template = unique_template.format(name=name, dtype=dtype, float_group=float_group)
302302

303-
yield (name, dtype, null_condition, float_group, unique_template)
303+
yield (name, dtype, float_group, default_na_value, unique_template)
304304
}}
305305

306306

307-
{{for name, dtype, null_condition, float_group, unique_template in get_dispatch(dtypes)}}
307+
{{for name, dtype, float_group, default_na_value, unique_template in get_dispatch(dtypes)}}
308308

309309
cdef class {{name}}HashTable(HashTable):
310310

@@ -408,24 +408,36 @@ cdef class {{name}}HashTable(HashTable):
408408
@cython.boundscheck(False)
409409
def get_labels(self, {{dtype}}_t[:] values, {{name}}Vector uniques,
410410
Py_ssize_t count_prior, Py_ssize_t na_sentinel,
411-
bint check_null=True):
411+
object na_value=None):
412412
cdef:
413413
Py_ssize_t i, n = len(values)
414414
int64_t[:] labels
415415
Py_ssize_t idx, count = count_prior
416416
int ret = 0
417-
{{dtype}}_t val
417+
{{dtype}}_t val, na_value2
418418
khiter_t k
419419
{{name}}VectorData *ud
420+
bint use_na_value
420421

421422
labels = np.empty(n, dtype=np.int64)
422423
ud = uniques.data
424+
use_na_value = na_value is not None
425+
426+
if use_na_value:
427+
# We need this na_value2 because we want to allow users
428+
# to *optionally* specify an NA sentinel *of the correct* type.
429+
# We use None, to make it optional, which requires `object` type
430+
# for the parameter. To please the compiler, we use na_value2,
431+
# which is only used if it's *specified*.
432+
na_value2 = <{{dtype}}_t>na_value
433+
else:
434+
na_value2 = {{default_na_value}}
423435

424436
with nogil:
425437
for i in range(n):
426438
val = values[i]
427439

428-
if check_null and {{null_condition}}:
440+
if val != val or (use_na_value and val == na_value2):
429441
labels[i] = na_sentinel
430442
continue
431443

@@ -695,7 +707,7 @@ cdef class StringHashTable(HashTable):
695707
@cython.boundscheck(False)
696708
def get_labels(self, ndarray[object] values, ObjectVector uniques,
697709
Py_ssize_t count_prior, int64_t na_sentinel,
698-
bint check_null=1):
710+
object na_value=None):
699711
cdef:
700712
Py_ssize_t i, n = len(values)
701713
int64_t[:] labels
@@ -706,18 +718,21 @@ cdef class StringHashTable(HashTable):
706718
char *v
707719
char **vecs
708720
khiter_t k
721+
bint use_na_value
709722

710723
# these by-definition *must* be strings
711724
labels = np.zeros(n, dtype=np.int64)
712725
uindexer = np.empty(n, dtype=np.int64)
726+
use_na_value = na_value is not None
713727

714728
# pre-filter out missing
715729
# and assign pointers
716730
vecs = <char **> malloc(n * sizeof(char *))
717731
for i in range(n):
718732
val = values[i]
719733

720-
if PyUnicode_Check(val) or PyString_Check(val):
734+
if ((PyUnicode_Check(val) or PyString_Check(val)) and
735+
not (use_na_value and val == na_value)):
721736
v = util.get_c_string(val)
722737
vecs[i] = v
723738
else:
@@ -868,22 +883,25 @@ cdef class PyObjectHashTable(HashTable):
868883

869884
def get_labels(self, ndarray[object] values, ObjectVector uniques,
870885
Py_ssize_t count_prior, int64_t na_sentinel,
871-
bint check_null=True):
886+
object na_value=None):
872887
cdef:
873888
Py_ssize_t i, n = len(values)
874889
int64_t[:] labels
875890
Py_ssize_t idx, count = count_prior
876891
int ret = 0
877892
object val
878893
khiter_t k
894+
bint use_na_value
879895

880896
labels = np.empty(n, dtype=np.int64)
897+
use_na_value = na_value is not None
881898

882899
for i in range(n):
883900
val = values[i]
884901
hash(val)
885902

886-
if check_null and val != val or val is None:
903+
if ((val != val or val is None) or
904+
(use_na_value and val == na_value)):
887905
labels[i] = na_sentinel
888906
continue
889907

pandas/core/algorithms.py

+21-8
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
_ensure_float64, _ensure_uint64,
3030
_ensure_int64)
3131
from pandas.compat.numpy import _np_version_under1p10
32-
from pandas.core.dtypes.missing import isna
32+
from pandas.core.dtypes.missing import isna, na_value_for_dtype
3333

3434
from pandas.core import common as com
3535
from pandas._libs import algos, lib, hashtable as htable
@@ -435,19 +435,23 @@ def isin(comps, values):
435435
return f(comps, values)
436436

437437

438-
def _factorize_array(values, check_nulls, na_sentinel=-1, size_hint=None):
438+
def _factorize_array(values, na_sentinel=-1, size_hint=None,
439+
na_value=None):
439440
"""Factorize an array-like to labels and uniques.
440441
441442
This doesn't do any coercion of types or unboxing before factorization.
442443
443444
Parameters
444445
----------
445446
values : ndarray
446-
check_nulls : bool
447-
Whether to check for nulls in the hashtable's 'get_labels' method.
448447
na_sentinel : int, default -1
449448
size_hint : int, optional
450449
Passsed through to the hashtable's 'get_labels' method
450+
na_value : object, optional
451+
A value in `values` to consider missing. Note: only use this
452+
parameter when you know that you don't have any values pandas would
453+
consider missing in the array (NaN for float data, iNaT for
454+
datetimes, etc.).
451455
452456
Returns
453457
-------
@@ -457,7 +461,8 @@ def _factorize_array(values, check_nulls, na_sentinel=-1, size_hint=None):
457461

458462
table = hash_klass(size_hint or len(values))
459463
uniques = vec_klass()
460-
labels = table.get_labels(values, uniques, 0, na_sentinel, check_nulls)
464+
labels = table.get_labels(values, uniques, 0, na_sentinel,
465+
na_value=na_value)
461466

462467
labels = _ensure_platform_int(labels)
463468
uniques = uniques.to_array()
@@ -508,10 +513,18 @@ def factorize(values, sort=False, order=None, na_sentinel=-1, size_hint=None):
508513
dtype = original.dtype
509514
else:
510515
values, dtype, _ = _ensure_data(values)
511-
check_nulls = not is_integer_dtype(original)
512-
labels, uniques = _factorize_array(values, check_nulls,
516+
517+
if (is_datetime64_any_dtype(original) or
518+
is_timedelta64_dtype(original) or
519+
is_period_dtype(original)):
520+
na_value = na_value_for_dtype(original.dtype)
521+
else:
522+
na_value = None
523+
524+
labels, uniques = _factorize_array(values,
513525
na_sentinel=na_sentinel,
514-
size_hint=size_hint)
526+
size_hint=size_hint,
527+
na_value=na_value)
515528

516529
if sort and len(uniques) > 0:
517530
from pandas.core.sorting import safe_sort

pandas/core/arrays/categorical.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from pandas import compat
88
from pandas.compat import u, lzip
99
from pandas._libs import lib, algos as libalgos
10-
from pandas._libs.tslib import iNaT
1110

1211
from pandas.core.dtypes.generic import (
1312
ABCSeries, ABCIndexClass, ABCCategoricalIndex)
@@ -2163,11 +2162,10 @@ def factorize(self, na_sentinel=-1):
21632162
from pandas.core.algorithms import _factorize_array
21642163

21652164
codes = self.codes.astype('int64')
2166-
codes[codes == -1] = iNaT
21672165
# We set missing codes, normally -1, to iNaT so that the
21682166
# Int64HashTable treats them as missing values.
2169-
labels, uniques = _factorize_array(codes, check_nulls=True,
2170-
na_sentinel=na_sentinel)
2167+
labels, uniques = _factorize_array(codes, na_sentinel=na_sentinel,
2168+
na_value=-1)
21712169
uniques = self._constructor(self.categories.take(uniques),
21722170
categories=self.categories,
21732171
ordered=self.ordered)

pandas/core/dtypes/missing.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
is_datetimelike_v_numeric, is_float_dtype,
1212
is_datetime64_dtype, is_datetime64tz_dtype,
1313
is_timedelta64_dtype, is_interval_dtype,
14+
is_period_dtype,
1415
is_complex_dtype,
1516
is_string_like_dtype, is_bool_dtype,
1617
is_integer_dtype, is_dtype_equal,
@@ -502,7 +503,7 @@ def na_value_for_dtype(dtype, compat=True):
502503
dtype = pandas_dtype(dtype)
503504

504505
if (is_datetime64_dtype(dtype) or is_datetime64tz_dtype(dtype) or
505-
is_timedelta64_dtype(dtype)):
506+
is_timedelta64_dtype(dtype) or is_period_dtype(dtype)):
506507
return NaT
507508
elif is_float_dtype(dtype):
508509
return np.nan

pandas/tests/dtypes/test_missing.py

+23-18
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from pandas import (NaT, Float64Index, Series,
1616
DatetimeIndex, TimedeltaIndex, date_range)
1717
from pandas.core.dtypes.common import is_scalar
18-
from pandas.core.dtypes.dtypes import DatetimeTZDtype
18+
from pandas.core.dtypes.dtypes import (
19+
DatetimeTZDtype, PeriodDtype, IntervalDtype)
1920
from pandas.core.dtypes.missing import (
2021
array_equivalent, isna, notna, isnull, notnull,
2122
na_value_for_dtype)
@@ -311,23 +312,27 @@ def test_array_equivalent_str():
311312
np.array(['A', 'X'], dtype=dtype))
312313

313314

314-
def test_na_value_for_dtype():
315-
for dtype in [np.dtype('M8[ns]'), np.dtype('m8[ns]'),
316-
DatetimeTZDtype('datetime64[ns, US/Eastern]')]:
317-
assert na_value_for_dtype(dtype) is NaT
318-
319-
for dtype in ['u1', 'u2', 'u4', 'u8',
320-
'i1', 'i2', 'i4', 'i8']:
321-
assert na_value_for_dtype(np.dtype(dtype)) == 0
322-
323-
for dtype in ['bool']:
324-
assert na_value_for_dtype(np.dtype(dtype)) is False
325-
326-
for dtype in ['f2', 'f4', 'f8']:
327-
assert np.isnan(na_value_for_dtype(np.dtype(dtype)))
328-
329-
for dtype in ['O']:
330-
assert np.isnan(na_value_for_dtype(np.dtype(dtype)))
315+
@pytest.mark.parametrize('dtype, na_value', [
316+
# Datetime-like
317+
(np.dtype("M8[ns]"), NaT),
318+
(np.dtype("m8[ns]"), NaT),
319+
(DatetimeTZDtype('datetime64[ns, US/Eastern]'), NaT),
320+
(PeriodDtype("M"), NaT),
321+
# Integer
322+
('u1', 0), ('u2', 0), ('u4', 0), ('u8', 0),
323+
('i1', 0), ('i2', 0), ('i4', 0), ('i8', 0),
324+
# Bool
325+
('bool', False),
326+
# Float
327+
('f2', np.nan), ('f4', np.nan), ('f8', np.nan),
328+
# Object
329+
('O', np.nan),
330+
# Interval
331+
(IntervalDtype(), np.nan),
332+
])
333+
def test_na_value_for_dtype(dtype, na_value):
334+
result = na_value_for_dtype(dtype)
335+
assert result is na_value
331336

332337

333338
class TestNAObj(object):

pandas/tests/test_algos.py

+30
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,36 @@ def test_deprecate_order(self):
257257
with tm.assert_produces_warning(False):
258258
algos.factorize(data)
259259

260+
@pytest.mark.parametrize('data', [
261+
np.array([0, 1, 0], dtype='u8'),
262+
np.array([-2**63, 1, -2**63], dtype='i8'),
263+
np.array(['__nan__', 'foo', '__nan__'], dtype='object'),
264+
])
265+
def test_parametrized_factorize_na_value_default(self, data):
266+
# arrays that include the NA default for that type, but isn't used.
267+
l, u = algos.factorize(data)
268+
expected_uniques = data[[0, 1]]
269+
expected_labels = np.array([0, 1, 0], dtype='i8')
270+
tm.assert_numpy_array_equal(l, expected_labels)
271+
tm.assert_numpy_array_equal(u, expected_uniques)
272+
273+
@pytest.mark.parametrize('data, na_value', [
274+
(np.array([0, 1, 0, 2], dtype='u8'), 0),
275+
(np.array([1, 0, 1, 2], dtype='u8'), 1),
276+
(np.array([-2**63, 1, -2**63, 0], dtype='i8'), -2**63),
277+
(np.array([1, -2**63, 1, 0], dtype='i8'), 1),
278+
(np.array(['a', '', 'a', 'b'], dtype=object), 'a'),
279+
(np.array([(), ('a', 1), (), ('a', 2)], dtype=object), ()),
280+
(np.array([('a', 1), (), ('a', 1), ('a', 2)], dtype=object),
281+
('a', 1)),
282+
])
283+
def test_parametrized_factorize_na_value(self, data, na_value):
284+
l, u = algos._factorize_array(data, na_value=na_value)
285+
expected_uniques = data[[1, 3]]
286+
expected_labels = np.array([-1, 0, -1, 1], dtype='i8')
287+
tm.assert_numpy_array_equal(l, expected_labels)
288+
tm.assert_numpy_array_equal(u, expected_uniques)
289+
260290

261291
class TestUnique(object):
262292

0 commit comments

Comments
 (0)