From acfd1b89b640ebcc8722bc44ad1063db1c21c192 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Wed, 3 Apr 2024 11:56:26 -0700 Subject: [PATCH 1/2] _with_infer includes RangeIndex --- pandas/core/indexes/base.py | 21 ++++++++++++++++--- .../arrays/categorical/test_constructors.py | 5 +++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 73e564f95cf65..6eea9e5892593 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -667,12 +667,25 @@ def _simple_new( return result @classmethod - def _with_infer(cls, *args, **kwargs): + def _with_infer( + cls, + data=None, + dtype=None, + copy: bool = False, + name=None, + tupleize_cols: bool = True, + ): """ Constructor that uses the 1.0.x behavior inferring numeric dtypes for ndarray[object] inputs. """ - result = cls(*args, **kwargs) + result = cls( + data=maybe_sequence_to_range(data), + dtype=dtype, + copy=copy, + name=name, + tupleize_cols=tupleize_cols, + ) if result.dtype == _dtype_obj and not result._is_multi: # error: Argument 1 to "maybe_convert_objects" has incompatible type @@ -7136,7 +7149,9 @@ def maybe_sequence_to_range(sequence) -> Any | range: """ if isinstance(sequence, (range, ExtensionArray)): return sequence - elif len(sequence) == 1 or lib.infer_dtype(sequence, skipna=False) != "integer": + elif isinstance(sequence, abc.Generator): + sequence = list(sequence) + if len(sequence) == 1 or lib.infer_dtype(sequence, skipna=False) != "integer": return sequence elif isinstance(sequence, (ABCSeries, Index)) and not ( isinstance(sequence.dtype, np.dtype) and sequence.dtype.kind == "i" diff --git a/pandas/tests/arrays/categorical/test_constructors.py b/pandas/tests/arrays/categorical/test_constructors.py index 857b14e2a2558..fe1bee0cdc166 100644 --- a/pandas/tests/arrays/categorical/test_constructors.py +++ b/pandas/tests/arrays/categorical/test_constructors.py @@ -794,3 +794,8 @@ def test_range_values_preserves_rangeindex_categories(self, values, categories): result = Categorical(values=values, categories=categories).categories expected = RangeIndex(range(5)) tm.assert_index_equal(result, expected, exact=True) + + def test_categoricaldtype_numeric_object_to_rangeindex_categories(self): + result = CategoricalDtype(np.array([1, 2], dtype=object)).categories + expected = RangeIndex(1, 3) + tm.assert_index_equal(result, expected, exact=True) From 94a27dadd4358eb4a801fa7a38f38bef6f5122e9 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Wed, 3 Apr 2024 14:50:22 -0700 Subject: [PATCH 2/2] Add groupby test --- pandas/tests/groupby/aggregate/test_aggregate.py | 14 ++++++++++++-- pandas/tests/groupby/test_groupby.py | 13 ++++++++++--- pandas/tests/groupby/transform/test_transform.py | 10 +++++++++- 3 files changed, 31 insertions(+), 6 deletions(-) diff --git a/pandas/tests/groupby/aggregate/test_aggregate.py b/pandas/tests/groupby/aggregate/test_aggregate.py index 2b9df1b7079da..822dfbc620538 100644 --- a/pandas/tests/groupby/aggregate/test_aggregate.py +++ b/pandas/tests/groupby/aggregate/test_aggregate.py @@ -19,6 +19,7 @@ DataFrame, Index, MultiIndex, + RangeIndex, Series, concat, to_datetime, @@ -517,7 +518,7 @@ def test_callable_result_dtype_frame( df["c"] = df["c"].astype(input_dtype) op = getattr(df.groupby(keys)[["c"]], method) result = op(lambda x: x.astype(result_dtype).iloc[0]) - expected_index = pd.RangeIndex(0, 1) if method == "transform" else agg_index + expected_index = RangeIndex(0, 1) if method == "transform" else agg_index expected = DataFrame({"c": [df["c"].iloc[0]]}, index=expected_index).astype( result_dtype ) @@ -541,7 +542,7 @@ def test_callable_result_dtype_series(keys, agg_index, input, dtype, method): df = DataFrame({"a": [1], "b": [2], "c": [input]}) op = getattr(df.groupby(keys)["c"], method) result = op(lambda x: x.astype(dtype).iloc[0]) - expected_index = pd.RangeIndex(0, 1) if method == "transform" else agg_index + expected_index = RangeIndex(0, 1) if method == "transform" else agg_index expected = Series([df["c"].iloc[0]], index=expected_index, name="c").astype(dtype) tm.assert_series_equal(result, expected) @@ -1663,3 +1664,12 @@ def func(x): msg = "length must not be 0" with pytest.raises(ValueError, match=msg): df.groupby("A", observed=False).agg(func) + + +def test_agg_groups_returns_rangeindex(): + df = DataFrame({"group": [1, 1, 2], "value": [1, 2, 3]}) + result = df.groupby("group").agg(max) + expected = DataFrame( + [2, 3], index=RangeIndex(1, 3, name="group"), columns=["value"] + ) + tm.assert_frame_equal(result, expected, check_index_type=True) diff --git a/pandas/tests/groupby/test_groupby.py b/pandas/tests/groupby/test_groupby.py index be8f5d73fe7e8..4e5da6494a5a8 100644 --- a/pandas/tests/groupby/test_groupby.py +++ b/pandas/tests/groupby/test_groupby.py @@ -46,9 +46,7 @@ def test_groupby_nonobject_dtype(multiindex_dataframe_random_data): result = grouped.sum() expected = multiindex_dataframe_random_data.groupby(key.astype("O")).sum() - assert result.index.dtype == np.int8 - assert expected.index.dtype == np.int64 - tm.assert_frame_equal(result, expected, check_index_type=False) + tm.assert_frame_equal(result, expected, check_index_type=True) def test_groupby_nonobject_dtype_mixed(): @@ -2955,3 +2953,12 @@ def test_groupby_dropna_with_nunique_unique(): ) tm.assert_frame_equal(result, expected) + + +def test_groupby_groups_returns_rangeindex(): + df = DataFrame({"group": [1, 1, 2], "value": [1, 2, 3]}) + result = df.groupby("group").max() + expected = DataFrame( + [2, 3], index=RangeIndex(1, 3, name="group"), columns=["value"] + ) + tm.assert_frame_equal(result, expected, check_index_type=True) diff --git a/pandas/tests/groupby/transform/test_transform.py b/pandas/tests/groupby/transform/test_transform.py index 245fb9c7babd7..2a08e1111232f 100644 --- a/pandas/tests/groupby/transform/test_transform.py +++ b/pandas/tests/groupby/transform/test_transform.py @@ -13,6 +13,7 @@ DataFrame, Index, MultiIndex, + RangeIndex, Series, Timestamp, concat, @@ -290,7 +291,7 @@ def test_transform_casting(): ), "DATETIME": pd.to_datetime([f"2014-10-08 {time}" for time in times]), }, - index=pd.RangeIndex(11, name="idx"), + index=RangeIndex(11, name="idx"), ) result = df.groupby("ID3")["DATETIME"].transform(lambda x: x.diff()) @@ -1535,3 +1536,10 @@ def test_transform_sum_one_column_with_matching_labels_and_missing_labels(): result = df.groupby(series, as_index=False).transform("sum") expected = DataFrame({"X": [-93203.0, -93203.0, np.nan]}) tm.assert_frame_equal(result, expected) + + +def test_transform_groups_returns_rangeindex(): + df = DataFrame({"group": [1, 1, 2], "value": [1, 2, 3]}) + result = df.groupby("group").transform(lambda x: x + 1) + expected = DataFrame([2, 3, 4], index=RangeIndex(0, 3), columns=["value"]) + tm.assert_frame_equal(result, expected, check_index_type=True)