The following code purports to implement a "shared lock" to protect some resource and that allows multiple "reader" threads to acquire concurrent access to that resource while providing a single "writer" thread exclusive access. In this implementation I have attempted to be "fair", i.e. I want to ensure that both readers and writers are ultimately able to acquire the lock (on the assumption that a lock that is acquired is ultimately released).
My main concern is its correctness.
Are there any scenarios where the code could indefinitely block a thread?
Could the code be implemented more simply or more efficiently?
The Code:
"""
This module implements a "shared lock", implemented with a threading.Condition
instance and its underlying threading.Lock instance.
We allow multiple threads that are just "reading" a resource to
concurrently acquire a "non-exclusive lock" on that
resource while preventing a "writer" thread from acquiring an "exclusive lock"
as long as any thread has acquired the lock. Likewise, when a thread has
acquired exclusive access, all other threads are blocked from acquiring the
lock until the exclusive access is released.
"""
from contextlib import contextmanager
from threading import Condition, Thread, current_thread
from collections import deque
class SharedLock:
"""
Usage:
shared_lock = SharedLock()
...
with shared_lock(False): # request shared access
...
with shared_lock(True): # request exclusive access
...
The code tries to be "fair" to ensure that all threads can ultimately acquire
the lock by maintaining a queue of waiting threads. Consider this scenario:
1. Reader thread 1 requests and acquires shared access at time t=0.
2. Reader thread 2 requests and acquires shared access at time t=1.
3. Writer thread 3 requests exclusive access at time t=2 but is now blocked
because the lock is already held.
4. Reader thread 4 requests shared access at time t=3. Theoretically, we can
give the thread shared access since only other "reader" threads have access.
But if we do so, it is possible that there is always at least one reader
thread holding a shared lock and the writer thread never gets access.
Instead, we place writer thread 3 in a queue of threads waiting to acquire
access and reader thread 4 is likewise enqueued behind thread 3. Thus, when
reader threads 1 and thread 2 ultimately release its shared lock, then
the thread at the head of thw waiting queue, i.e. writer thread 3, will be
given exclusive access.
"""
def __init__(self):
self._cond = Condition()
self._readers = 0
self._writer_running = False
self._waiting = deque()
@contextmanager
def __call__(self, exclusive: bool=False):
if not exclusive: # Shared access
with self._cond:
if self._writer_running or self._waiting:
this_thread = current_thread()
# Add us to the end of the waiting queue:
self._waiting.append(this_thread)
# And wait until there is no writer running and
# we are next to run:
self._cond.wait_for(lambda:
not self._writer_running and
self._waiting[0] == this_thread
)
self._waiting.popleft()
self._readers += 1
self._cond.notify_all() # Other readers may be waiting
try:
yield self
finally:
with self._cond:
self._readers -= 1
if not self._readers:
self._cond.notify_all()
else: # Exclusive access
with self._cond:
if self._readers or self._writer_running or self._waiting:
this_thread = current_thread()
self._waiting.append(this_thread)
self._cond.wait_for(lambda:
not self._readers and
not self._writer_running and
self._waiting[0] == this_thread
)
self._waiting.popleft()
self._writer_running = True
try:
yield self
finally:
with self._cond:
self._writer_running = False
self._cond.notify_all()
if __name__ == '__main__':
import time
shared_lock = SharedLock()
def reader():
with shared_lock(False):
print('start reading', time.time())
time.sleep(1)
print('end reading', time.time())
def writer():
with shared_lock(True):
print('start writing', time.time())
time.sleep(1)
print('end writing', time.time())
def test():
reader_threads1 = [
Thread(target=reader) for _ in range(4)
]
writer_threads = [
Thread(target=writer) for _ in range(2)
]
reader_threads2 = [
Thread(target=reader) for _ in range(2)
]
for t in reader_threads1:
t.start()
time.sleep(.01) # Give readers a chance to start
for t in writer_threads:
t.start()
time.sleep(.01) # Give writers a chance to start
for t in reader_threads2:
t.start()
for t in reader_threads1:
t.join()
for t in writer_threads:
t.join()
for t in reader_threads2:
t.join()
test()
Prints:
start reading 1767303995.195186
start reading 1767303995.195186
start reading 1767303995.195186
start reading 1767303995.195186
end reading 1767303996.1962683
end reading 1767303996.1963234
end reading 1767303996.1963234
end reading 1767303996.1963234
start writing 1767303996.197262
end writing 1767303997.1977108
start writing 1767303997.1977108
end writing 1767303998.1983483
start reading 1767303998.1983912
start reading 1767303998.1983912
end reading 1767303999.199174
end reading 1767303999.19917