diff --git a/cyberpandas/ip_array.py b/cyberpandas/ip_array.py index cafb7ea..ab1c85a 100644 --- a/cyberpandas/ip_array.py +++ b/cyberpandas/ip_array.py @@ -286,11 +286,19 @@ def to_bytes(self): # ------------------------------------------------------------------------ def __eq__(self, other): - # TDOO: scalar ipaddress - if not isinstance(other, IPArray): - return NotImplemented - mask = self.isna() | other.isna() - result = self.data == other.data + if isinstance(other, (ipaddress.IPv4Address, ipaddress.IPv6Address)): + other = int(other) + hi, lo = unpack(pack(other)) + other = np.array([(hi, lo)], dtype=self.dtype._record_type) + mask = self.isna() + elif isinstance(other, IPArray): + mask = self.isna() | other.isna() + other = other.data + else: + msg = ("Invalid type comparison. Can't compare IPArray to " + "type '{}'.") + raise TypeError(msg.format(other)) + result = self.data == other result[mask] = False return result diff --git a/tests/ip/test_ip.py b/tests/ip/test_ip.py index 6bb9d5d..fa1c1b8 100644 --- a/tests/ip/test_ip.py +++ b/tests/ip/test_ip.py @@ -112,6 +112,27 @@ def test_equality(): v1.equals("a") +@pytest.mark.parametrize('other', [ + 1, '192.168.1.1', b'1' +]) +def test_ops_other(other): + arr = ip.IPArray([1, 2, 3]) + + with pytest.raises(TypeError): + arr == other + + +@pytest.mark.parametrize('other', [ + ipaddress.IPv4Address(1), + ipaddress.IPv6Address(1), +]) +def test_equality_ipaddress(other): + arr = ip.IPArray([0, 1, 2**64 + 1]) + result = arr == other + expected = np.array([False, True, False]) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize('op', [ operator.lt, operator.le,