diff --git a/doc/source/whatsnew/v1.2.0.rst b/doc/source/whatsnew/v1.2.0.rst index 09cb024cbd95c..1aa88e56689c2 100644 --- a/doc/source/whatsnew/v1.2.0.rst +++ b/doc/source/whatsnew/v1.2.0.rst @@ -552,7 +552,7 @@ Groupby/resample/rolling - Bug in :meth:`Rolling.median` and :meth:`Rolling.quantile` returned wrong values for :class:`BaseIndexer` subclasses with non-monotonic starting or ending points for windows (:issue:`37153`) - Bug in :meth:`DataFrame.groupby` dropped ``nan`` groups from result with ``dropna=False`` when grouping over a single column (:issue:`35646`, :issue:`35542`) - Bug in :meth:`DataFrameGroupBy.head`, :meth:`DataFrameGroupBy.tail`, :meth:`SeriesGroupBy.head`, and :meth:`SeriesGroupBy.tail` would raise when used with ``axis=1`` (:issue:`9772`) - +- Bug in :meth:`DataFrameGroupBy.transform` would raise when used with ``axis=1`` and a transformation kernel (e.g. "shift") (:issue:`36308`) Reshaping ^^^^^^^^^ diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 6f86819303537..3395b9d36fd0c 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -1675,11 +1675,16 @@ def _wrap_transformed_output( DataFrame """ indexed_output = {key.position: val for key, val in output.items()} - columns = Index(key.label for key in output) - columns.name = self.obj.columns.name - result = self.obj._constructor(indexed_output) - result.columns = columns + + if self.axis == 1: + result = result.T + result.columns = self.obj.columns + else: + columns = Index(key.label for key in output) + columns.name = self.obj.columns.name + result.columns = columns + result.index = self.obj.index return result diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index e3fceb9bf0a06..ec96a0d502d3f 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -2365,7 +2365,7 @@ def cumcount(self, ascending: bool = True): dtype: int64 """ with group_selection_context(self): - index = self._selected_obj.index + index = self._selected_obj._get_axis(self.axis) cumcounts = self._cumcount_array(ascending=ascending) return self._obj_1d_constructor(cumcounts, index) @@ -2706,8 +2706,8 @@ def pct_change(self, periods=1, fill_method="pad", limit=None, freq=None, axis=0 fill_method = "pad" limit = 0 filled = getattr(self, fill_method)(limit=limit) - fill_grp = filled.groupby(self.grouper.codes) - shifted = fill_grp.shift(periods=periods, freq=freq) + fill_grp = filled.groupby(self.grouper.codes, axis=self.axis) + shifted = fill_grp.shift(periods=periods, freq=freq, axis=self.axis) return (filled / shifted) - 1 @Substitution(name="groupby") diff --git a/pandas/tests/frame/apply/test_frame_transform.py b/pandas/tests/frame/apply/test_frame_transform.py index d141fa8682c10..f2a66b553b366 100644 --- a/pandas/tests/frame/apply/test_frame_transform.py +++ b/pandas/tests/frame/apply/test_frame_transform.py @@ -27,8 +27,6 @@ def test_transform_groupby_kernel(axis, float_frame, op): pytest.xfail("DataFrame.cumcount does not exist") if op == "tshift": pytest.xfail("Only works on time index and is deprecated") - if axis == 1 or axis == "columns": - pytest.xfail("GH 36308: groupby.transform with axis=1 is broken") args = [0.0] if op == "fillna" else [] if axis == 0 or axis == "index": diff --git a/pandas/tests/groupby/transform/test_transform.py b/pandas/tests/groupby/transform/test_transform.py index d7426a5e3b42e..b4e023f569844 100644 --- a/pandas/tests/groupby/transform/test_transform.py +++ b/pandas/tests/groupby/transform/test_transform.py @@ -158,7 +158,25 @@ def test_transform_broadcast(tsframe, ts): assert_fp_equal(res.xs(idx), agged[idx]) -def test_transform_axis(tsframe): +def test_transform_axis_1(transformation_func): + # GH 36308 + if transformation_func == "tshift": + pytest.xfail("tshift is deprecated") + args = ("ffill",) if transformation_func == "fillna" else tuple() + + df = DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]}, index=["x", "y"]) + result = df.groupby([0, 0, 1], axis=1).transform(transformation_func, *args) + expected = df.T.groupby([0, 0, 1]).transform(transformation_func, *args).T + + if transformation_func == "diff": + # Result contains nans, so transpose coerces to float + expected["b"] = expected["b"].astype("int64") + + # cumcount returns Series; the rest are DataFrame + tm.assert_equal(result, expected) + + +def test_transform_axis_ts(tsframe): # make sure that we are setting the axes # correctly when on axis=0 or 1