diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py index 2f59527cb8a450..05becb64a3f1b2 100644 --- a/Lib/test/test_ast.py +++ b/Lib/test/test_ast.py @@ -6,6 +6,8 @@ import weakref from test import support +from test.support import findfile + def to_tuple(t): if t is None or isinstance(t, (str, int, complex)): @@ -432,6 +434,80 @@ def test_empty_yield_from(self): self.assertIn("field value is required", str(cm.exception)) +class ASTCompareTest(unittest.TestCase): + def setUp(self): + import imp + imp.reload(ast) + + def test_normal_compare(self): + self.assertEqual(ast.parse('x = 10'), ast.parse('x = 10')) + self.assertNotEqual(ast.parse('x = 10'), ast.parse('')) + self.assertNotEqual(ast.parse('x = 10'), ast.parse('x')) + self.assertNotEqual(ast.parse('x = 10;y = 20'), ast.parse('class C:pass')) + + def test_literals_compare(self): + self.assertEqual(ast.Num(), ast.Num()) + self.assertEqual(ast.Num(-20), ast.Num(-20)) + self.assertEqual(ast.Num(10), ast.Num(10)) + self.assertEqual(ast.Num(2048), ast.Num(2048)) + self.assertEqual(ast.Str(), ast.Str()) + self.assertEqual(ast.Str("ABCD"), ast.Str("ABCD")) + self.assertEqual(ast.Str("中文字"), ast.Str("中文字")) + + self.assertNotEqual(ast.Num(10), ast.Num(20)) + self.assertNotEqual(ast.Num(-10), ast.Num(10)) + self.assertNotEqual(ast.Str("AAAA"), ast.Str("BBBB")) + self.assertNotEqual(ast.Str("一二三"), ast.Str("中文字")) + + self.assertNotEqual(ast.Num(10), ast.Num()) + self.assertNotEqual(ast.Str("AB"), ast.Str()) + + def test_operator_compare(self): + self.assertEqual(ast.Add(), ast.Add()) + self.assertEqual(ast.Sub(), ast.Sub()) + + self.assertNotEqual(ast.Add(), ast.Sub()) + self.assertNotEqual(ast.Add(), ast.Num()) + + def test_complex_ast(self): + fps = [findfile('test_asyncgen.py'), + findfile('test_generators.py'), + findfile('test_unicode.py')] + + for fp in fps: + with open(fp) as f: + try: + source = f.read() + except UnicodeDecodeError: + continue + + a = ast.parse(source) + b = ast.parse(source) + self.assertEqual(a, b, "%s != %s" % (ast.dump(a), ast.dump(b))) + self.assertFalse(a != b) + + def test_exec_compare(self): + for source in exec_tests: + a = ast.parse(source, mode='exec') + b = ast.parse(source, mode='exec') + self.assertEqual(a, b, "%s != %s" % (ast.dump(a), ast.dump(b))) + self.assertFalse(a != b) + + def test_single_compare(self): + for source in single_tests: + a = ast.parse(source, mode='single') + b = ast.parse(source, mode='single') + self.assertEqual(a, b, "%s != %s" % (ast.dump(a), ast.dump(b))) + self.assertFalse(a != b) + + def test_eval_compare(self): + for source in eval_tests: + a = ast.parse(source, mode='eval') + b = ast.parse(source, mode='eval') + self.assertEqual(a, b, "%s != %s" % (ast.dump(a), ast.dump(b))) + self.assertFalse(a != b) + + class ASTHelpers_Test(unittest.TestCase): def test_parse(self): diff --git a/Python/Python-ast.c b/Python/Python-ast.c index 2759b2fe9c4b33..b5bf1283054b7c 100644 --- a/Python/Python-ast.c +++ b/Python/Python-ast.c @@ -607,6 +607,63 @@ ast_type_reduce(PyObject *self, PyObject *unused) return Py_BuildValue("O()", Py_TYPE(self)); } +static PyObject * +ast_richcompare(PyObject *self, PyObject *other, int op) +{ + int i, len; + PyObject *fields, *key, *a = Py_None, *b = Py_None; + + /* Check operator */ + if ((op != Py_EQ && op != Py_NE) || + !PyAST_Check(self) || + !PyAST_Check(other)) { + Py_RETURN_NOTIMPLEMENTED; + } + + /* Compare types */ + if (Py_TYPE(self) != Py_TYPE(other)) { + if (op == Py_EQ) + Py_RETURN_FALSE; + else + Py_RETURN_TRUE; + } + + /* Compare fields */ + fields = PyObject_GetAttrString(self, "_fields"); + len = PySequence_Size(fields); + for (i = 0; i < len; ++i) { + key = PySequence_GetItem(fields, i); + + if (PyObject_HasAttr(self, key)) + a = PyObject_GetAttr(self, key); + if (PyObject_HasAttr(other, key)) + b = PyObject_GetAttr(other, key); + + + if (Py_TYPE(a) != Py_TYPE(b)) { + if (op == Py_EQ) { + Py_RETURN_FALSE; + } + } + + if (op == Py_EQ) { + if (!PyObject_RichCompareBool(a, b, Py_EQ)) { + Py_RETURN_FALSE; + } + } + else if (op == Py_NE) { + if (PyObject_RichCompareBool(a, b, Py_NE)) { + Py_RETURN_TRUE; + } + } + } + + if (op == Py_EQ) + Py_RETURN_TRUE; + else + Py_RETURN_FALSE; +} + static PyMethodDef ast_type_methods[] = { {"__reduce__", ast_type_reduce, METH_NOARGS, NULL}, {NULL} @@ -641,7 +698,7 @@ static PyTypeObject AST_type = { 0, /* tp_doc */ (traverseproc)ast_traverse, /* tp_traverse */ (inquiry)ast_clear, /* tp_clear */ - 0, /* tp_richcompare */ + ast_richcompare, /* tp_richcompare */ 0, /* tp_weaklistoffset */ 0, /* tp_iter */ 0, /* tp_iternext */