diff --git a/mypyc/irbuild/ll_builder.py b/mypyc/irbuild/ll_builder.py index dafb0aab9646..b6cf990d025d 100644 --- a/mypyc/irbuild/ll_builder.py +++ b/mypyc/irbuild/ll_builder.py @@ -46,6 +46,7 @@ method_call_ops, CFunctionDescription, function_ops, binary_ops, unary_ops, ERR_NEG_INT ) +from mypyc.primitives.bytes_ops import bytes_compare from mypyc.primitives.list_ops import ( list_extend_op, new_list_op, list_build_op ) @@ -860,8 +861,12 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value: # Special case various ops if op in ('is', 'is not'): return self.translate_is_op(lreg, rreg, op, line) + # TODO: modify 'str' to use same interface as 'compare_bytes' as it avoids + # call to PyErr_Occurred() if is_str_rprimitive(ltype) and is_str_rprimitive(rtype) and op in ('==', '!='): return self.compare_strings(lreg, rreg, op, line) + if is_bytes_rprimitive(ltype) and is_bytes_rprimitive(rtype) and op in ('==', '!='): + return self.compare_bytes(lreg, rreg, op, line) if is_tagged(ltype) and is_tagged(rtype) and op in int_comparison_op_mapping: return self.compare_tagged(lreg, rreg, op, line) if is_bool_rprimitive(ltype) and is_bool_rprimitive(rtype) and op in ( @@ -1002,6 +1007,12 @@ def compare_strings(self, lhs: Value, rhs: Value, op: str, line: int) -> Value: return self.add(ComparisonOp(compare_result, Integer(0, c_int_rprimitive), op_type, line)) + def compare_bytes(self, lhs: Value, rhs: Value, op: str, line: int) -> Value: + compare_result = self.call_c(bytes_compare, [lhs, rhs], line) + op_type = ComparisonOp.EQ if op == '==' else ComparisonOp.NEQ + return self.add(ComparisonOp(compare_result, + Integer(1, c_int_rprimitive), op_type, line)) + def compare_tuples(self, lhs: Value, rhs: Value, diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index cfdea38dabd8..8aa7247bd66f 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -409,6 +409,10 @@ PyObject *CPyBytes_Concat(PyObject *a, PyObject *b); PyObject *CPyBytes_Join(PyObject *sep, PyObject *iter); +int CPyBytes_Compare(PyObject *left, PyObject *right); + + + // Set operations diff --git a/mypyc/lib-rt/bytes_ops.c b/mypyc/lib-rt/bytes_ops.c index b80fc3102f55..bece7c13c957 100644 --- a/mypyc/lib-rt/bytes_ops.c +++ b/mypyc/lib-rt/bytes_ops.c @@ -5,6 +5,31 @@ #include #include "CPy.h" +// Returns -1 on error, 0 on inequality, 1 on equality. +// +// Falls back to PyObject_RichCompareBool. +int CPyBytes_Compare(PyObject *left, PyObject *right) { + if (PyBytes_CheckExact(left) && PyBytes_CheckExact(right)) { + if (left == right) { + return 1; + } + + // Adapted from cpython internal implementation of bytes_compare. + Py_ssize_t len = Py_SIZE(left); + if (Py_SIZE(right) != len) { + return 0; + } + PyBytesObject *left_b = (PyBytesObject *)left; + PyBytesObject *right_b = (PyBytesObject *)right; + if (left_b->ob_sval[0] != right_b->ob_sval[0]) { + return 0; + } + + return memcmp(left_b->ob_sval, right_b->ob_sval, len) == 0; + } + return PyObject_RichCompareBool(left, right, Py_EQ); +} + CPyTagged CPyBytes_GetItem(PyObject *o, CPyTagged index) { if (CPyTagged_CheckShort(index)) { Py_ssize_t n = CPyTagged_ShortAsSsize_t(index); diff --git a/mypyc/primitives/bytes_ops.py b/mypyc/primitives/bytes_ops.py index ea6c7ec0f72b..6ddb5e38111c 100644 --- a/mypyc/primitives/bytes_ops.py +++ b/mypyc/primitives/bytes_ops.py @@ -3,10 +3,11 @@ from mypyc.ir.ops import ERR_MAGIC from mypyc.ir.rtypes import ( object_rprimitive, bytes_rprimitive, list_rprimitive, dict_rprimitive, - str_rprimitive, RUnion, int_rprimitive, c_pyssize_t_rprimitive + str_rprimitive, c_int_rprimitive, RUnion, c_pyssize_t_rprimitive, + int_rprimitive, ) from mypyc.primitives.registry import ( - load_address_op, function_op, method_op, binary_op, custom_op + load_address_op, function_op, method_op, binary_op, custom_op, ERR_NEG_INT ) # Get the 'bytes' type object. @@ -31,6 +32,13 @@ c_function_name='PyByteArray_FromObject', error_kind=ERR_MAGIC) +# bytes ==/!= (return -1/0/1) +bytes_compare = custom_op( + arg_types=[bytes_rprimitive, bytes_rprimitive], + return_type=c_int_rprimitive, + c_function_name='CPyBytes_Compare', + error_kind=ERR_NEG_INT) + # bytes + bytes # bytearray + bytearray binary_op( diff --git a/mypyc/test-data/irbuild-bytes.test b/mypyc/test-data/irbuild-bytes.test index 17b1c3942154..f13a1a956580 100644 --- a/mypyc/test-data/irbuild-bytes.test +++ b/mypyc/test-data/irbuild-bytes.test @@ -62,6 +62,32 @@ L0: c = r6 return 1 +[case testBytesEquality] +def eq(x: bytes, y: bytes) -> bool: + return x == y + +def neq(x: bytes, y: bytes) -> bool: + return x != y +[out] +def eq(x, y): + x, y :: bytes + r0 :: int32 + r1, r2 :: bit +L0: + r0 = CPyBytes_Compare(x, y) + r1 = r0 >= 0 :: signed + r2 = r0 == 1 + return r2 +def neq(x, y): + x, y :: bytes + r0 :: int32 + r1, r2 :: bit +L0: + r0 = CPyBytes_Compare(x, y) + r1 = r0 >= 0 :: signed + r2 = r0 != 1 + return r2 + [case testBytesSlicing] def f(a: bytes, start: int, end: int) -> bytes: return a[start:end] diff --git a/mypyc/test-data/run-bytes.test b/mypyc/test-data/run-bytes.test index 96c2e7f84235..aaf541194ac6 100644 --- a/mypyc/test-data/run-bytes.test +++ b/mypyc/test-data/run-bytes.test @@ -17,6 +17,8 @@ assert f(b'123') == b'123' assert f(b'\x07 \x0b " \t \x7f \xf0') == b'\x07 \x0b " \t \x7f \xf0' assert eq(b'123', b'123') assert not eq(b'123', b'1234') +assert not eq(b'123', b'124') +assert not eq(b'123', b'223') assert neq(b'123', b'1234') try: f('x')