From 6015196c96f48aad04e9a07fe587ccab4e5e671f Mon Sep 17 00:00:00 2001 From: Patrick Hoefler Date: Mon, 24 Oct 2022 22:25:47 +0200 Subject: [PATCH 1/2] REGR: MultiIndex.join does not work for ea dtypes --- doc/source/whatsnew/v1.5.2.rst | 2 +- pandas/core/indexes/base.py | 5 +-- pandas/tests/indexes/multi/test_join.py | 48 +++++++++++++++++++++++++ 3 files changed, 52 insertions(+), 3 deletions(-) diff --git a/doc/source/whatsnew/v1.5.2.rst b/doc/source/whatsnew/v1.5.2.rst index aaf00804262bb..7a6811c05151e 100644 --- a/doc/source/whatsnew/v1.5.2.rst +++ b/doc/source/whatsnew/v1.5.2.rst @@ -13,7 +13,7 @@ including other versions of pandas. Fixed regressions ~~~~~~~~~~~~~~~~~ -- +- Fixed regression in :meth:`MultiIndex.join` for extension array dtypes (:issue:`49277`) - .. --------------------------------------------------------------------------- diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index eacb33171aa56..8c009c20bc3eb 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -4675,8 +4675,9 @@ def join( return self._join_non_unique(other, how=how) elif not self.is_unique or not other.is_unique: if self.is_monotonic_increasing and other.is_monotonic_increasing: - if self._can_use_libjoin: + if not is_interval_dtype(self.dtype): # otherwise we will fall through to _join_via_get_indexer + # go through object dtype for ea till engine is supported properly return self._join_monotonic(other, how=how) else: return self._join_non_unique(other, how=how) @@ -5051,7 +5052,7 @@ def _wrap_joined_index(self: _IndexT, joined: ArrayLike, other: _IndexT) -> _Ind return self._constructor(joined, name=name) # type: ignore[return-value] else: name = get_op_result_name(self, other) - return self._constructor._with_infer(joined, name=name) + return self._constructor._with_infer(joined, name=name, dtype=self.dtype) @cache_readonly def _can_use_libjoin(self) -> bool: diff --git a/pandas/tests/indexes/multi/test_join.py b/pandas/tests/indexes/multi/test_join.py index e6bec97aedb38..7000724a6b271 100644 --- a/pandas/tests/indexes/multi/test_join.py +++ b/pandas/tests/indexes/multi/test_join.py @@ -5,6 +5,8 @@ Index, Interval, MultiIndex, + Series, + StringDtype, ) import pandas._testing as tm @@ -158,3 +160,49 @@ def test_join_overlapping_interval_level(): result = idx_1.join(idx_2, how="outer") tm.assert_index_equal(result, expected) + + +def test_join_midx_ea(): + # GH#49277 + midx = MultiIndex.from_arrays( + [Series([1, 1, 3], dtype="Int64"), Series([1, 2, 3], dtype="Int64")], + names=["a", "b"], + ) + midx2 = MultiIndex.from_arrays( + [Series([1], dtype="Int64"), Series([3], dtype="Int64")], names=["a", "c"] + ) + result = midx.join(midx2, how="inner") + expected = MultiIndex.from_arrays( + [ + Series([1, 1], dtype="Int64"), + Series([1, 2], dtype="Int64"), + Series([3, 3], dtype="Int64"), + ], + names=["a", "b", "c"], + ) + tm.assert_index_equal(result, expected) + + +def test_join_midx_string(): + # GH#49277 + midx = MultiIndex.from_arrays( + [ + Series(["a", "a", "c"], dtype=StringDtype()), + Series(["a", "b", "c"], dtype=StringDtype()), + ], + names=["a", "b"], + ) + midx2 = MultiIndex.from_arrays( + [Series(["a"], dtype=StringDtype()), Series(["c"], dtype=StringDtype())], + names=["a", "c"], + ) + result = midx.join(midx2, how="inner") + expected = MultiIndex.from_arrays( + [ + Series(["a", "a"], dtype=StringDtype()), + Series(["a", "b"], dtype=StringDtype()), + Series(["c", "c"], dtype=StringDtype()), + ], + names=["a", "b", "c"], + ) + tm.assert_index_equal(result, expected) From a20177b2723ca676fcf3cbecc8368e8a52d3da78 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Mon, 31 Oct 2022 10:39:07 +0100 Subject: [PATCH 2/2] Update base.py --- pandas/core/indexes/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 9762bf49a24ef..7b3119902e1d1 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -4638,6 +4638,7 @@ def join( if self.is_monotonic_increasing and other.is_monotonic_increasing: if not is_interval_dtype(self.dtype): # otherwise we will fall through to _join_via_get_indexer + # GH#39133 # go through object dtype for ea till engine is supported properly return self._join_monotonic(other, how=how) else: