diff --git a/doc/source/whatsnew/v1.3.0.rst b/doc/source/whatsnew/v1.3.0.rst index 3a0e1b7568c91..e0bcd805bc30c 100644 --- a/doc/source/whatsnew/v1.3.0.rst +++ b/doc/source/whatsnew/v1.3.0.rst @@ -567,6 +567,7 @@ Performance improvements - Performance improvement in :class:`Styler` where render times are more than 50% reduced (:issue:`39972` :issue:`39952`) - Performance improvement in :meth:`core.window.ewm.ExponentialMovingWindow.mean` with ``times`` (:issue:`39784`) - Performance improvement in :meth:`.GroupBy.apply` when requiring the python fallback implementation (:issue:`40176`) +- Performance improvement for concatenation of data with type :class:`CategoricalDtype` (:issue:`40193`) .. --------------------------------------------------------------------------- @@ -795,6 +796,7 @@ ExtensionArray - Bug in :meth:`DataFrame.where` when ``other`` is a :class:`Series` with :class:`ExtensionArray` dtype (:issue:`38729`) - Fixed bug where :meth:`Series.idxmax`, :meth:`Series.idxmin` and ``argmax/min`` fail when the underlying data is :class:`ExtensionArray` (:issue:`32749`, :issue:`33719`, :issue:`36566`) +- Fixed a bug where some properties of subclasses of :class:`PandasExtensionDtype` where improperly cached (:issue:`40329`) - Other diff --git a/pandas/core/dtypes/dtypes.py b/pandas/core/dtypes/dtypes.py index 2785874878c96..84eede019251b 100644 --- a/pandas/core/dtypes/dtypes.py +++ b/pandas/core/dtypes/dtypes.py @@ -15,6 +15,7 @@ import pytz from pandas._libs.interval import Interval +from pandas._libs.properties import cache_readonly from pandas._libs.tslibs import ( BaseOffset, NaT, @@ -81,7 +82,7 @@ class PandasExtensionDtype(ExtensionDtype): base: DtypeObj | None = None isbuiltin = 0 isnative = 0 - _cache: dict[str_type, PandasExtensionDtype] = {} + _cache_dtypes: dict[str_type, PandasExtensionDtype] = {} def __str__(self) -> str_type: """ @@ -105,7 +106,7 @@ def __getstate__(self) -> dict[str_type, Any]: @classmethod def reset_cache(cls) -> None: """ clear the cache """ - cls._cache = {} + cls._cache_dtypes = {} class CategoricalDtypeType(type): @@ -177,7 +178,7 @@ class CategoricalDtype(PandasExtensionDtype, ExtensionDtype): str = "|O08" base = np.dtype("O") _metadata = ("categories", "ordered") - _cache: dict[str_type, PandasExtensionDtype] = {} + _cache_dtypes: dict[str_type, PandasExtensionDtype] = {} def __init__(self, categories=None, ordered: Ordered = False): self._finalize(categories, ordered, fastpath=False) @@ -355,7 +356,7 @@ def __hash__(self) -> int: else: return -2 # We *do* want to include the real self.ordered here - return int(self._hash_categories(self.categories, self.ordered)) + return int(self._hash_categories) def __eq__(self, other: Any) -> bool: """ @@ -429,14 +430,17 @@ def __repr__(self) -> str_type: data = data.rstrip(", ") return f"CategoricalDtype(categories={data}, ordered={self.ordered})" - @staticmethod - def _hash_categories(categories, ordered: Ordered = True) -> int: + @cache_readonly + def _hash_categories(self) -> int: from pandas.core.util.hashing import ( combine_hash_arrays, hash_array, hash_tuples, ) + categories = self.categories + ordered = self.ordered + if len(categories) and isinstance(categories[0], tuple): # assumes if any individual category is a tuple, then all our. ATM # I don't really want to support just some of the categories being @@ -671,7 +675,7 @@ class DatetimeTZDtype(PandasExtensionDtype): na_value = NaT _metadata = ("unit", "tz") _match = re.compile(r"(datetime64|M8)\[(?P.+), (?P.+)\]") - _cache: dict[str_type, PandasExtensionDtype] = {} + _cache_dtypes: dict[str_type, PandasExtensionDtype] = {} def __init__(self, unit: str_type | DatetimeTZDtype = "ns", tz=None): if isinstance(unit, DatetimeTZDtype): @@ -837,7 +841,7 @@ class PeriodDtype(dtypes.PeriodDtypeBase, PandasExtensionDtype): num = 102 _metadata = ("freq",) _match = re.compile(r"(P|p)eriod\[(?P.+)\]") - _cache: dict[str_type, PandasExtensionDtype] = {} + _cache_dtypes: dict[str_type, PandasExtensionDtype] = {} def __new__(cls, freq=None): """ @@ -859,12 +863,12 @@ def __new__(cls, freq=None): freq = cls._parse_dtype_strict(freq) try: - return cls._cache[freq.freqstr] + return cls._cache_dtypes[freq.freqstr] except KeyError: dtype_code = freq._period_dtype_code u = dtypes.PeriodDtypeBase.__new__(cls, dtype_code) u._freq = freq - cls._cache[freq.freqstr] = u + cls._cache_dtypes[freq.freqstr] = u return u def __reduce__(self): @@ -1042,7 +1046,7 @@ class IntervalDtype(PandasExtensionDtype): _match = re.compile( r"(I|i)nterval\[(?P[^,]+)(, (?P(right|left|both|neither)))?\]" ) - _cache: dict[str_type, PandasExtensionDtype] = {} + _cache_dtypes: dict[str_type, PandasExtensionDtype] = {} def __new__(cls, subtype=None, closed: str_type | None = None): from pandas.core.dtypes.common import ( @@ -1099,12 +1103,12 @@ def __new__(cls, subtype=None, closed: str_type | None = None): key = str(subtype) + str(closed) try: - return cls._cache[key] + return cls._cache_dtypes[key] except KeyError: u = object.__new__(cls) u._subtype = subtype u._closed = closed - cls._cache[key] = u + cls._cache_dtypes[key] = u return u @property diff --git a/pandas/tests/dtypes/test_dtypes.py b/pandas/tests/dtypes/test_dtypes.py index 51a7969162abf..abb29ce66fd34 100644 --- a/pandas/tests/dtypes/test_dtypes.py +++ b/pandas/tests/dtypes/test_dtypes.py @@ -66,7 +66,7 @@ def test_pickle(self, dtype): # clear the cache type(dtype).reset_cache() - assert not len(dtype._cache) + assert not len(dtype._cache_dtypes) # force back to the cache result = tm.round_trip_pickle(dtype) @@ -74,7 +74,7 @@ def test_pickle(self, dtype): # Because PeriodDtype has a cython class as a base class, # it has different pickle semantics, and its cache is re-populated # on un-pickling. - assert not len(dtype._cache) + assert not len(dtype._cache_dtypes) assert result == dtype @@ -791,14 +791,14 @@ def test_basic_dtype(self): def test_caching(self): IntervalDtype.reset_cache() dtype = IntervalDtype("int64", "right") - assert len(IntervalDtype._cache) == 1 + assert len(IntervalDtype._cache_dtypes) == 1 IntervalDtype("interval") - assert len(IntervalDtype._cache) == 2 + assert len(IntervalDtype._cache_dtypes) == 2 IntervalDtype.reset_cache() tm.round_trip_pickle(dtype) - assert len(IntervalDtype._cache) == 0 + assert len(IntervalDtype._cache_dtypes) == 0 def test_not_string(self): # GH30568: though IntervalDtype has object kind, it cannot be string