diff --git a/pandas/core/indexes/datetimes.py b/pandas/core/indexes/datetimes.py index fbfee9a1f524c..d9eaaa763fde8 100644 --- a/pandas/core/indexes/datetimes.py +++ b/pandas/core/indexes/datetimes.py @@ -34,7 +34,6 @@ Dtype, DtypeObj, ) -from pandas.errors import InvalidIndexError from pandas.util._decorators import ( cache_readonly, doc, @@ -658,8 +657,7 @@ def get_loc(self, key, method=None, tolerance=None): ------- loc : int """ - if not is_scalar(key): - raise InvalidIndexError(key) + self._check_indexing_error(key) orig_key = key if is_valid_na_for_dtype(key, self.dtype): diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py index 072ab7dff8e5b..232ca9068abc6 100644 --- a/pandas/core/indexes/interval.py +++ b/pandas/core/indexes/interval.py @@ -613,9 +613,7 @@ def get_loc( 0 """ self._check_indexing_method(method) - - if not is_scalar(key): - raise InvalidIndexError(key) + self._check_indexing_error(key) if isinstance(key, Interval): if self.closed != key.closed: diff --git a/pandas/core/indexes/period.py b/pandas/core/indexes/period.py index c1104b80a0a7a..1e705282bd816 100644 --- a/pandas/core/indexes/period.py +++ b/pandas/core/indexes/period.py @@ -27,14 +27,12 @@ Dtype, DtypeObj, ) -from pandas.errors import InvalidIndexError from pandas.util._decorators import doc from pandas.core.dtypes.common import ( is_datetime64_any_dtype, is_float, is_integer, - is_scalar, pandas_dtype, ) from pandas.core.dtypes.dtypes import PeriodDtype @@ -411,9 +409,7 @@ def get_loc(self, key, method=None, tolerance=None): """ orig_key = key - if not is_scalar(key): - raise InvalidIndexError(key) - + self._check_indexing_error(key) if isinstance(key, str): try: diff --git a/pandas/core/indexes/range.py b/pandas/core/indexes/range.py index 746246172b967..8bef105c783c8 100644 --- a/pandas/core/indexes/range.py +++ b/pandas/core/indexes/range.py @@ -385,6 +385,7 @@ def get_loc(self, key, method=None, tolerance=None): return self._range.index(new_key) except ValueError as err: raise KeyError(key) from err + self._check_indexing_error(key) raise KeyError(key) return super().get_loc(key, method=method, tolerance=tolerance) diff --git a/pandas/core/indexes/timedeltas.py b/pandas/core/indexes/timedeltas.py index c60ab06dd08f3..bc21b567dba26 100644 --- a/pandas/core/indexes/timedeltas.py +++ b/pandas/core/indexes/timedeltas.py @@ -13,7 +13,6 @@ DtypeObj, Optional, ) -from pandas.errors import InvalidIndexError from pandas.core.dtypes.common import ( TD64NS_DTYPE, @@ -170,8 +169,7 @@ def get_loc(self, key, method=None, tolerance=None): ------- loc : int, slice, or ndarray[int] """ - if not is_scalar(key): - raise InvalidIndexError(key) + self._check_indexing_error(key) try: key = self._data._validate_scalar(key, unbox=False) diff --git a/pandas/core/series.py b/pandas/core/series.py index 59ea6710ea6cd..3d974cbec4286 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -1072,7 +1072,7 @@ def __setitem__(self, key, value) -> None: # GH#12862 adding a new key to the Series self.loc[key] = value - except TypeError as err: + except (InvalidIndexError, TypeError) as err: if isinstance(key, tuple) and not isinstance(self.index, MultiIndex): raise KeyError( "key of type tuple not found and not a MultiIndex" @@ -1094,8 +1094,7 @@ def __setitem__(self, key, value) -> None: self._maybe_update_cacher() def _set_with_engine(self, key, value) -> None: - # fails with AttributeError for IntervalIndex - loc = self.index._engine.get_loc(key) + loc = self.index.get_loc(key) # error: Argument 1 to "validate_numeric_casting" has incompatible type # "Union[dtype, ExtensionDtype]"; expected "dtype" validate_numeric_casting(self.dtype, value) # type: ignore[arg-type] @@ -1112,6 +1111,9 @@ def _set_with(self, key, value): if is_scalar(key): key = [key] + elif is_iterator(key): + # Without this, the call to infer_dtype will consume the generator + key = list(key) if isinstance(key, Index): key_type = key.inferred_type diff --git a/pandas/tests/indexes/test_indexing.py b/pandas/tests/indexes/test_indexing.py index 379c766b94d6c..5f6d0155ae6cf 100644 --- a/pandas/tests/indexes/test_indexing.py +++ b/pandas/tests/indexes/test_indexing.py @@ -26,6 +26,7 @@ IntervalIndex, MultiIndex, PeriodIndex, + RangeIndex, Series, TimedeltaIndex, UInt64Index, @@ -181,6 +182,27 @@ def test_get_value(self, index): tm.assert_almost_equal(result, values[67]) +class TestGetLoc: + def test_get_loc_non_hashable(self, index): + # MultiIndex and Index raise TypeError, others InvalidIndexError + + with pytest.raises((TypeError, InvalidIndexError), match="slice"): + index.get_loc(slice(0, 1)) + + def test_get_loc_generator(self, index): + + exc = KeyError + if isinstance( + index, + (DatetimeIndex, TimedeltaIndex, PeriodIndex, RangeIndex, IntervalIndex), + ): + # TODO: make these more consistent? + exc = InvalidIndexError + with pytest.raises(exc, match="generator object"): + # MultiIndex specifically checks for generator; others for scalar + index.get_loc(x for x in range(5)) + + class TestGetIndexer: def test_get_indexer_base(self, index): diff --git a/pandas/tests/indexing/test_indexing.py b/pandas/tests/indexing/test_indexing.py index c945bd6b95ee1..fd4690a05da05 100644 --- a/pandas/tests/indexing/test_indexing.py +++ b/pandas/tests/indexing/test_indexing.py @@ -113,15 +113,6 @@ def test_setitem_ndarray_3d(self, index, frame_or_series, indexer_sli): if indexer_sli is tm.iloc: err = ValueError msg = f"Cannot set values with ndim > {obj.ndim}" - elif ( - isinstance(index, pd.IntervalIndex) - and indexer_sli is tm.setitem - and obj.ndim == 1 - ): - err = AttributeError - msg = ( - "'pandas._libs.interval.IntervalTree' object has no attribute 'get_loc'" - ) else: err = ValueError msg = "|".join(