From 2a3f35e369ef421bccd85a0b70800900582c1a31 Mon Sep 17 00:00:00 2001 From: Jared Hance Date: Wed, 4 Aug 2021 11:38:40 -0700 Subject: [PATCH 1/3] Implement bytes equality optimizations --- mypyc/irbuild/ll_builder.py | 14 +++++++++++++- mypyc/lib-rt/bytes_ops.c | 24 ++++++++++++++++++++++++ mypyc/primitives/bytes_ops.py | 11 +++++++++-- mypyc/test-data/irbuild-bytes.test | 25 +++++++++++++++++++++++++ mypyc/test-data/run-bytes.test | 2 ++ 5 files changed, 73 insertions(+), 3 deletions(-) diff --git a/mypyc/irbuild/ll_builder.py b/mypyc/irbuild/ll_builder.py index e5d46bf5edd4..c39e2b3ae931 100644 --- a/mypyc/irbuild/ll_builder.py +++ b/mypyc/irbuild/ll_builder.py @@ -33,7 +33,8 @@ is_list_rprimitive, is_tuple_rprimitive, is_dict_rprimitive, is_set_rprimitive, PySetObject, none_rprimitive, RTuple, is_bool_rprimitive, is_str_rprimitive, c_int_rprimitive, pointer_rprimitive, PyObject, PyListObject, bit_rprimitive, is_bit_rprimitive, - object_pointer_rprimitive, c_size_t_rprimitive, dict_rprimitive, bytes_rprimitive + object_pointer_rprimitive, c_size_t_rprimitive, dict_rprimitive, bytes_rprimitive, + is_bytes_rprimitive ) from mypyc.ir.func_ir import FuncDecl, FuncSignature from mypyc.ir.class_ir import ClassIR, all_concrete_classes @@ -45,6 +46,7 @@ method_call_ops, CFunctionDescription, function_ops, binary_ops, unary_ops, ERR_NEG_INT ) +from mypyc.primitives.bytes_ops import bytes_compare from mypyc.primitives.list_ops import ( list_extend_op, new_list_op, list_build_op ) @@ -855,8 +857,12 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value: # Special case various ops if op in ('is', 'is not'): return self.translate_is_op(lreg, rreg, op, line) + # TODO: modify 'str' to use same interface as 'compare_bytes' as it avoids + # call to PyErr_Occurred() if is_str_rprimitive(ltype) and is_str_rprimitive(rtype) and op in ('==', '!='): return self.compare_strings(lreg, rreg, op, line) + if is_bytes_rprimitive(ltype) and is_bytes_rprimitive(rtype) and op in ('==', '!='): + return self.compare_bytes(lreg, rreg, op, line) if is_tagged(ltype) and is_tagged(rtype) and op in int_comparison_op_mapping: return self.compare_tagged(lreg, rreg, op, line) if is_bool_rprimitive(ltype) and is_bool_rprimitive(rtype) and op in ( @@ -997,6 +1003,12 @@ def compare_strings(self, lhs: Value, rhs: Value, op: str, line: int) -> Value: return self.add(ComparisonOp(compare_result, Integer(0, c_int_rprimitive), op_type, line)) + def compare_bytes(self, lhs: Value, rhs: Value, op: str, line: int) -> Value: + compare_result = self.call_c(bytes_compare, [lhs, rhs], line) + op_type = ComparisonOp.EQ if op == '==' else ComparisonOp.NEQ + return self.add(ComparisonOp(compare_result, + Integer(1, c_int_rprimitive), op_type, line)) + def compare_tuples(self, lhs: Value, rhs: Value, diff --git a/mypyc/lib-rt/bytes_ops.c b/mypyc/lib-rt/bytes_ops.c index 99771bdf926e..496d3e14f870 100644 --- a/mypyc/lib-rt/bytes_ops.c +++ b/mypyc/lib-rt/bytes_ops.c @@ -4,3 +4,27 @@ #include #include "CPy.h" + +// Returns -1 on error, 0 on inequality, 1 on equality. +// +// Falls back to PyObject_RichCompareBool. +int CPyBytes_Compare(PyObject *left, PyObject *right) { + if (PyBytes_CheckExact(left) && PyBytes_CheckExact(right)) { + if (left == right) { + return 1; + } + + Py_ssize_t len = Py_SIZE(left); + if (Py_SIZE(right) != len) { + return 0; + } + PyBytesObject *left_b = (PyBytesObject *)left; + PyBytesObject *right_b = (PyBytesObject *)right; + if (left_b->ob_sval[0] != right_b->ob_sval[0]) { + return 0; + } + + return memcmp(left_b->ob_sval, right_b->ob_sval, len) == 0; + } + return PyObject_RichCompareBool(left, right, Py_EQ); +} diff --git a/mypyc/primitives/bytes_ops.py b/mypyc/primitives/bytes_ops.py index a5963d2ad8fa..2d07a00d6ec1 100644 --- a/mypyc/primitives/bytes_ops.py +++ b/mypyc/primitives/bytes_ops.py @@ -3,9 +3,9 @@ from mypyc.ir.ops import ERR_MAGIC from mypyc.ir.rtypes import ( object_rprimitive, bytes_rprimitive, list_rprimitive, dict_rprimitive, - str_rprimitive, RUnion + str_rprimitive, c_int_rprimitive, RUnion ) -from mypyc.primitives.registry import load_address_op, function_op +from mypyc.primitives.registry import load_address_op, function_op, custom_op, ERR_NEG_INT # Get the 'bytes' type object. @@ -29,3 +29,10 @@ return_type=bytes_rprimitive, c_function_name='PyByteArray_FromObject', error_kind=ERR_MAGIC) + +# bytes ==/!= (return -1/0/1) +bytes_compare = custom_op( + arg_types=[bytes_rprimitive, bytes_rprimitive], + return_type=c_int_rprimitive, + c_function_name='CPyBytes_Compare', + error_kind=ERR_NEG_INT) diff --git a/mypyc/test-data/irbuild-bytes.test b/mypyc/test-data/irbuild-bytes.test index 479f97872f5e..cbb1ba130cc7 100644 --- a/mypyc/test-data/irbuild-bytes.test +++ b/mypyc/test-data/irbuild-bytes.test @@ -62,3 +62,28 @@ L0: c = r6 return 1 +[case testBytesEquality] +def eq(x: bytes, y: bytes) -> bool: + return x == y + +def neq(x: bytes, y: bytes) -> bool: + return x != y +[out] +def eq(x, y): + x, y :: bytes + r0 :: int32 + r1, r2 :: bit +L0: + r0 = CPyBytes_Compare(x, y) + r1 = r0 >= 0 :: signed + r2 = r0 == 1 + return r2 +def neq(x, y): + x, y :: bytes + r0 :: int32 + r1, r2 :: bit +L0: + r0 = CPyBytes_Compare(x, y) + r1 = r0 >= 0 :: signed + r2 = r0 != 1 + return r2 diff --git a/mypyc/test-data/run-bytes.test b/mypyc/test-data/run-bytes.test index 17ebe6085fec..982cbec6680e 100644 --- a/mypyc/test-data/run-bytes.test +++ b/mypyc/test-data/run-bytes.test @@ -17,6 +17,8 @@ assert f(b'123') == b'123' assert f(b'\x07 \x0b " \t \x7f \xf0') == b'\x07 \x0b " \t \x7f \xf0' assert eq(b'123', b'123') assert not eq(b'123', b'1234') +assert not eq(b'123', b'124') +assert not eq(b'123', b'223') assert neq(b'123', b'1234') try: f('x') From d777ac4a1c5ef9c0b6ff23011f22bf50a53cfec6 Mon Sep 17 00:00:00 2001 From: Jared Hance Date: Wed, 4 Aug 2021 11:43:47 -0700 Subject: [PATCH 2/3] Add comment --- mypyc/lib-rt/bytes_ops.c | 1 + 1 file changed, 1 insertion(+) diff --git a/mypyc/lib-rt/bytes_ops.c b/mypyc/lib-rt/bytes_ops.c index 496d3e14f870..35ba12e6b7eb 100644 --- a/mypyc/lib-rt/bytes_ops.c +++ b/mypyc/lib-rt/bytes_ops.c @@ -14,6 +14,7 @@ int CPyBytes_Compare(PyObject *left, PyObject *right) { return 1; } + // Adapted from cpython internal implementation of bytes_compare. Py_ssize_t len = Py_SIZE(left); if (Py_SIZE(right) != len) { return 0; From 256840e564d443d791abb49c0747245fb9e357af Mon Sep 17 00:00:00 2001 From: Jared Hance Date: Mon, 9 Aug 2021 12:41:29 -0700 Subject: [PATCH 3/3] Fix header --- mypyc/lib-rt/CPy.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index a5603765727c..975b2007f698 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -401,6 +401,10 @@ Py_ssize_t CPyStr_Size_size_t(PyObject *str); +int CPyBytes_Compare(PyObject *left, PyObject *right); + + + // Set operations