From f3457b49f140d00faf412e5a53c52f42c3dfcd0e Mon Sep 17 00:00:00 2001 From: Louie Lu Date: Mon, 1 May 2017 22:46:54 +0800 Subject: [PATCH 1/7] bpo-15987: Add AST nodes eq & ne compare methods --- Lib/ast.py | 37 ++++++++++++++++++++++++++++ Lib/test/test_ast.py | 57 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+) diff --git a/Lib/ast.py b/Lib/ast.py index 070c2bee7f9dee..7174a4700b11e6 100644 --- a/Lib/ast.py +++ b/Lib/ast.py @@ -24,9 +24,46 @@ :copyright: Copyright 2008 by Armin Ronacher. :license: Python License. """ +import _ast from _ast import * +# Add EQ and NE compare to _ast Nodes +_NODES = [vars(_ast)[k] for k in filter( + lambda x: not x.startswith('_') and x not in ('AST' 'PyCF_ONLY_AST'), vars(_ast))] + + +def ast_eq_compare(a, b): + if type(a) != type(b): + return False + if type(a) not in _NODES: + return a == b + + ret = True + for field in a._fields: + af = vars(a)[field] + bf = vars(b)[field] + if isinstance(af, list): + if len(af) != len(bf): + return False + for i, j in zip(af, bf): + ret &= ast_eq_compare(i, j) + elif type(af) in _NODES: + ret &= ast_eq_compare(af, bf) + elif af != bf: + return False + return ret + + +def ast_ne_compare(a, b): + return not ast_eq_compare(a, b) + + +for n in _NODES: + n.__eq__ = ast_eq_compare + n.__ne__ = ast_ne_compare + + def parse(source, filename='', mode='exec'): """ Parse the source into an AST node. diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py index 2f59527cb8a450..cf9c5de883a1dd 100644 --- a/Lib/test/test_ast.py +++ b/Lib/test/test_ast.py @@ -6,6 +6,8 @@ import weakref from test import support +from test.support import findfile + def to_tuple(t): if t is None or isinstance(t, (str, int, complex)): @@ -432,6 +434,61 @@ def test_empty_yield_from(self): self.assertIn("field value is required", str(cm.exception)) +class ASTCompareTest(unittest.TestCase): + def test_literals_compare(self): + self.assertEqual(ast.Num(-20), ast.Num(-20)) + self.assertEqual(ast.Num(10), ast.Num(10)) + self.assertEqual(ast.Num(2048), ast.Num(2048)) + self.assertEqual(ast.Str("ABCD"), ast.Str("ABCD")) + self.assertEqual(ast.Str("中文字"), ast.Str("中文字")) + + self.assertNotEqual(ast.Num(10), ast.Num(20)) + self.assertNotEqual(ast.Num(-10), ast.Num(10)) + self.assertNotEqual(ast.Str("AAAA"), ast.Str("BBBB")) + self.assertNotEqual(ast.Str("一二三"), ast.Str("中文字")) + + def test_operator_compare(self): + self.assertEqual(ast.Add(), ast.Add()) + self.assertEqual(ast.Sub(), ast.Sub()) + + self.assertNotEqual(ast.Add(), ast.Sub()) + self.assertNotEqual(ast.Add(), ast.Num()) + + def test_complex_ast(self): + fps = [findfile('test_asyncgen.py'), + findfile('test_generators.py'), + findfile('test_unicode.py')] + + for fp in fps: + with open(fp) as f: + source = f.read() + a = ast.parse(source) + b = ast.parse(source) + self.assertEqual(a, b) + self.assertFalse(a != b) + + def test_exec_compare(self): + for source in exec_tests: + a = ast.parse(source, mode='exec') + b = ast.parse(source, mode='exec') + self.assertEqual(a, b) + self.assertFalse(a != b) + + def test_single_compare(self): + for source in single_tests: + a = ast.parse(source, mode='single') + b = ast.parse(source, mode='single') + self.assertEqual(a, b) + self.assertFalse(a != b) + + def test_eval_compare(self): + for source in eval_tests: + a = ast.parse(source, mode='eval') + b = ast.parse(source, mode='eval') + self.assertEqual(a, b) + self.assertFalse(a != b) + + class ASTHelpers_Test(unittest.TestCase): def test_parse(self): From 84241a383e3f721ee178aa82844d12d2a5829b54 Mon Sep 17 00:00:00 2001 From: Louie Lu Date: Mon, 1 May 2017 23:01:57 +0800 Subject: [PATCH 2/7] Skip complex ast compare when get unicode decode error --- Lib/test/test_ast.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py index cf9c5de883a1dd..6bb608b4bf025f 100644 --- a/Lib/test/test_ast.py +++ b/Lib/test/test_ast.py @@ -461,7 +461,11 @@ def test_complex_ast(self): for fp in fps: with open(fp) as f: - source = f.read() + try: + source = f.read() + except UnicodeDecodeError: + continue + a = ast.parse(source) b = ast.parse(source) self.assertEqual(a, b) From b47c6fd9b20181d7464ae81aaa1d95181dca1164 Mon Sep 17 00:00:00 2001 From: Louie Lu Date: Tue, 2 May 2017 12:22:30 +0800 Subject: [PATCH 3/7] Add base type richcompare methods --- Lib/ast.py | 37 -------------------------------- Lib/test/test_ast.py | 6 ++++++ Python/Python-ast.c | 51 +++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 56 insertions(+), 38 deletions(-) diff --git a/Lib/ast.py b/Lib/ast.py index 7174a4700b11e6..070c2bee7f9dee 100644 --- a/Lib/ast.py +++ b/Lib/ast.py @@ -24,46 +24,9 @@ :copyright: Copyright 2008 by Armin Ronacher. :license: Python License. """ -import _ast from _ast import * -# Add EQ and NE compare to _ast Nodes -_NODES = [vars(_ast)[k] for k in filter( - lambda x: not x.startswith('_') and x not in ('AST' 'PyCF_ONLY_AST'), vars(_ast))] - - -def ast_eq_compare(a, b): - if type(a) != type(b): - return False - if type(a) not in _NODES: - return a == b - - ret = True - for field in a._fields: - af = vars(a)[field] - bf = vars(b)[field] - if isinstance(af, list): - if len(af) != len(bf): - return False - for i, j in zip(af, bf): - ret &= ast_eq_compare(i, j) - elif type(af) in _NODES: - ret &= ast_eq_compare(af, bf) - elif af != bf: - return False - return ret - - -def ast_ne_compare(a, b): - return not ast_eq_compare(a, b) - - -for n in _NODES: - n.__eq__ = ast_eq_compare - n.__ne__ = ast_ne_compare - - def parse(source, filename='', mode='exec'): """ Parse the source into an AST node. diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py index 6bb608b4bf025f..4782fa1aaa78ab 100644 --- a/Lib/test/test_ast.py +++ b/Lib/test/test_ast.py @@ -435,6 +435,12 @@ def test_empty_yield_from(self): class ASTCompareTest(unittest.TestCase): + def test_normal_compare(self): + self.assertEqual(ast.parse('x = 10'), ast.parse('x = 10')) + self.assertNotEqual(ast.parse('x = 10'), ast.parse('')) + self.assertNotEqual(ast.parse('x = 10'), ast.parse('x')) + self.assertNotEqual(ast.parse('x = 10;y = 20'), ast.parse('class C:pass')) + def test_literals_compare(self): self.assertEqual(ast.Num(-20), ast.Num(-20)) self.assertEqual(ast.Num(10), ast.Num(10)) diff --git a/Python/Python-ast.c b/Python/Python-ast.c index 2759b2fe9c4b33..4ce256a1dc5619 100644 --- a/Python/Python-ast.c +++ b/Python/Python-ast.c @@ -617,6 +617,55 @@ static PyGetSetDef ast_type_getsets[] = { {NULL} }; +static PyObject * +ast_richcompare(PyObject *self, PyObject *other, int op) +{ + int i, len; + PyObject *fields, *key, *a, *b; + + /* Check operator */ + if ((op != Py_EQ && op != Py_NE) || + !PyAST_Check(self) || + !PyAST_Check(other)) { + Py_RETURN_NOTIMPLEMENTED; + } + + if (Py_TYPE(self) != Py_TYPE(other)) { + if (op == Py_EQ) + Py_RETURN_FALSE; + else + Py_RETURN_TRUE; + } + + fields = PyObject_GetAttrString(self, "_fields"); + len = PySequence_Size(fields); + for (i = 0; i < len; ++i) { + key = PySequence_GetItem(fields, i); + a = PyObject_GetAttr(self, key); + b = PyObject_GetAttr(other, key); + if (Py_TYPE(a) != Py_TYPE(b)) { + if (op == Py_EQ) + Py_RETURN_FALSE; + } + + if (op == Py_EQ) { + if (!PyObject_RichCompareBool(a, b, Py_EQ)) { + Py_RETURN_FALSE; + } + } + else if (op == Py_NE) { + if (PyObject_RichCompareBool(a, b, Py_NE)) { + Py_RETURN_TRUE; + } + } + } + + if (op == Py_EQ) + Py_RETURN_TRUE; + else + Py_RETURN_FALSE; +} + static PyTypeObject AST_type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) "_ast.AST", @@ -641,7 +690,7 @@ static PyTypeObject AST_type = { 0, /* tp_doc */ (traverseproc)ast_traverse, /* tp_traverse */ (inquiry)ast_clear, /* tp_clear */ - 0, /* tp_richcompare */ + ast_richcompare, /* tp_richcompare */ 0, /* tp_weaklistoffset */ 0, /* tp_iter */ 0, /* tp_iternext */ From f2c3655c1991040ede6ff0c7f8b63a34b7139efd Mon Sep 17 00:00:00 2001 From: Louie Lu Date: Tue, 2 May 2017 13:41:58 +0800 Subject: [PATCH 4/7] Add comments --- Python/Python-ast.c | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Python/Python-ast.c b/Python/Python-ast.c index 4ce256a1dc5619..2512823085862f 100644 --- a/Python/Python-ast.c +++ b/Python/Python-ast.c @@ -630,6 +630,7 @@ ast_richcompare(PyObject *self, PyObject *other, int op) Py_RETURN_NOTIMPLEMENTED; } + /* Compare types */ if (Py_TYPE(self) != Py_TYPE(other)) { if (op == Py_EQ) Py_RETURN_FALSE; @@ -637,6 +638,7 @@ ast_richcompare(PyObject *self, PyObject *other, int op) Py_RETURN_TRUE; } + /* Compare fields */ fields = PyObject_GetAttrString(self, "_fields"); len = PySequence_Size(fields); for (i = 0; i < len; ++i) { From aa92d325e42e23e5009118709b35ac997971e6e8 Mon Sep 17 00:00:00 2001 From: Louie Lu Date: Tue, 2 May 2017 14:31:59 +0800 Subject: [PATCH 5/7] Add corner case and fix richcompare --- Lib/test/test_ast.py | 5 +++++ Python/Python-ast.c | 34 ++++++++++++++++++++-------------- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py index 4782fa1aaa78ab..3026da7085f60f 100644 --- a/Lib/test/test_ast.py +++ b/Lib/test/test_ast.py @@ -442,9 +442,11 @@ def test_normal_compare(self): self.assertNotEqual(ast.parse('x = 10;y = 20'), ast.parse('class C:pass')) def test_literals_compare(self): + self.assertEqual(ast.Num(), ast.Num()) self.assertEqual(ast.Num(-20), ast.Num(-20)) self.assertEqual(ast.Num(10), ast.Num(10)) self.assertEqual(ast.Num(2048), ast.Num(2048)) + self.assertEqual(ast.Str(), ast.Str()) self.assertEqual(ast.Str("ABCD"), ast.Str("ABCD")) self.assertEqual(ast.Str("中文字"), ast.Str("中文字")) @@ -453,6 +455,9 @@ def test_literals_compare(self): self.assertNotEqual(ast.Str("AAAA"), ast.Str("BBBB")) self.assertNotEqual(ast.Str("一二三"), ast.Str("中文字")) + self.assertNotEqual(ast.Num(10), ast.Num()) + self.assertNotEqual(ast.Str("AB"), ast.Str()) + def test_operator_compare(self): self.assertEqual(ast.Add(), ast.Add()) self.assertEqual(ast.Sub(), ast.Sub()) diff --git a/Python/Python-ast.c b/Python/Python-ast.c index 2512823085862f..b5bf1283054b7c 100644 --- a/Python/Python-ast.c +++ b/Python/Python-ast.c @@ -607,21 +607,11 @@ ast_type_reduce(PyObject *self, PyObject *unused) return Py_BuildValue("O()", Py_TYPE(self)); } -static PyMethodDef ast_type_methods[] = { - {"__reduce__", ast_type_reduce, METH_NOARGS, NULL}, - {NULL} -}; - -static PyGetSetDef ast_type_getsets[] = { - {"__dict__", PyObject_GenericGetDict, PyObject_GenericSetDict}, - {NULL} -}; - static PyObject * ast_richcompare(PyObject *self, PyObject *other, int op) { int i, len; - PyObject *fields, *key, *a, *b; + PyObject *fields, *key, *a = Py_None, *b = Py_None; /* Check operator */ if ((op != Py_EQ && op != Py_NE) || @@ -643,11 +633,17 @@ ast_richcompare(PyObject *self, PyObject *other, int op) len = PySequence_Size(fields); for (i = 0; i < len; ++i) { key = PySequence_GetItem(fields, i); - a = PyObject_GetAttr(self, key); - b = PyObject_GetAttr(other, key); + + if (PyObject_HasAttr(self, key)) + a = PyObject_GetAttr(self, key); + if (PyObject_HasAttr(other, key)) + b = PyObject_GetAttr(other, key); + + if (Py_TYPE(a) != Py_TYPE(b)) { - if (op == Py_EQ) + if (op == Py_EQ) { Py_RETURN_FALSE; + } } if (op == Py_EQ) { @@ -668,6 +664,16 @@ ast_richcompare(PyObject *self, PyObject *other, int op) Py_RETURN_FALSE; } +static PyMethodDef ast_type_methods[] = { + {"__reduce__", ast_type_reduce, METH_NOARGS, NULL}, + {NULL} +}; + +static PyGetSetDef ast_type_getsets[] = { + {"__dict__", PyObject_GenericGetDict, PyObject_GenericSetDict}, + {NULL} +}; + static PyTypeObject AST_type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) "_ast.AST", From 886fc478ca14f5fda770107d70f82b4b8203c9f2 Mon Sep 17 00:00:00 2001 From: Louie Lu Date: Tue, 2 May 2017 15:32:41 +0800 Subject: [PATCH 6/7] Add more descriptive assertEqual for debugging --- Lib/test/test_ast.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py index 3026da7085f60f..6b65e66a84d3bd 100644 --- a/Lib/test/test_ast.py +++ b/Lib/test/test_ast.py @@ -479,28 +479,28 @@ def test_complex_ast(self): a = ast.parse(source) b = ast.parse(source) - self.assertEqual(a, b) + self.assertEqual(a, b, "%s != %s" % (ast.dump(a), ast.dump(b))) self.assertFalse(a != b) def test_exec_compare(self): for source in exec_tests: a = ast.parse(source, mode='exec') b = ast.parse(source, mode='exec') - self.assertEqual(a, b) + self.assertEqual(a, b, "%s != %s" % (ast.dump(a), ast.dump(b))) self.assertFalse(a != b) def test_single_compare(self): for source in single_tests: a = ast.parse(source, mode='single') b = ast.parse(source, mode='single') - self.assertEqual(a, b) + self.assertEqual(a, b, "%s != %s" % (ast.dump(a), ast.dump(b))) self.assertFalse(a != b) def test_eval_compare(self): for source in eval_tests: a = ast.parse(source, mode='eval') b = ast.parse(source, mode='eval') - self.assertEqual(a, b) + self.assertEqual(a, b, "%s != %s" % (ast.dump(a), ast.dump(b))) self.assertFalse(a != b) From 5ec28b25a64b2395b484f31af54b2491e6165fe4 Mon Sep 17 00:00:00 2001 From: Louie Lu Date: Tue, 2 May 2017 15:52:34 +0800 Subject: [PATCH 7/7] Debugging with travis --- Lib/test/test_ast.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py index 6b65e66a84d3bd..05becb64a3f1b2 100644 --- a/Lib/test/test_ast.py +++ b/Lib/test/test_ast.py @@ -435,6 +435,10 @@ def test_empty_yield_from(self): class ASTCompareTest(unittest.TestCase): + def setUp(self): + import imp + imp.reload(ast) + def test_normal_compare(self): self.assertEqual(ast.parse('x = 10'), ast.parse('x = 10')) self.assertNotEqual(ast.parse('x = 10'), ast.parse(''))