@@ -745,51 +745,83 @@ def visitModule(self, mod):
745
745
static PyObject *
746
746
ast_richcompare(PyObject *self, PyObject *other, int op)
747
747
{
748
- int i, len;
749
- PyObject *fields, *key, *a = Py_None, *b = Py_None;
748
+ Py_ssize_t i, numfields = 0;
749
+ PyObject *fields, *key = NULL;
750
+
750
751
/* Check operator */
751
752
if ((op != Py_EQ && op != Py_NE) ||
752
- !PyAST_Check(self) ||
753
- !PyAST_Check(other)) {
753
+ !PyAST_Check(self) || !PyAST_Check(other)) {
754
754
Py_RETURN_NOTIMPLEMENTED;
755
755
}
756
+
756
757
/* Compare types */
757
758
if (Py_TYPE(self) != Py_TYPE(other)) {
758
- if (op == Py_EQ)
759
- Py_RETURN_FALSE;
760
- else
761
- Py_RETURN_TRUE;
759
+ Py_RETURN_RICHCOMPARE(Py_TYPE(self), Py_TYPE(other), op);
760
+ }
761
+
762
+ if (_PyObject_LookupAttr((PyObject*)Py_TYPE(self), astmodulestate_global->_fields, &fields) < 0) {
763
+ return NULL;
762
764
}
765
+ if (fields) {
766
+ numfields = PySequence_Size(fields);
767
+ if (numfields == -1) {
768
+ goto fail;
769
+ }
770
+ }
771
+
772
+ PyObject *a, *b;
763
773
/* Compare fields */
764
- fields = PyObject_GetAttrString(self, "_fields");
765
- len = PySequence_Size(fields);
766
- for (i = 0; i < len; ++i) {
774
+ for (i = 0; i < numfields; i++) {
767
775
key = PySequence_GetItem(fields, i);
768
- if (PyObject_HasAttr(self, key))
769
- a = PyObject_GetAttr(self, key);
770
- if (PyObject_HasAttr(other, key))
771
- b = PyObject_GetAttr(other, key);
772
- /* Check filed value type */
773
- if (Py_TYPE(a) != Py_TYPE(b)) {
774
- if (op == Py_EQ) {
775
- Py_RETURN_FALSE;
776
- }
776
+ if (!key) {
777
+ goto fail;
777
778
}
778
- if (op == Py_EQ) {
779
- if (!PyObject_RichCompareBool(a, b, Py_EQ)) {
780
- Py_RETURN_FALSE;
781
- }
779
+ if (!PyObject_HasAttr(self, key) || !PyObject_HasAttr(other, key)) {
780
+ Py_DECREF(key);
781
+ goto unsuccessful;
782
782
}
783
- else if (op == Py_NE) {
784
- if (PyObject_RichCompareBool(a, b, Py_NE)) {
785
- Py_RETURN_TRUE;
786
- }
783
+ Py_DECREF(key);
784
+
785
+ a = PyObject_GetAttr(self, key);
786
+ b = PyObject_GetAttr(other, key);
787
+ if (!a || !b) {
788
+ goto unsuccessful;
787
789
}
790
+
791
+ /* Ensure they belong to the same type */
792
+ if (Py_TYPE(a) != Py_TYPE(b)) {
793
+ goto unsuccessful;
794
+ }
795
+
796
+ if (!PyObject_RichCompareBool(a, b, Py_EQ)) {
797
+ goto unsuccessful;
798
+ }
799
+ Py_DECREF(a);
800
+ Py_DECREF(b);
788
801
}
789
- if (op == Py_EQ)
802
+ Py_DECREF(fields);
803
+
804
+ if (op == Py_EQ) {
790
805
Py_RETURN_TRUE;
791
- else
806
+ }
807
+ else {
792
808
Py_RETURN_FALSE;
809
+ }
810
+
811
+ unsuccessful:
812
+ Py_XDECREF(a);
813
+ Py_XDECREF(b);
814
+ Py_DECREF(fields);
815
+ if (op == Py_EQ) {
816
+ Py_RETURN_FALSE;
817
+ }
818
+ else {
819
+ Py_RETURN_TRUE;
820
+ }
821
+
822
+ fail:
823
+ Py_DECREF(fields);
824
+ return NULL;
793
825
}
794
826
795
827
static PyMemberDef ast_type_members[] = {
0 commit comments