Skip to content

Commit 1463518

Browse files
authored
[2.7] bpo-31234: Join threads explicitly in tests (#7406)
* Add support.wait_threads_exit(): context manager looping at exit until the number of threads decreases to its original number. * Add some missing thread.join() * test_asyncore.test_send(): call explicitly t.join() because the cleanup function is only called outside the test method, whereas the method has a @test_support.reap_threads decorator * test_hashlib: replace threading.Event with thread.join() * test_thread: * Use wait_threads_exit() context manager * Replace test_support with support * test_forkinthread(): check child process exit status in the main thread to better handle error.
1 parent fadcd44 commit 1463518

File tree

6 files changed

+122
-73
lines changed

6 files changed

+122
-73
lines changed

‎Lib/test/support/__init__.py‎

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1722,6 +1722,43 @@ def decorator(*args):
17221722
threading_cleanup(*key)
17231723
return decorator
17241724

1725+
1726+
@contextlib.contextmanager
1727+
def wait_threads_exit(timeout=60.0):
1728+
"""
1729+
bpo-31234: Context manager to wait until all threads created in the with
1730+
statement exit.
1731+
1732+
Use thread.count() to check if threads exited. Indirectly, wait until
1733+
threads exit the internal t_bootstrap() C function of the thread module.
1734+
1735+
threading_setup() and threading_cleanup() are designed to emit a warning
1736+
if a test leaves running threads in the background. This context manager
1737+
is designed to cleanup threads started by the thread.start_new_thread()
1738+
which doesn't allow to wait for thread exit, whereas thread.Thread has a
1739+
join() method.
1740+
"""
1741+
old_count = thread._count()
1742+
try:
1743+
yield
1744+
finally:
1745+
start_time = time.time()
1746+
deadline = start_time + timeout
1747+
while True:
1748+
count = thread._count()
1749+
if count <= old_count:
1750+
break
1751+
if time.time() > deadline:
1752+
dt = time.time() - start_time
1753+
msg = ("wait_threads() failed to cleanup %s "
1754+
"threads after %.1f seconds "
1755+
"(count: %s, old count: %s)"
1756+
% (count - old_count, dt, count, old_count))
1757+
raise AssertionError(msg)
1758+
time.sleep(0.010)
1759+
gc_collect()
1760+
1761+
17251762
def reap_children():
17261763
"""Use this function at the end of test_main() whenever sub-processes
17271764
are started. This will help ensure that no extra children (zombies)

‎Lib/test/test_asyncore.py‎

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -727,19 +727,20 @@ def test_quick_connect(self):
727727
server = TCPServer()
728728
t = threading.Thread(target=lambda: asyncore.loop(timeout=0.1, count=500))
729729
t.start()
730-
self.addCleanup(t.join)
731-
732-
for x in xrange(20):
733-
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
734-
s.settimeout(.2)
735-
s.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER,
736-
struct.pack('ii', 1, 0))
737-
try:
738-
s.connect(server.address)
739-
except socket.error:
740-
pass
741-
finally:
742-
s.close()
730+
try:
731+
for x in xrange(20):
732+
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
733+
s.settimeout(.2)
734+
s.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER,
735+
struct.pack('ii', 1, 0))
736+
try:
737+
s.connect(server.address)
738+
except socket.error:
739+
pass
740+
finally:
741+
s.close()
742+
finally:
743+
t.join()
743744

744745

745746
class TestAPI_UseSelect(BaseTestAPI):

‎Lib/test/test_hashlib.py‎

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -371,25 +371,25 @@ def test_threaded_hashing(self):
371371
data = smallest_data*200000
372372
expected_hash = hashlib.sha1(data*num_threads).hexdigest()
373373

374-
def hash_in_chunks(chunk_size, event):
374+
def hash_in_chunks(chunk_size):
375375
index = 0
376376
while index < len(data):
377377
hasher.update(data[index:index+chunk_size])
378378
index += chunk_size
379-
event.set()
380379

381-
events = []
380+
threads = []
382381
for threadnum in xrange(num_threads):
383382
chunk_size = len(data) // (10**threadnum)
384383
assert chunk_size > 0
385384
assert chunk_size % len(smallest_data) == 0
386-
event = threading.Event()
387-
events.append(event)
388-
threading.Thread(target=hash_in_chunks,
389-
args=(chunk_size, event)).start()
390-
391-
for event in events:
392-
event.wait()
385+
thread = threading.Thread(target=hash_in_chunks,
386+
args=(chunk_size,))
387+
threads.append(thread)
388+
389+
for thread in threads:
390+
thread.start()
391+
for thread in threads:
392+
thread.join()
393393

394394
self.assertEqual(expected_hash, hasher.hexdigest())
395395

‎Lib/test/test_httpservers.py‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def run(self):
6666

6767
def stop(self):
6868
self.server.shutdown()
69+
self.join()
6970

7071

7172
class BaseTestCase(unittest.TestCase):

‎Lib/test/test_smtplib.py‎

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,12 +306,14 @@ def setUp(self):
306306
self.sock.settimeout(15)
307307
self.port = test_support.bind_port(self.sock)
308308
servargs = (self.evt, self.respdata, self.sock)
309-
threading.Thread(target=server, args=servargs).start()
309+
self.thread = threading.Thread(target=server, args=servargs)
310+
self.thread.start()
310311
self.evt.wait()
311312
self.evt.clear()
312313

313314
def tearDown(self):
314315
self.evt.wait()
316+
self.thread.join()
315317
sys.stdout = self.old_stdout
316318

317319
def testLineTooLong(self):

‎Lib/test/test_thread.py‎

Lines changed: 57 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import os
22
import unittest
33
import random
4-
from test import test_support
5-
thread = test_support.import_module('thread')
4+
from test import support
5+
thread = support.import_module('thread')
66
import time
77
import sys
88
import weakref
@@ -17,7 +17,7 @@
1717

1818
def verbose_print(arg):
1919
"""Helper function for printing out debugging output."""
20-
if test_support.verbose:
20+
if support.verbose:
2121
with _print_mutex:
2222
print arg
2323

@@ -34,8 +34,8 @@ def setUp(self):
3434
self.running = 0
3535
self.next_ident = 0
3636

37-
key = test_support.threading_setup()
38-
self.addCleanup(test_support.threading_cleanup, *key)
37+
key = support.threading_setup()
38+
self.addCleanup(support.threading_cleanup, *key)
3939

4040

4141
class ThreadRunningTests(BasicThreadTest):
@@ -60,12 +60,13 @@ def task(self, ident):
6060
self.done_mutex.release()
6161

6262
def test_starting_threads(self):
63-
# Basic test for thread creation.
64-
for i in range(NUMTASKS):
65-
self.newtask()
66-
verbose_print("waiting for tasks to complete...")
67-
self.done_mutex.acquire()
68-
verbose_print("all tasks done")
63+
with support.wait_threads_exit():
64+
# Basic test for thread creation.
65+
for i in range(NUMTASKS):
66+
self.newtask()
67+
verbose_print("waiting for tasks to complete...")
68+
self.done_mutex.acquire()
69+
verbose_print("all tasks done")
6970

7071
def test_stack_size(self):
7172
# Various stack size tests.
@@ -95,12 +96,13 @@ def test_nt_and_posix_stack_size(self):
9596
verbose_print("trying stack_size = (%d)" % tss)
9697
self.next_ident = 0
9798
self.created = 0
98-
for i in range(NUMTASKS):
99-
self.newtask()
99+
with support.wait_threads_exit():
100+
for i in range(NUMTASKS):
101+
self.newtask()
100102

101-
verbose_print("waiting for all tasks to complete")
102-
self.done_mutex.acquire()
103-
verbose_print("all tasks done")
103+
verbose_print("waiting for all tasks to complete")
104+
self.done_mutex.acquire()
105+
verbose_print("all tasks done")
104106

105107
thread.stack_size(0)
106108

@@ -110,25 +112,28 @@ def test__count(self):
110112
mut = thread.allocate_lock()
111113
mut.acquire()
112114
started = []
115+
113116
def task():
114117
started.append(None)
115118
mut.acquire()
116119
mut.release()
117-
thread.start_new_thread(task, ())
118-
while not started:
119-
time.sleep(0.01)
120-
self.assertEqual(thread._count(), orig + 1)
121-
# Allow the task to finish.
122-
mut.release()
123-
# The only reliable way to be sure that the thread ended from the
124-
# interpreter's point of view is to wait for the function object to be
125-
# destroyed.
126-
done = []
127-
wr = weakref.ref(task, lambda _: done.append(None))
128-
del task
129-
while not done:
130-
time.sleep(0.01)
131-
self.assertEqual(thread._count(), orig)
120+
121+
with support.wait_threads_exit():
122+
thread.start_new_thread(task, ())
123+
while not started:
124+
time.sleep(0.01)
125+
self.assertEqual(thread._count(), orig + 1)
126+
# Allow the task to finish.
127+
mut.release()
128+
# The only reliable way to be sure that the thread ended from the
129+
# interpreter's point of view is to wait for the function object to be
130+
# destroyed.
131+
done = []
132+
wr = weakref.ref(task, lambda _: done.append(None))
133+
del task
134+
while not done:
135+
time.sleep(0.01)
136+
self.assertEqual(thread._count(), orig)
132137

133138
def test_save_exception_state_on_error(self):
134139
# See issue #14474
@@ -143,14 +148,13 @@ def mywrite(self, *args):
143148
real_write(self, *args)
144149
c = thread._count()
145150
started = thread.allocate_lock()
146-
with test_support.captured_output("stderr") as stderr:
151+
with support.captured_output("stderr") as stderr:
147152
real_write = stderr.write
148153
stderr.write = mywrite
149154
started.acquire()
150-
thread.start_new_thread(task, ())
151-
started.acquire()
152-
while thread._count() > c:
153-
time.sleep(0.01)
155+
with support.wait_threads_exit():
156+
thread.start_new_thread(task, ())
157+
started.acquire()
154158
self.assertIn("Traceback", stderr.getvalue())
155159

156160

@@ -182,13 +186,14 @@ def enter(self):
182186
class BarrierTest(BasicThreadTest):
183187

184188
def test_barrier(self):
185-
self.bar = Barrier(NUMTASKS)
186-
self.running = NUMTASKS
187-
for i in range(NUMTASKS):
188-
thread.start_new_thread(self.task2, (i,))
189-
verbose_print("waiting for tasks to end")
190-
self.done_mutex.acquire()
191-
verbose_print("tasks done")
189+
with support.wait_threads_exit():
190+
self.bar = Barrier(NUMTASKS)
191+
self.running = NUMTASKS
192+
for i in range(NUMTASKS):
193+
thread.start_new_thread(self.task2, (i,))
194+
verbose_print("waiting for tasks to end")
195+
self.done_mutex.acquire()
196+
verbose_print("tasks done")
192197

193198
def task2(self, ident):
194199
for i in range(NUMTRIPS):
@@ -226,8 +231,9 @@ def setUp(self):
226231

227232
@unittest.skipIf(sys.platform.startswith('win'),
228233
"This test is only appropriate for POSIX-like systems.")
229-
@test_support.reap_threads
234+
@support.reap_threads
230235
def test_forkinthread(self):
236+
non_local = {'status': None}
231237
def thread1():
232238
try:
233239
pid = os.fork() # fork in a thread
@@ -246,11 +252,13 @@ def thread1():
246252
else: # parent
247253
os.close(self.write_fd)
248254
pid, status = os.waitpid(pid, 0)
249-
self.assertEqual(status, 0)
255+
non_local['status'] = status
250256

251-
thread.start_new_thread(thread1, ())
252-
self.assertEqual(os.read(self.read_fd, 2), "OK",
253-
"Unable to fork() in thread")
257+
with support.wait_threads_exit():
258+
thread.start_new_thread(thread1, ())
259+
self.assertEqual(os.read(self.read_fd, 2), "OK",
260+
"Unable to fork() in thread")
261+
self.assertEqual(non_local['status'], 0)
254262

255263
def tearDown(self):
256264
try:
@@ -265,7 +273,7 @@ def tearDown(self):
265273

266274

267275
def test_main():
268-
test_support.run_unittest(ThreadRunningTests, BarrierTest, LockTests,
276+
support.run_unittest(ThreadRunningTests, BarrierTest, LockTests,
269277
TestForkInThread)
270278

271279
if __name__ == "__main__":

0 commit comments

Comments
 (0)