diff --git a/pandas/core/arrays/arrow/dtype.py b/pandas/core/arrays/arrow/dtype.py index a262a45feb9c5..2ba0711de98f9 100644 --- a/pandas/core/arrays/arrow/dtype.py +++ b/pandas/core/arrays/arrow/dtype.py @@ -27,6 +27,7 @@ StorageExtensionDtype, register_extension_dtype, ) +from pandas.core.dtypes.dtypes import CategoricalDtypeType if not pa_version_under7p0: import pyarrow as pa @@ -106,7 +107,7 @@ def type(self): return int elif pa.types.is_floating(pa_type): return float - elif pa.types.is_string(pa_type): + elif pa.types.is_string(pa_type) or pa.types.is_large_string(pa_type): return str elif ( pa.types.is_binary(pa_type) @@ -132,6 +133,14 @@ def type(self): return time elif pa.types.is_decimal(pa_type): return Decimal + elif pa.types.is_dictionary(pa_type): + # TODO: Potentially change this & CategoricalDtype.type to + # something more representative of the scalar + return CategoricalDtypeType + elif pa.types.is_list(pa_type) or pa.types.is_large_list(pa_type): + return list + elif pa.types.is_map(pa_type): + return dict elif pa.types.is_null(pa_type): # TODO: None? pd.NA? pa.null? return type(pa_type) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 712fccae83cfe..22920e077123e 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -39,6 +39,7 @@ from pandas.errors import PerformanceWarning from pandas.core.dtypes.common import is_any_int_dtype +from pandas.core.dtypes.dtypes import CategoricalDtypeType import pandas as pd import pandas._testing as tm @@ -1530,9 +1531,23 @@ def test_mode_dropna_false_mode_na(data): tm.assert_series_equal(result, expected) -@pytest.mark.parametrize("arrow_dtype", [pa.binary(), pa.binary(16), pa.large_binary()]) -def test_arrow_dtype_type(arrow_dtype): - assert ArrowDtype(arrow_dtype).type == bytes +@pytest.mark.parametrize( + "arrow_dtype, expected_type", + [ + [pa.binary(), bytes], + [pa.binary(16), bytes], + [pa.large_binary(), bytes], + [pa.large_string(), str], + [pa.list_(pa.int64()), list], + [pa.large_list(pa.int64()), list], + [pa.map_(pa.string(), pa.int64()), dict], + [pa.dictionary(pa.int64(), pa.int64()), CategoricalDtypeType], + ], +) +def test_arrow_dtype_type(arrow_dtype, expected_type): + # GH 51845 + # TODO: Redundant with test_getitem_scalar once arrow_dtype exists in data fixture + assert ArrowDtype(arrow_dtype).type == expected_type def test_is_bool_dtype(): @@ -1925,7 +1940,7 @@ def test_str_get(i, exp): @pytest.mark.xfail( reason="TODO: StringMethods._validate should support Arrow list types", - raises=NotImplementedError, + raises=AttributeError, ) def test_str_join(): ser = pd.Series(ArrowExtensionArray(pa.array([list("abc"), list("123"), None])))