mirror of
https://github.com/evilhero/mylar
synced 2024-12-27 10:06:52 +00:00
134 lines
3.9 KiB
Python
134 lines
3.9 KiB
Python
|
"""
|
||
|
Generic thread pool class. Modeled after Java's ThreadPoolExecutor.
|
||
|
Please note that this ThreadPool does *not* fully implement the PEP 3148
|
||
|
ThreadPool!
|
||
|
"""
|
||
|
|
||
|
from threading import Thread, Lock, currentThread
|
||
|
from weakref import ref
|
||
|
import logging
|
||
|
import atexit
|
||
|
|
||
|
try:
|
||
|
from queue import Queue, Empty
|
||
|
except ImportError:
|
||
|
from Queue import Queue, Empty
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
_threadpools = set()
|
||
|
|
||
|
|
||
|
# Worker threads are daemonic in order to let the interpreter exit without
|
||
|
# an explicit shutdown of the thread pool. The following trick is necessary
|
||
|
# to allow worker threads to finish cleanly.
|
||
|
def _shutdown_all():
|
||
|
for pool_ref in tuple(_threadpools):
|
||
|
pool = pool_ref()
|
||
|
if pool:
|
||
|
pool.shutdown()
|
||
|
|
||
|
atexit.register(_shutdown_all)
|
||
|
|
||
|
|
||
|
class ThreadPool(object):
|
||
|
def __init__(self, core_threads=0, max_threads=20, keepalive=1):
|
||
|
"""
|
||
|
:param core_threads: maximum number of persistent threads in the pool
|
||
|
:param max_threads: maximum number of total threads in the pool
|
||
|
:param thread_class: callable that creates a Thread object
|
||
|
:param keepalive: seconds to keep non-core worker threads waiting
|
||
|
for new tasks
|
||
|
"""
|
||
|
self.core_threads = core_threads
|
||
|
self.max_threads = max(max_threads, core_threads, 1)
|
||
|
self.keepalive = keepalive
|
||
|
self._queue = Queue()
|
||
|
self._threads_lock = Lock()
|
||
|
self._threads = set()
|
||
|
self._shutdown = False
|
||
|
|
||
|
_threadpools.add(ref(self))
|
||
|
logger.info('Started thread pool with %d core threads and %s maximum '
|
||
|
'threads', core_threads, max_threads or 'unlimited')
|
||
|
|
||
|
def _adjust_threadcount(self):
|
||
|
self._threads_lock.acquire()
|
||
|
try:
|
||
|
if self.num_threads < self.max_threads:
|
||
|
self._add_thread(self.num_threads < self.core_threads)
|
||
|
finally:
|
||
|
self._threads_lock.release()
|
||
|
|
||
|
def _add_thread(self, core):
|
||
|
t = Thread(target=self._run_jobs, args=(core,))
|
||
|
t.setDaemon(True)
|
||
|
t.start()
|
||
|
self._threads.add(t)
|
||
|
|
||
|
def _run_jobs(self, core):
|
||
|
logger.debug('Started worker thread')
|
||
|
block = True
|
||
|
timeout = None
|
||
|
if not core:
|
||
|
block = self.keepalive > 0
|
||
|
timeout = self.keepalive
|
||
|
|
||
|
while True:
|
||
|
try:
|
||
|
func, args, kwargs = self._queue.get(block, timeout)
|
||
|
except Empty:
|
||
|
break
|
||
|
|
||
|
if self._shutdown:
|
||
|
break
|
||
|
|
||
|
try:
|
||
|
func(*args, **kwargs)
|
||
|
except:
|
||
|
logger.exception('Error in worker thread')
|
||
|
|
||
|
self._threads_lock.acquire()
|
||
|
self._threads.remove(currentThread())
|
||
|
self._threads_lock.release()
|
||
|
|
||
|
logger.debug('Exiting worker thread')
|
||
|
|
||
|
@property
|
||
|
def num_threads(self):
|
||
|
return len(self._threads)
|
||
|
|
||
|
def submit(self, func, *args, **kwargs):
|
||
|
if self._shutdown:
|
||
|
raise RuntimeError('Cannot schedule new tasks after shutdown')
|
||
|
|
||
|
self._queue.put((func, args, kwargs))
|
||
|
self._adjust_threadcount()
|
||
|
|
||
|
def shutdown(self, wait=True):
|
||
|
if self._shutdown:
|
||
|
return
|
||
|
|
||
|
logging.info('Shutting down thread pool')
|
||
|
self._shutdown = True
|
||
|
_threadpools.remove(ref(self))
|
||
|
|
||
|
self._threads_lock.acquire()
|
||
|
for _ in range(self.num_threads):
|
||
|
self._queue.put((None, None, None))
|
||
|
self._threads_lock.release()
|
||
|
|
||
|
if wait:
|
||
|
self._threads_lock.acquire()
|
||
|
threads = tuple(self._threads)
|
||
|
self._threads_lock.release()
|
||
|
for thread in threads:
|
||
|
thread.join()
|
||
|
|
||
|
def __repr__(self):
|
||
|
if self.max_threads:
|
||
|
threadcount = '%d/%d' % (self.num_threads, self.max_threads)
|
||
|
else:
|
||
|
threadcount = '%d' % self.num_threads
|
||
|
|
||
|
return '<ThreadPool at %x; threads=%s>' % (id(self), threadcount)
|