""" Generic thread pool class. Modeled after Java's ThreadPoolExecutor. Please note that this ThreadPool does *not* fully implement the PEP 3148 ThreadPool! """ from threading import Thread, Lock, currentThread from weakref import ref import logging import atexit try: from queue import Queue, Empty except ImportError: from Queue import Queue, Empty logger = logging.getLogger(__name__) _threadpools = set() # Worker threads are daemonic in order to let the interpreter exit without # an explicit shutdown of the thread pool. The following trick is necessary # to allow worker threads to finish cleanly. def _shutdown_all(): for pool_ref in tuple(_threadpools): pool = pool_ref() if pool: pool.shutdown() atexit.register(_shutdown_all) class ThreadPool(object): def __init__(self, core_threads=0, max_threads=20, keepalive=1): """ :param core_threads: maximum number of persistent threads in the pool :param max_threads: maximum number of total threads in the pool :param thread_class: callable that creates a Thread object :param keepalive: seconds to keep non-core worker threads waiting for new tasks """ self.core_threads = core_threads self.max_threads = max(max_threads, core_threads, 1) self.keepalive = keepalive self._queue = Queue() self._threads_lock = Lock() self._threads = set() self._shutdown = False _threadpools.add(ref(self)) logger.info('Started thread pool with %d core threads and %s maximum ' 'threads', core_threads, max_threads or 'unlimited') def _adjust_threadcount(self): self._threads_lock.acquire() try: if self.num_threads < self.max_threads: self._add_thread(self.num_threads < self.core_threads) finally: self._threads_lock.release() def _add_thread(self, core): t = Thread(target=self._run_jobs, args=(core,)) t.setDaemon(True) t.start() self._threads.add(t) def _run_jobs(self, core): logger.debug('Started worker thread') block = True timeout = None if not core: block = self.keepalive > 0 timeout = self.keepalive while True: try: func, args, kwargs = self._queue.get(block, timeout) except Empty: break if self._shutdown: break try: func(*args, **kwargs) except: logger.exception('Error in worker thread') self._threads_lock.acquire() self._threads.remove(currentThread()) self._threads_lock.release() logger.debug('Exiting worker thread') @property def num_threads(self): return len(self._threads) def submit(self, func, *args, **kwargs): if self._shutdown: raise RuntimeError('Cannot schedule new tasks after shutdown') self._queue.put((func, args, kwargs)) self._adjust_threadcount() def shutdown(self, wait=True): if self._shutdown: return logging.info('Shutting down thread pool') self._shutdown = True _threadpools.remove(ref(self)) self._threads_lock.acquire() for _ in range(self.num_threads): self._queue.put((None, None, None)) self._threads_lock.release() if wait: self._threads_lock.acquire() threads = tuple(self._threads) self._threads_lock.release() for thread in threads: thread.join() def __repr__(self): if self.max_threads: threadcount = '%d/%d' % (self.num_threads, self.max_threads) else: threadcount = '%d' % self.num_threads return '' % (id(self), threadcount)