@@ -250,13 +250,13 @@ cdef class HashTable:
250
250
251
251
{{py:
252
252
253
- # name, dtype, null_condition, float_group
254
- dtypes = [('Float64', 'float64', 'val != val', True ),
255
- ('UInt64', 'uint64', ' False', False ),
256
- ('Int64', 'int64', 'val == iNaT', False )]
253
+ # name, dtype, float_group, default_na_value
254
+ dtypes = [('Float64', 'float64', True, 'nan' ),
255
+ ('UInt64', 'uint64', False, 0 ),
256
+ ('Int64', 'int64', False, ' iNaT')]
257
257
258
258
def get_dispatch(dtypes):
259
- for (name, dtype, null_condition, float_group ) in dtypes:
259
+ for (name, dtype, float_group, default_na_value ) in dtypes:
260
260
unique_template = """\
261
261
cdef:
262
262
Py_ssize_t i, n = len(values)
@@ -298,13 +298,13 @@ def get_dispatch(dtypes):
298
298
return uniques.to_array()
299
299
"""
300
300
301
- unique_template = unique_template.format(name=name, dtype=dtype, null_condition=null_condition, float_group=float_group)
301
+ unique_template = unique_template.format(name=name, dtype=dtype, float_group=float_group)
302
302
303
- yield (name, dtype, null_condition, float_group , unique_template)
303
+ yield (name, dtype, float_group, default_na_value , unique_template)
304
304
}}
305
305
306
306
307
- {{for name, dtype, null_condition, float_group , unique_template in get_dispatch(dtypes)}}
307
+ {{for name, dtype, float_group, default_na_value , unique_template in get_dispatch(dtypes)}}
308
308
309
309
cdef class {{name}}HashTable(HashTable):
310
310
@@ -408,24 +408,36 @@ cdef class {{name}}HashTable(HashTable):
408
408
@cython.boundscheck(False)
409
409
def get_labels(self, {{dtype}}_t[:] values, {{name}}Vector uniques,
410
410
Py_ssize_t count_prior, Py_ssize_t na_sentinel,
411
- bint check_null=True ):
411
+ object na_value=None ):
412
412
cdef:
413
413
Py_ssize_t i, n = len(values)
414
414
int64_t[:] labels
415
415
Py_ssize_t idx, count = count_prior
416
416
int ret = 0
417
- {{dtype}}_t val
417
+ {{dtype}}_t val, na_value2
418
418
khiter_t k
419
419
{{name}}VectorData *ud
420
+ bint use_na_value
420
421
421
422
labels = np.empty(n, dtype=np.int64)
422
423
ud = uniques.data
424
+ use_na_value = na_value is not None
425
+
426
+ if use_na_value:
427
+ # We need this na_value2 because we want to allow users
428
+ # to *optionally* specify an NA sentinel *of the correct* type.
429
+ # We use None, to make it optional, which requires `object` type
430
+ # for the parameter. To please the compiler, we use na_value2,
431
+ # which is only used if it's *specified*.
432
+ na_value2 = <{{dtype}}_t>na_value
433
+ else:
434
+ na_value2 = {{default_na_value}}
423
435
424
436
with nogil:
425
437
for i in range(n):
426
438
val = values[i]
427
439
428
- if check_null and {{null_condition}} :
440
+ if val != val or (use_na_value and val == na_value2) :
429
441
labels[i] = na_sentinel
430
442
continue
431
443
@@ -695,7 +707,7 @@ cdef class StringHashTable(HashTable):
695
707
@cython.boundscheck(False)
696
708
def get_labels(self, ndarray[object] values, ObjectVector uniques,
697
709
Py_ssize_t count_prior, int64_t na_sentinel,
698
- bint check_null=1 ):
710
+ object na_value=None ):
699
711
cdef:
700
712
Py_ssize_t i, n = len(values)
701
713
int64_t[:] labels
@@ -706,18 +718,21 @@ cdef class StringHashTable(HashTable):
706
718
char *v
707
719
char **vecs
708
720
khiter_t k
721
+ bint use_na_value
709
722
710
723
# these by-definition *must* be strings
711
724
labels = np.zeros(n, dtype=np.int64)
712
725
uindexer = np.empty(n, dtype=np.int64)
726
+ use_na_value = na_value is not None
713
727
714
728
# pre-filter out missing
715
729
# and assign pointers
716
730
vecs = <char **> malloc(n * sizeof(char *))
717
731
for i in range(n):
718
732
val = values[i]
719
733
720
- if PyUnicode_Check(val) or PyString_Check(val):
734
+ if ((PyUnicode_Check(val) or PyString_Check(val)) and
735
+ not (use_na_value and val == na_value)):
721
736
v = util.get_c_string(val)
722
737
vecs[i] = v
723
738
else:
@@ -868,22 +883,25 @@ cdef class PyObjectHashTable(HashTable):
868
883
869
884
def get_labels(self, ndarray[object] values, ObjectVector uniques,
870
885
Py_ssize_t count_prior, int64_t na_sentinel,
871
- bint check_null=True ):
886
+ object na_value=None ):
872
887
cdef:
873
888
Py_ssize_t i, n = len(values)
874
889
int64_t[:] labels
875
890
Py_ssize_t idx, count = count_prior
876
891
int ret = 0
877
892
object val
878
893
khiter_t k
894
+ bint use_na_value
879
895
880
896
labels = np.empty(n, dtype=np.int64)
897
+ use_na_value = na_value is not None
881
898
882
899
for i in range(n):
883
900
val = values[i]
884
901
hash(val)
885
902
886
- if check_null and val != val or val is None:
903
+ if ((val != val or val is None) or
904
+ (use_na_value and val == na_value)):
887
905
labels[i] = na_sentinel
888
906
continue
889
907
0 commit comments