Skip to content

Commit 3766f18

Browse files
authored
bpo-35378: Fix multiprocessing.Pool references (GH-11627)
Changes in this commit: 1. Use a _strong_ reference between the Pool and associated iterators 2. Rework PR #8450 to eliminate a cycle in the Pool. There is no test in this commit because any test that automatically tests this behaviour needs to eliminate the pool before joining the pool to check that the pool object is garbaged collected/does not hang. But doing this will potentially leak threads and processes (see https://bugs.python.org/issue35413).
1 parent 4b250fc commit 3766f18

File tree

3 files changed

+80
-39
lines changed

3 files changed

+80
-39
lines changed

‎Lib/multiprocessing/pool.py‎

Lines changed: 74 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,9 @@ class Pool(object):
151151
'''
152152
_wrap_exception = True
153153

154-
def Process(self, *args, **kwds):
155-
return self._ctx.Process(*args, **kwds)
154+
@staticmethod
155+
def Process(ctx, *args, **kwds):
156+
return ctx.Process(*args, **kwds)
156157

157158
def __init__(self, processes=None, initializer=None, initargs=(),
158159
maxtasksperchild=None, context=None):
@@ -190,7 +191,10 @@ def __init__(self, processes=None, initializer=None, initargs=(),
190191

191192
self._worker_handler = threading.Thread(
192193
target=Pool._handle_workers,
193-
args=(self, )
194+
args=(self._cache, self._taskqueue, self._ctx, self.Process,
195+
self._processes, self._pool, self._inqueue, self._outqueue,
196+
self._initializer, self._initargs, self._maxtasksperchild,
197+
self._wrap_exception)
194198
)
195199
self._worker_handler.daemon = True
196200
self._worker_handler._state = RUN
@@ -236,43 +240,61 @@ def __repr__(self):
236240
f'state={self._state} '
237241
f'pool_size={len(self._pool)}>')
238242

239-
def _join_exited_workers(self):
243+
@staticmethod
244+
def _join_exited_workers(pool):
240245
"""Cleanup after any worker processes which have exited due to reaching
241246
their specified lifetime. Returns True if any workers were cleaned up.
242247
"""
243248
cleaned = False
244-
for i in reversed(range(len(self._pool))):
245-
worker = self._pool[i]
249+
for i in reversed(range(len(pool))):
250+
worker = pool[i]
246251
if worker.exitcode is not None:
247252
# worker exited
248253
util.debug('cleaning up worker %d' % i)
249254
worker.join()
250255
cleaned = True
251-
del self._pool[i]
256+
del pool[i]
252257
return cleaned
253258

254259
def _repopulate_pool(self):
260+
return self._repopulate_pool_static(self._ctx, self.Process,
261+
self._processes,
262+
self._pool, self._inqueue,
263+
self._outqueue, self._initializer,
264+
self._initargs,
265+
self._maxtasksperchild,
266+
self._wrap_exception)
267+
268+
@staticmethod
269+
def _repopulate_pool_static(ctx, Process, processes, pool, inqueue,
270+
outqueue, initializer, initargs,
271+
maxtasksperchild, wrap_exception):
255272
"""Bring the number of pool processes up to the specified number,
256273
for use after reaping workers which have exited.
257274
"""
258-
for i in range(self._processes - len(self._pool)):
259-
w = self.Process(target=worker,
260-
args=(self._inqueue, self._outqueue,
261-
self._initializer,
262-
self._initargs, self._maxtasksperchild,
263-
self._wrap_exception)
264-
)
275+
for i in range(processes - len(pool)):
276+
w = Process(ctx, target=worker,
277+
args=(inqueue, outqueue,
278+
initializer,
279+
initargs, maxtasksperchild,
280+
wrap_exception))
265281
w.name = w.name.replace('Process', 'PoolWorker')
266282
w.daemon = True
267283
w.start()
268-
self._pool.append(w)
284+
pool.append(w)
269285
util.debug('added worker')
270286

271-
def _maintain_pool(self):
287+
@staticmethod
288+
def _maintain_pool(ctx, Process, processes, pool, inqueue, outqueue,
289+
initializer, initargs, maxtasksperchild,
290+
wrap_exception):
272291
"""Clean up any exited workers and start replacements for them.
273292
"""
274-
if self._join_exited_workers():
275-
self._repopulate_pool()
293+
if Pool._join_exited_workers(pool):
294+
Pool._repopulate_pool_static(ctx, Process, processes, pool,
295+
inqueue, outqueue, initializer,
296+
initargs, maxtasksperchild,
297+
wrap_exception)
276298

277299
def _setup_queues(self):
278300
self._inqueue = self._ctx.SimpleQueue()
@@ -331,7 +353,7 @@ def imap(self, func, iterable, chunksize=1):
331353
'''
332354
self._check_running()
333355
if chunksize == 1:
334-
result = IMapIterator(self._cache)
356+
result = IMapIterator(self)
335357
self._taskqueue.put(
336358
(
337359
self._guarded_task_generation(result._job, func, iterable),
@@ -344,7 +366,7 @@ def imap(self, func, iterable, chunksize=1):
344366
"Chunksize must be 1+, not {0:n}".format(
345367
chunksize))
346368
task_batches = Pool._get_tasks(func, iterable, chunksize)
347-
result = IMapIterator(self._cache)
369+
result = IMapIterator(self)
348370
self._taskqueue.put(
349371
(
350372
self._guarded_task_generation(result._job,
@@ -360,7 +382,7 @@ def imap_unordered(self, func, iterable, chunksize=1):
360382
'''
361383
self._check_running()
362384
if chunksize == 1:
363-
result = IMapUnorderedIterator(self._cache)
385+
result = IMapUnorderedIterator(self)
364386
self._taskqueue.put(
365387
(
366388
self._guarded_task_generation(result._job, func, iterable),
@@ -372,7 +394,7 @@ def imap_unordered(self, func, iterable, chunksize=1):
372394
raise ValueError(
373395
"Chunksize must be 1+, not {0!r}".format(chunksize))
374396
task_batches = Pool._get_tasks(func, iterable, chunksize)
375-
result = IMapUnorderedIterator(self._cache)
397+
result = IMapUnorderedIterator(self)
376398
self._taskqueue.put(
377399
(
378400
self._guarded_task_generation(result._job,
@@ -388,7 +410,7 @@ def apply_async(self, func, args=(), kwds={}, callback=None,
388410
Asynchronous version of `apply()` method.
389411
'''
390412
self._check_running()
391-
result = ApplyResult(self._cache, callback, error_callback)
413+
result = ApplyResult(self, callback, error_callback)
392414
self._taskqueue.put(([(result._job, 0, func, args, kwds)], None))
393415
return result
394416

@@ -417,7 +439,7 @@ def _map_async(self, func, iterable, mapper, chunksize=None, callback=None,
417439
chunksize = 0
418440

419441
task_batches = Pool._get_tasks(func, iterable, chunksize)
420-
result = MapResult(self._cache, chunksize, len(iterable), callback,
442+
result = MapResult(self, chunksize, len(iterable), callback,
421443
error_callback=error_callback)
422444
self._taskqueue.put(
423445
(
@@ -430,16 +452,20 @@ def _map_async(self, func, iterable, mapper, chunksize=None, callback=None,
430452
return result
431453

432454
@staticmethod
433-
def _handle_workers(pool):
455+
def _handle_workers(cache, taskqueue, ctx, Process, processes, pool,
456+
inqueue, outqueue, initializer, initargs,
457+
maxtasksperchild, wrap_exception):
434458
thread = threading.current_thread()
435459

436460
# Keep maintaining workers until the cache gets drained, unless the pool
437461
# is terminated.
438-
while thread._state == RUN or (pool._cache and thread._state != TERMINATE):
439-
pool._maintain_pool()
462+
while thread._state == RUN or (cache and thread._state != TERMINATE):
463+
Pool._maintain_pool(ctx, Process, processes, pool, inqueue,
464+
outqueue, initializer, initargs,
465+
maxtasksperchild, wrap_exception)
440466
time.sleep(0.1)
441467
# send sentinel to stop workers
442-
pool._taskqueue.put(None)
468+
taskqueue.put(None)
443469
util.debug('worker handler exiting')
444470

445471
@staticmethod
@@ -656,13 +682,14 @@ def __exit__(self, exc_type, exc_val, exc_tb):
656682

657683
class ApplyResult(object):
658684

659-
def __init__(self, cache, callback, error_callback):
685+
def __init__(self, pool, callback, error_callback):
686+
self._pool = pool
660687
self._event = threading.Event()
661688
self._job = next(job_counter)
662-
self._cache = cache
689+
self._cache = pool._cache
663690
self._callback = callback
664691
self._error_callback = error_callback
665-
cache[self._job] = self
692+
self._cache[self._job] = self
666693

667694
def ready(self):
668695
return self._event.is_set()
@@ -692,6 +719,7 @@ def _set(self, i, obj):
692719
self._error_callback(self._value)
693720
self._event.set()
694721
del self._cache[self._job]
722+
self._pool = None
695723

696724
AsyncResult = ApplyResult # create alias -- see #17805
697725

@@ -701,16 +729,16 @@ def _set(self, i, obj):
701729

702730
class MapResult(ApplyResult):
703731

704-
def __init__(self, cache, chunksize, length, callback, error_callback):
705-
ApplyResult.__init__(self, cache, callback,
732+
def __init__(self, pool, chunksize, length, callback, error_callback):
733+
ApplyResult.__init__(self, pool, callback,
706734
error_callback=error_callback)
707735
self._success = True
708736
self._value = [None] * length
709737
self._chunksize = chunksize
710738
if chunksize <= 0:
711739
self._number_left = 0
712740
self._event.set()
713-
del cache[self._job]
741+
del self._cache[self._job]
714742
else:
715743
self._number_left = length//chunksize + bool(length % chunksize)
716744

@@ -724,6 +752,7 @@ def _set(self, i, success_result):
724752
self._callback(self._value)
725753
del self._cache[self._job]
726754
self._event.set()
755+
self._pool = None
727756
else:
728757
if not success and self._success:
729758
# only store first exception
@@ -735,22 +764,24 @@ def _set(self, i, success_result):
735764
self._error_callback(self._value)
736765
del self._cache[self._job]
737766
self._event.set()
767+
self._pool = None
738768

739769
#
740770
# Class whose instances are returned by `Pool.imap()`
741771
#
742772

743773
class IMapIterator(object):
744774

745-
def __init__(self, cache):
775+
def __init__(self, pool):
776+
self._pool = pool
746777
self._cond = threading.Condition(threading.Lock())
747778
self._job = next(job_counter)
748-
self._cache = cache
779+
self._cache = pool._cache
749780
self._items = collections.deque()
750781
self._index = 0
751782
self._length = None
752783
self._unsorted = {}
753-
cache[self._job] = self
784+
self._cache[self._job] = self
754785

755786
def __iter__(self):
756787
return self
@@ -761,12 +792,14 @@ def next(self, timeout=None):
761792
item = self._items.popleft()
762793
except IndexError:
763794
if self._index == self._length:
795+
self._pool = None
764796
raise StopIteration from None
765797
self._cond.wait(timeout)
766798
try:
767799
item = self._items.popleft()
768800
except IndexError:
769801
if self._index == self._length:
802+
self._pool = None
770803
raise StopIteration from None
771804
raise TimeoutError from None
772805

@@ -792,13 +825,15 @@ def _set(self, i, obj):
792825

793826
if self._index == self._length:
794827
del self._cache[self._job]
828+
self._pool = None
795829

796830
def _set_length(self, length):
797831
with self._cond:
798832
self._length = length
799833
if self._index == self._length:
800834
self._cond.notify()
801835
del self._cache[self._job]
836+
self._pool = None
802837

803838
#
804839
# Class whose instances are returned by `Pool.imap_unordered()`
@@ -813,6 +848,7 @@ def _set(self, i, obj):
813848
self._cond.notify()
814849
if self._index == self._length:
815850
del self._cache[self._job]
851+
self._pool = None
816852

817853
#
818854
#
@@ -822,7 +858,7 @@ class ThreadPool(Pool):
822858
_wrap_exception = False
823859

824860
@staticmethod
825-
def Process(*args, **kwds):
861+
def Process(ctx, *args, **kwds):
826862
from .dummy import Process
827863
return Process(*args, **kwds)
828864

‎Lib/test/_test_multiprocessing.py‎

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2593,7 +2593,6 @@ def test_resource_warning(self):
25932593
pool = None
25942594
support.gc_collect()
25952595

2596-
25972596
def raising():
25982597
raise KeyError("key")
25992598

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
Fix a reference issue inside :class:`multiprocessing.Pool` that caused
2+
the pool to remain alive if it was deleted without being closed or
3+
terminated explicitly. A new strong reference is added to the pool
4+
iterators to link the lifetime of the pool to the lifetime of its
5+
iterators so the pool does not get destroyed if a pool iterator is
6+
still alive.

0 commit comments

Comments
 (0)