diff --git a/pandas/tests/extension/decimal/test_decimal.py b/pandas/tests/extension/decimal/test_decimal.py index 4d05a5a799eac..baa056550624f 100644 --- a/pandas/tests/extension/decimal/test_decimal.py +++ b/pandas/tests/extension/decimal/test_decimal.py @@ -65,35 +65,71 @@ def data_for_grouping(): return DecimalArray([b, b, na, na, a, a, b, c]) -class TestDtype(base.BaseDtypeTests): - pass +class TestDecimalArray(base.ExtensionTests): + def _get_expected_exception( + self, op_name: str, obj, other + ) -> type[Exception] | None: + return None + def _supports_reduction(self, obj, op_name: str) -> bool: + return True -class TestInterface(base.BaseInterfaceTests): - pass + def check_reduce(self, s, op_name, skipna): + if op_name == "count": + return super().check_reduce(s, op_name, skipna) + else: + result = getattr(s, op_name)(skipna=skipna) + expected = getattr(np.asarray(s), op_name)() + tm.assert_almost_equal(result, expected) + def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna, request): + if all_numeric_reductions in ["kurt", "skew", "sem", "median"]: + mark = pytest.mark.xfail(raises=NotImplementedError) + request.node.add_marker(mark) + super().test_reduce_series_numeric(data, all_numeric_reductions, skipna) -class TestConstructors(base.BaseConstructorsTests): - pass + def test_reduce_frame(self, data, all_numeric_reductions, skipna, request): + op_name = all_numeric_reductions + if op_name in ["skew", "median"]: + mark = pytest.mark.xfail(raises=NotImplementedError) + request.node.add_marker(mark) + return super().test_reduce_frame(data, all_numeric_reductions, skipna) -class TestReshaping(base.BaseReshapingTests): - pass + def test_compare_scalar(self, data, comparison_op): + ser = pd.Series(data) + self._compare_other(ser, data, comparison_op, 0.5) + def test_compare_array(self, data, comparison_op): + ser = pd.Series(data) -class TestGetitem(base.BaseGetitemTests): - def test_take_na_value_other_decimal(self): - arr = DecimalArray([decimal.Decimal("1.0"), decimal.Decimal("2.0")]) - result = arr.take([0, -1], allow_fill=True, fill_value=decimal.Decimal("-1.0")) - expected = DecimalArray([decimal.Decimal("1.0"), decimal.Decimal("-1.0")]) - tm.assert_extension_array_equal(result, expected) + alter = np.random.default_rng(2).choice([-1, 0, 1], len(data)) + # Randomly double, halve or keep same value + other = pd.Series(data) * [decimal.Decimal(pow(2.0, i)) for i in alter] + self._compare_other(ser, data, comparison_op, other) + def test_arith_series_with_array(self, data, all_arithmetic_operators): + op_name = all_arithmetic_operators + ser = pd.Series(data) + + context = decimal.getcontext() + divbyzerotrap = context.traps[decimal.DivisionByZero] + invalidoptrap = context.traps[decimal.InvalidOperation] + context.traps[decimal.DivisionByZero] = 0 + context.traps[decimal.InvalidOperation] = 0 -class TestIndex(base.BaseIndexTests): - pass + # Decimal supports ops with int, but not float + other = pd.Series([int(d * 100) for d in data]) + self.check_opname(ser, op_name, other) + + if "mod" not in op_name: + self.check_opname(ser, op_name, ser * 2) + self.check_opname(ser, op_name, 0) + self.check_opname(ser, op_name, 5) + context.traps[decimal.DivisionByZero] = divbyzerotrap + context.traps[decimal.InvalidOperation] = invalidoptrap -class TestMissing(base.BaseMissingTests): def test_fillna_frame(self, data_missing): msg = "ExtensionArray.fillna added a 'copy' keyword" with tm.assert_produces_warning( @@ -141,59 +177,6 @@ def test_fillna_series_method(self, data_missing, fillna_method): ): super().test_fillna_series_method(data_missing, fillna_method) - -class Reduce: - def _supports_reduction(self, obj, op_name: str) -> bool: - return True - - def check_reduce(self, s, op_name, skipna): - if op_name == "count": - return super().check_reduce(s, op_name, skipna) - else: - result = getattr(s, op_name)(skipna=skipna) - expected = getattr(np.asarray(s), op_name)() - tm.assert_almost_equal(result, expected) - - def test_reduction_without_keepdims(self): - # GH52788 - # test _reduce without keepdims - - class DecimalArray2(DecimalArray): - def _reduce(self, name: str, *, skipna: bool = True, **kwargs): - # no keepdims in signature - return super()._reduce(name, skipna=skipna) - - arr = DecimalArray2([decimal.Decimal(2) for _ in range(100)]) - - ser = pd.Series(arr) - result = ser.agg("sum") - expected = decimal.Decimal(200) - assert result == expected - - df = pd.DataFrame({"a": arr, "b": arr}) - with tm.assert_produces_warning(FutureWarning): - result = df.agg("sum") - expected = pd.Series({"a": 200, "b": 200}, dtype=object) - tm.assert_series_equal(result, expected) - - -class TestReduce(Reduce, base.BaseReduceTests): - def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna, request): - if all_numeric_reductions in ["kurt", "skew", "sem", "median"]: - mark = pytest.mark.xfail(raises=NotImplementedError) - request.node.add_marker(mark) - super().test_reduce_series_numeric(data, all_numeric_reductions, skipna) - - def test_reduce_frame(self, data, all_numeric_reductions, skipna, request): - op_name = all_numeric_reductions - if op_name in ["skew", "median"]: - mark = pytest.mark.xfail(raises=NotImplementedError) - request.node.add_marker(mark) - - return super().test_reduce_frame(data, all_numeric_reductions, skipna) - - -class TestMethods(base.BaseMethodsTests): def test_fillna_copy_frame(self, data_missing, using_copy_on_write): warn = FutureWarning if not using_copy_on_write else None msg = "ExtensionArray.fillna added a 'copy' keyword" @@ -226,20 +209,6 @@ def test_value_counts(self, all_data, dropna, request): tm.assert_series_equal(result, expected) - -class TestCasting(base.BaseCastingTests): - pass - - -class TestGroupby(base.BaseGroupbyTests): - pass - - -class TestSetitem(base.BaseSetitemTests): - pass - - -class TestPrinting(base.BasePrintingTests): def test_series_repr(self, data): # Overriding this base test to explicitly test that # the custom _formatter is used @@ -247,6 +216,24 @@ def test_series_repr(self, data): assert data.dtype.name in repr(ser) assert "Decimal: " in repr(ser) + @pytest.mark.xfail( + reason="Looks like the test (incorrectly) implicitly assumes int/bool dtype" + ) + def test_invert(self, data): + super().test_invert(data) + + @pytest.mark.xfail(reason="Inconsistent array-vs-scalar behavior") + @pytest.mark.parametrize("ufunc", [np.positive, np.negative, np.abs]) + def test_unary_ufunc_dunder_equivalence(self, data, ufunc): + super().test_unary_ufunc_dunder_equivalence(data, ufunc) + + +def test_take_na_value_other_decimal(): + arr = DecimalArray([decimal.Decimal("1.0"), decimal.Decimal("2.0")]) + result = arr.take([0, -1], allow_fill=True, fill_value=decimal.Decimal("-1.0")) + expected = DecimalArray([decimal.Decimal("1.0"), decimal.Decimal("-1.0")]) + tm.assert_extension_array_equal(result, expected) + def test_series_constructor_coerce_data_to_extension_dtype(): dtype = DecimalDtype() @@ -305,53 +292,6 @@ def test_astype_dispatches(frame): assert result.dtype.context.prec == ctx.prec -class TestArithmeticOps(base.BaseArithmeticOpsTests): - series_scalar_exc = None - frame_scalar_exc = None - series_array_exc = None - - def _get_expected_exception( - self, op_name: str, obj, other - ) -> type[Exception] | None: - return None - - def test_arith_series_with_array(self, data, all_arithmetic_operators): - op_name = all_arithmetic_operators - s = pd.Series(data) - - context = decimal.getcontext() - divbyzerotrap = context.traps[decimal.DivisionByZero] - invalidoptrap = context.traps[decimal.InvalidOperation] - context.traps[decimal.DivisionByZero] = 0 - context.traps[decimal.InvalidOperation] = 0 - - # Decimal supports ops with int, but not float - other = pd.Series([int(d * 100) for d in data]) - self.check_opname(s, op_name, other) - - if "mod" not in op_name: - self.check_opname(s, op_name, s * 2) - - self.check_opname(s, op_name, 0) - self.check_opname(s, op_name, 5) - context.traps[decimal.DivisionByZero] = divbyzerotrap - context.traps[decimal.InvalidOperation] = invalidoptrap - - -class TestComparisonOps(base.BaseComparisonOpsTests): - def test_compare_scalar(self, data, comparison_op): - s = pd.Series(data) - self._compare_other(s, data, comparison_op, 0.5) - - def test_compare_array(self, data, comparison_op): - s = pd.Series(data) - - alter = np.random.default_rng(2).choice([-1, 0, 1], len(data)) - # Randomly double, halve or keep same value - other = pd.Series(data) * [decimal.Decimal(pow(2.0, i)) for i in alter] - self._compare_other(s, data, comparison_op, other) - - class DecimalArrayWithoutFromSequence(DecimalArray): """Helper class for testing error handling in _from_sequence."""