Skip to content

Make WeakSet safe against concurrent mutation across threads while iterating #123089

Closed
@kumaraditya303

Description

@kumaraditya303

Bug report

Currently if a WeakSet is being iterated then it is not safe against concurrent additions of weakrefs by other threads, this leads to spurious RuntimeErrors being raised as the underlying set might resize. The existing _IterationGuard doesn't protect against this case.

To solve this I propose that we don't rely on _IterationGuard but while iterating take a copy of the underlying set, this way it is safe against concurrent mutations from other threads as copying a set is an atomic operation. The following patch implements this and all the existing tests pass.

Patch
From 4f81fb16d9259e5b86c40d791a377aad1cff4dfb Mon Sep 17 00:00:00 2001
From: Kumar Aditya <[email protected]>
Date: Sat, 17 Aug 2024 00:57:47 +0530
Subject: [PATCH] avoid _IterationGuard for WeakSet, use copy of the set
 instead

---
 Lib/_weakrefset.py | 53 +++++++++-------------------------------------
 1 file changed, 10 insertions(+), 43 deletions(-)

diff --git a/Lib/_weakrefset.py b/Lib/_weakrefset.py
index 489eec714e0..2071755d71d 100644
--- a/Lib/_weakrefset.py
+++ b/Lib/_weakrefset.py
@@ -36,41 +36,26 @@ def __exit__(self, e, t, b):
 class WeakSet:
     def __init__(self, data=None):
         self.data = set()
+
         def _remove(item, selfref=ref(self)):
             self = selfref()
             if self is not None:
-                if self._iterating:
-                    self._pending_removals.append(item)
-                else:
-                    self.data.discard(item)
+                self.data.discard(item)
+
         self._remove = _remove
-        # A list of keys to be removed
-        self._pending_removals = []
-        self._iterating = set()
         if data is not None:
             self.update(data)
 
-    def _commit_removals(self):
-        pop = self._pending_removals.pop
-        discard = self.data.discard
-        while True:
-            try:
-                item = pop()
-            except IndexError:
-                return
-            discard(item)
-
     def __iter__(self):
-        with _IterationGuard(self):
-            for itemref in self.data:
-                item = itemref()
-                if item is not None:
-                    # Caveat: the iterator will keep a strong reference to
-                    # `item` until it is resumed or closed.
-                    yield item
+        for itemref in self.data.copy():
+            item = itemref()
+            if item is not None:
+                # Caveat: the iterator will keep a strong reference to
+                # `item` until it is resumed or closed.
+                yield item
 
     def __len__(self):
-        return len(self.data) - len(self._pending_removals)
+        return len(self.data)
 
     def __contains__(self, item):
         try:
@@ -83,21 +68,15 @@ def __reduce__(self):
         return self.__class__, (list(self),), self.__getstate__()
 
     def add(self, item):
-        if self._pending_removals:
-            self._commit_removals()
         self.data.add(ref(item, self._remove))
 
     def clear(self):
-        if self._pending_removals:
-            self._commit_removals()
         self.data.clear()
 
     def copy(self):
         return self.__class__(self)
 
     def pop(self):
-        if self._pending_removals:
-            self._commit_removals()
         while True:
             try:
                 itemref = self.data.pop()
@@ -108,18 +87,12 @@ def pop(self):
                 return item
 
     def remove(self, item):
-        if self._pending_removals:
-            self._commit_removals()
         self.data.remove(ref(item))
 
     def discard(self, item):
-        if self._pending_removals:
-            self._commit_removals()
         self.data.discard(ref(item))
 
     def update(self, other):
-        if self._pending_removals:
-            self._commit_removals()
         for element in other:
             self.add(element)
 
@@ -136,8 +109,6 @@ def difference(self, other):
     def difference_update(self, other):
         self.__isub__(other)
     def __isub__(self, other):
-        if self._pending_removals:
-            self._commit_removals()
         if self is other:
             self.data.clear()
         else:
@@ -151,8 +122,6 @@ def intersection(self, other):
     def intersection_update(self, other):
         self.__iand__(other)
     def __iand__(self, other):
-        if self._pending_removals:
-            self._commit_removals()
         self.data.intersection_update(ref(item) for item in other)
         return self
 
@@ -184,8 +153,6 @@ def symmetric_difference(self, other):
     def symmetric_difference_update(self, other):
         self.__ixor__(other)
     def __ixor__(self, other):
-        if self._pending_removals:
-            self._commit_removals()
         if self is other:
             self.data.clear()
         else:
-- 
2.45.2.windows.1

Linked PRs

Metadata

Metadata

Assignees

No one assigned

    Labels

    3.14bugs and security fixestype-bugAn unexpected behavior, bug, or error

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions