Skip to content

Commit ee6b0a0

Browse files
authored
ENH: Index[bool] (#45061)
1 parent 601170d commit ee6b0a0

23 files changed

+137
-56
lines changed

doc/source/whatsnew/v1.5.0.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ Other enhancements
3737
- :meth:`to_numeric` now preserves float64 arrays when downcasting would generate values not representable in float32 (:issue:`43693`)
3838
- :meth:`Series.reset_index` and :meth:`DataFrame.reset_index` now support the argument ``allow_duplicates`` (:issue:`44410`)
3939
- :meth:`.GroupBy.min` and :meth:`.GroupBy.max` now supports `Numba <https://numba.pydata.org/>`_ execution with the ``engine`` keyword (:issue:`45428`)
40+
- Implemented a ``bool``-dtype :class:`Index`, passing a bool-dtype arraylike to ``pd.Index`` will now retain ``bool`` dtype instead of casting to ``object`` (:issue:`45061`)
4041
-
4142

4243
.. ---------------------------------------------------------------------------

pandas/_libs/index.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class ObjectEngine(IndexEngine): ...
4242
class DatetimeEngine(Int64Engine): ...
4343
class TimedeltaEngine(DatetimeEngine): ...
4444
class PeriodEngine(Int64Engine): ...
45+
class BoolEngine(UInt8Engine): ...
4546

4647
class BaseMultiIndexCodesEngine:
4748
levels: list[np.ndarray]

pandas/_libs/index.pyx

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -802,6 +802,13 @@ cdef class BaseMultiIndexCodesEngine:
802802
include "index_class_helper.pxi"
803803

804804

805+
cdef class BoolEngine(UInt8Engine):
806+
cdef _check_type(self, object val):
807+
if not util.is_bool_object(val):
808+
raise KeyError(val)
809+
return <uint8_t>val
810+
811+
805812
@cython.internal
806813
@cython.freelist(32)
807814
cdef class SharedEngine:

pandas/conftest.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,8 @@ def _create_mi_with_dt64tz_level():
555555
"num_uint8": tm.makeNumericIndex(100, dtype="uint8"),
556556
"num_float64": tm.makeNumericIndex(100, dtype="float64"),
557557
"num_float32": tm.makeNumericIndex(100, dtype="float32"),
558-
"bool": tm.makeBoolIndex(10),
558+
"bool-object": tm.makeBoolIndex(10).astype(object),
559+
"bool-dtype": Index(np.random.randn(10) < 0),
559560
"categorical": tm.makeCategoricalIndex(100),
560561
"interval": tm.makeIntervalIndex(100),
561562
"empty": Index([]),
@@ -630,7 +631,7 @@ def index_flat_unique(request):
630631
key
631632
for key in indices_dict
632633
if not (
633-
key in ["int", "uint", "range", "empty", "repeats"]
634+
key in ["int", "uint", "range", "empty", "repeats", "bool-dtype"]
634635
or key.startswith("num_")
635636
)
636637
and not isinstance(indices_dict[key], MultiIndex)

pandas/core/algorithms.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -220,9 +220,6 @@ def _reconstruct_data(
220220
elif is_bool_dtype(dtype):
221221
values = values.astype(dtype, copy=False)
222222

223-
# we only support object dtypes bool Index
224-
if isinstance(original, ABCIndex):
225-
values = values.astype(object, copy=False)
226223
elif dtype is not None:
227224
if is_datetime64_dtype(dtype):
228225
dtype = np.dtype("datetime64[ns]")
@@ -830,7 +827,10 @@ def value_counts(
830827
-------
831828
Series
832829
"""
833-
from pandas.core.series import Series
830+
from pandas import (
831+
Index,
832+
Series,
833+
)
834834

835835
name = getattr(values, "name", None)
836836

@@ -868,7 +868,13 @@ def value_counts(
868868
else:
869869
keys, counts = value_counts_arraylike(values, dropna)
870870

871-
result = Series(counts, index=keys, name=name)
871+
# For backwards compatibility, we let Index do its normal type
872+
# inference, _except_ for if if infers from object to bool.
873+
idx = Index._with_infer(keys)
874+
if idx.dtype == bool and keys.dtype == object:
875+
idx = idx.astype(object)
876+
877+
result = Series(counts, index=idx, name=name)
872878

873879
if sort:
874880
result = result.sort_values(ascending=ascending)

pandas/core/indexes/base.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,10 @@ def __new__(
505505
if data.dtype.kind in ["i", "u", "f"]:
506506
# maybe coerce to a sub-class
507507
arr = data
508+
elif data.dtype.kind == "b":
509+
# No special subclass, and Index._ensure_array won't do this
510+
# for us.
511+
arr = np.asarray(data)
508512
else:
509513
arr = com.asarray_tuplesafe(data, dtype=_dtype_obj)
510514

@@ -702,7 +706,7 @@ def _with_infer(cls, *args, **kwargs):
702706
# "Union[ExtensionArray, ndarray[Any, Any]]"; expected
703707
# "ndarray[Any, Any]"
704708
values = lib.maybe_convert_objects(result._values) # type: ignore[arg-type]
705-
if values.dtype.kind in ["i", "u", "f"]:
709+
if values.dtype.kind in ["i", "u", "f", "b"]:
706710
return Index(values, name=result.name)
707711

708712
return result
@@ -872,9 +876,12 @@ def _engine(
872876
):
873877
return libindex.ExtensionEngine(target_values)
874878

879+
target_values = cast(np.ndarray, target_values)
875880
# to avoid a reference cycle, bind `target_values` to a local variable, so
876881
# `self` is not passed into the lambda.
877-
target_values = cast(np.ndarray, target_values)
882+
if target_values.dtype == bool:
883+
return libindex.BoolEngine(target_values)
884+
878885
# error: Argument 1 to "ExtensionEngine" has incompatible type
879886
# "ndarray[Any, Any]"; expected "ExtensionArray"
880887
return self._engine_type(target_values) # type:ignore[arg-type]
@@ -2680,7 +2687,6 @@ def _is_all_dates(self) -> bool:
26802687
"""
26812688
Whether or not the index values only consist of dates.
26822689
"""
2683-
26842690
if needs_i8_conversion(self.dtype):
26852691
return True
26862692
elif self.dtype != _dtype_obj:
@@ -7302,7 +7308,7 @@ def _maybe_cast_data_without_dtype(
73027308
FutureWarning,
73037309
stacklevel=3,
73047310
)
7305-
if result.dtype.kind in ["b", "c"]:
7311+
if result.dtype.kind in ["c"]:
73067312
return subarr
73077313
result = ensure_wrapped_if_datetimelike(result)
73087314
return result

pandas/core/tools/datetimes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,6 +1076,8 @@ def to_datetime(
10761076
result = convert_listlike(arg, format)
10771077
else:
10781078
result = convert_listlike(np.array([arg]), format)[0]
1079+
if isinstance(arg, bool) and isinstance(result, np.bool_):
1080+
result = bool(result) # TODO: avoid this kludge.
10791081

10801082
# error: Incompatible return value type (got "Union[Timestamp, NaTType,
10811083
# Series, Index]", expected "Union[DatetimeIndex, Series, float, str,

pandas/core/util/hashing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def _hash_ndarray(
319319

320320
# First, turn whatever array this is into unsigned 64-bit ints, if we can
321321
# manage it.
322-
elif isinstance(dtype, bool):
322+
elif dtype == bool:
323323
vals = vals.astype("u8")
324324
elif issubclass(dtype.type, (np.datetime64, np.timedelta64)):
325325
vals = vals.view("i8").astype("u8", copy=False)

pandas/tests/base/test_value_counts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ def test_value_counts_with_nan(dropna, index_or_series):
284284
obj = klass(values)
285285
res = obj.value_counts(dropna=dropna)
286286
if dropna is True:
287-
expected = Series([1], index=[True])
287+
expected = Series([1], index=Index([True], dtype=obj.dtype))
288288
else:
289289
expected = Series([1, 1, 1], index=[True, pd.NA, np.nan])
290290
tm.assert_series_equal(res, expected)

pandas/tests/indexes/common.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,8 @@ def test_ensure_copied_data(self, index):
216216
# RangeIndex cannot be initialized from data
217217
# MultiIndex and CategoricalIndex are tested separately
218218
return
219+
elif index.dtype == object and index.inferred_type == "boolean":
220+
init_kwargs["dtype"] = index.dtype
219221

220222
index_type = type(index)
221223
result = index_type(index.values, copy=True, **init_kwargs)
@@ -522,6 +524,9 @@ def test_fillna(self, index):
522524
# GH 11343
523525
if len(index) == 0:
524526
return
527+
elif index.dtype == bool:
528+
# can't hold NAs
529+
return
525530
elif isinstance(index, NumericIndex) and is_integer_dtype(index.dtype):
526531
return
527532
elif isinstance(index, MultiIndex):

0 commit comments

Comments
 (0)