Make ThreadPool a context manager to prevent memory leaks
This commit is contained in:
parent
c4f65a5d7b
commit
09e65e1d95
3 changed files with 99 additions and 99 deletions
|
@ -149,21 +149,19 @@ class TestNoparallel:
|
|||
|
||||
def testMultithreadMix(self, queue_spawn):
|
||||
obj1 = ExampleClass()
|
||||
thread_pool = ThreadPool.ThreadPool(10)
|
||||
with ThreadPool.ThreadPool(10) as thread_pool:
|
||||
s = time.time()
|
||||
t1 = queue_spawn(obj1.countBlocking, 5)
|
||||
time.sleep(0.01)
|
||||
t2 = thread_pool.spawn(obj1.countBlocking, 5)
|
||||
time.sleep(0.01)
|
||||
t3 = thread_pool.spawn(obj1.countBlocking, 5)
|
||||
time.sleep(0.3)
|
||||
t4 = gevent.spawn(obj1.countBlocking, 5)
|
||||
threads = [t1, t2, t3, t4]
|
||||
for thread in threads:
|
||||
assert thread.get() == "counted:5"
|
||||
|
||||
s = time.time()
|
||||
t1 = queue_spawn(obj1.countBlocking, 5)
|
||||
time.sleep(0.01)
|
||||
t2 = thread_pool.spawn(obj1.countBlocking, 5)
|
||||
time.sleep(0.01)
|
||||
t3 = thread_pool.spawn(obj1.countBlocking, 5)
|
||||
time.sleep(0.3)
|
||||
t4 = gevent.spawn(obj1.countBlocking, 5)
|
||||
threads = [t1, t2, t3, t4]
|
||||
for thread in threads:
|
||||
assert thread.get() == "counted:5"
|
||||
|
||||
time_taken = time.time() - s
|
||||
assert obj1.counted == 5
|
||||
assert 0.5 < time_taken < 0.7
|
||||
thread_pool.kill()
|
||||
time_taken = time.time() - s
|
||||
assert obj1.counted == 5
|
||||
assert 0.5 < time_taken < 0.7
|
||||
|
|
|
@ -9,31 +9,29 @@ from util import ThreadPool
|
|||
|
||||
class TestThreadPool:
|
||||
def testExecutionOrder(self):
|
||||
pool = ThreadPool.ThreadPool(4)
|
||||
with ThreadPool.ThreadPool(4) as pool:
|
||||
events = []
|
||||
|
||||
events = []
|
||||
@pool.wrap
|
||||
def blocker():
|
||||
events.append("S")
|
||||
out = 0
|
||||
for i in range(10000000):
|
||||
if i == 3000000:
|
||||
events.append("M")
|
||||
out += 1
|
||||
events.append("D")
|
||||
return out
|
||||
|
||||
@pool.wrap
|
||||
def blocker():
|
||||
events.append("S")
|
||||
out = 0
|
||||
for i in range(10000000):
|
||||
if i == 3000000:
|
||||
events.append("M")
|
||||
out += 1
|
||||
events.append("D")
|
||||
return out
|
||||
threads = []
|
||||
for i in range(3):
|
||||
threads.append(gevent.spawn(blocker))
|
||||
gevent.joinall(threads)
|
||||
|
||||
threads = []
|
||||
for i in range(3):
|
||||
threads.append(gevent.spawn(blocker))
|
||||
gevent.joinall(threads)
|
||||
assert events == ["S"] * 3 + ["M"] * 3 + ["D"] * 3
|
||||
|
||||
assert events == ["S"] * 3 + ["M"] * 3 + ["D"] * 3
|
||||
|
||||
res = blocker()
|
||||
assert res == 10000000
|
||||
pool.kill()
|
||||
res = blocker()
|
||||
assert res == 10000000
|
||||
|
||||
def testLockBlockingSameThread(self):
|
||||
lock = ThreadPool.Lock()
|
||||
|
@ -60,89 +58,88 @@ class TestThreadPool:
|
|||
time.sleep(0.5)
|
||||
lock.release()
|
||||
|
||||
pool = ThreadPool.ThreadPool(10)
|
||||
threads = [
|
||||
pool.spawn(locker),
|
||||
pool.spawn(locker),
|
||||
gevent.spawn(locker),
|
||||
pool.spawn(locker)
|
||||
]
|
||||
time.sleep(0.1)
|
||||
with ThreadPool.ThreadPool(10) as pool:
|
||||
threads = [
|
||||
pool.spawn(locker),
|
||||
pool.spawn(locker),
|
||||
gevent.spawn(locker),
|
||||
pool.spawn(locker)
|
||||
]
|
||||
time.sleep(0.1)
|
||||
|
||||
s = time.time()
|
||||
s = time.time()
|
||||
|
||||
lock.acquire(True, 5.0)
|
||||
lock.acquire(True, 5.0)
|
||||
|
||||
unlock_taken = time.time() - s
|
||||
unlock_taken = time.time() - s
|
||||
|
||||
assert 1.8 < unlock_taken < 2.2
|
||||
assert 1.8 < unlock_taken < 2.2
|
||||
|
||||
gevent.joinall(threads)
|
||||
gevent.joinall(threads)
|
||||
|
||||
def testMainLoopCallerThreadId(self):
|
||||
main_thread_id = threading.current_thread().ident
|
||||
pool = ThreadPool.ThreadPool(5)
|
||||
with ThreadPool.ThreadPool(5) as pool:
|
||||
def getThreadId(*args, **kwargs):
|
||||
return threading.current_thread().ident
|
||||
|
||||
def getThreadId(*args, **kwargs):
|
||||
return threading.current_thread().ident
|
||||
t = pool.spawn(getThreadId)
|
||||
assert t.get() != main_thread_id
|
||||
|
||||
t = pool.spawn(getThreadId)
|
||||
assert t.get() != main_thread_id
|
||||
|
||||
t = pool.spawn(lambda: ThreadPool.main_loop.call(getThreadId))
|
||||
assert t.get() == main_thread_id
|
||||
t = pool.spawn(lambda: ThreadPool.main_loop.call(getThreadId))
|
||||
assert t.get() == main_thread_id
|
||||
|
||||
def testMainLoopCallerGeventSpawn(self):
|
||||
main_thread_id = threading.current_thread().ident
|
||||
pool = ThreadPool.ThreadPool(5)
|
||||
def waiter():
|
||||
time.sleep(1)
|
||||
return threading.current_thread().ident
|
||||
with ThreadPool.ThreadPool(5) as pool:
|
||||
def waiter():
|
||||
time.sleep(1)
|
||||
return threading.current_thread().ident
|
||||
|
||||
def geventSpawner():
|
||||
event = ThreadPool.main_loop.call(gevent.spawn, waiter)
|
||||
def geventSpawner():
|
||||
event = ThreadPool.main_loop.call(gevent.spawn, waiter)
|
||||
|
||||
with pytest.raises(Exception) as greenlet_err:
|
||||
event.get()
|
||||
assert str(greenlet_err.value) == "cannot switch to a different thread"
|
||||
with pytest.raises(Exception) as greenlet_err:
|
||||
event.get()
|
||||
assert str(greenlet_err.value) == "cannot switch to a different thread"
|
||||
|
||||
waiter_thread_id = ThreadPool.main_loop.call(event.get)
|
||||
return waiter_thread_id
|
||||
waiter_thread_id = ThreadPool.main_loop.call(event.get)
|
||||
return waiter_thread_id
|
||||
|
||||
s = time.time()
|
||||
waiter_thread_id = pool.apply(geventSpawner)
|
||||
assert main_thread_id == waiter_thread_id
|
||||
time_taken = time.time() - s
|
||||
assert 0.9 < time_taken < 1.2
|
||||
s = time.time()
|
||||
waiter_thread_id = pool.apply(geventSpawner)
|
||||
assert main_thread_id == waiter_thread_id
|
||||
time_taken = time.time() - s
|
||||
assert 0.9 < time_taken < 1.2
|
||||
|
||||
def testEvent(self):
|
||||
pool = ThreadPool.ThreadPool(5)
|
||||
event = ThreadPool.Event()
|
||||
with ThreadPool.ThreadPool(5) as pool:
|
||||
event = ThreadPool.Event()
|
||||
|
||||
def setter():
|
||||
time.sleep(1)
|
||||
event.set("done!")
|
||||
def setter():
|
||||
time.sleep(1)
|
||||
event.set("done!")
|
||||
|
||||
def getter():
|
||||
return event.get()
|
||||
def getter():
|
||||
return event.get()
|
||||
|
||||
pool.spawn(setter)
|
||||
t_gevent = gevent.spawn(getter)
|
||||
t_pool = pool.spawn(getter)
|
||||
s = time.time()
|
||||
assert event.get() == "done!"
|
||||
time_taken = time.time() - s
|
||||
gevent.joinall([t_gevent, t_pool])
|
||||
pool.spawn(setter)
|
||||
t_gevent = gevent.spawn(getter)
|
||||
t_pool = pool.spawn(getter)
|
||||
s = time.time()
|
||||
assert event.get() == "done!"
|
||||
time_taken = time.time() - s
|
||||
gevent.joinall([t_gevent, t_pool])
|
||||
|
||||
assert t_gevent.get() == "done!"
|
||||
assert t_pool.get() == "done!"
|
||||
assert t_gevent.get() == "done!"
|
||||
assert t_pool.get() == "done!"
|
||||
|
||||
assert 0.9 < time_taken < 1.2
|
||||
assert 0.9 < time_taken < 1.2
|
||||
|
||||
with pytest.raises(Exception) as err:
|
||||
event.set("another result")
|
||||
with pytest.raises(Exception) as err:
|
||||
event.set("another result")
|
||||
|
||||
assert "Event already has value" in str(err.value)
|
||||
assert "Event already has value" in str(err.value)
|
||||
|
||||
def testMemoryLeak(self):
|
||||
import gc
|
||||
|
@ -153,10 +150,9 @@ class TestThreadPool:
|
|||
return "ok"
|
||||
|
||||
def poolTest():
|
||||
pool = ThreadPool.ThreadPool(5)
|
||||
for i in range(20):
|
||||
pool.spawn(worker)
|
||||
pool.kill()
|
||||
with ThreadPool.ThreadPool(5) as pool:
|
||||
for i in range(20):
|
||||
pool.spawn(worker)
|
||||
|
||||
for i in range(5):
|
||||
poolTest()
|
||||
|
|
|
@ -55,6 +55,12 @@ class ThreadPool:
|
|||
del self.pool
|
||||
self.pool = None
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.kill()
|
||||
|
||||
|
||||
lock_pool = gevent.threadpool.ThreadPool(50)
|
||||
main_thread_id = threading.current_thread().ident
|
||||
|
|
Loading…
Reference in a new issue