diff options
Diffstat (limited to 'libs/tqdm/contrib/concurrent.py')
-rw-r--r-- | libs/tqdm/contrib/concurrent.py | 66 |
1 files changed, 45 insertions, 21 deletions
diff --git a/libs/tqdm/contrib/concurrent.py b/libs/tqdm/contrib/concurrent.py index 197a5f8c5..ccb5e1252 100644 --- a/libs/tqdm/contrib/concurrent.py +++ b/libs/tqdm/contrib/concurrent.py @@ -2,9 +2,12 @@ Thin wrappers around `concurrent.futures`. """ from __future__ import absolute_import -from tqdm import TqdmWarning -from tqdm.auto import tqdm as tqdm_auto -from copy import deepcopy + +from contextlib import contextmanager + +from ..auto import tqdm as tqdm_auto +from ..std import TqdmWarning + try: from operator import length_hint except ImportError: @@ -23,10 +26,25 @@ except ImportError: def cpu_count(): return 4 import sys + __author__ = {"github.com/": ["casperdcl"]} __all__ = ['thread_map', 'process_map'] +@contextmanager +def ensure_lock(tqdm_class, lock_name=""): + """get (create if necessary) and then restore `tqdm_class`'s lock""" + old_lock = getattr(tqdm_class, '_lock', None) # don't create a new lock + lock = old_lock or tqdm_class.get_lock() # maybe create a new lock + lock = getattr(lock, lock_name, lock) # maybe subtype + tqdm_class.set_lock(lock) + yield lock + if old_lock is None: + del tqdm_class._lock + else: + tqdm_class.set_lock(old_lock) + + def _executor_map(PoolExecutor, fn, *iterables, **tqdm_kwargs): """ Implementation of `thread_map` and `process_map`. @@ -36,25 +54,26 @@ def _executor_map(PoolExecutor, fn, *iterables, **tqdm_kwargs): tqdm_class : [default: tqdm.auto.tqdm]. max_workers : [default: min(32, cpu_count() + 4)]. chunksize : [default: 1]. + lock_name : [default: "":str]. """ - kwargs = deepcopy(tqdm_kwargs) + kwargs = tqdm_kwargs.copy() if "total" not in kwargs: - kwargs["total"] = len(iterables[0]) + kwargs["total"] = length_hint(iterables[0]) tqdm_class = kwargs.pop("tqdm_class", tqdm_auto) max_workers = kwargs.pop("max_workers", min(32, cpu_count() + 4)) chunksize = kwargs.pop("chunksize", 1) - pool_kwargs = dict(max_workers=max_workers) - sys_version = sys.version_info[:2] - if sys_version >= (3, 7): - # share lock in case workers are already using `tqdm` - pool_kwargs.update( - initializer=tqdm_class.set_lock, initargs=(tqdm_class.get_lock(),)) - map_args = {} - if not (3, 0) < sys_version < (3, 5): - map_args.update(chunksize=chunksize) - with PoolExecutor(**pool_kwargs) as ex: - return list(tqdm_class( - ex.map(fn, *iterables, **map_args), **kwargs)) + lock_name = kwargs.pop("lock_name", "") + with ensure_lock(tqdm_class, lock_name=lock_name) as lk: + pool_kwargs = {'max_workers': max_workers} + sys_version = sys.version_info[:2] + if sys_version >= (3, 7): + # share lock in case workers are already using `tqdm` + pool_kwargs.update(initializer=tqdm_class.set_lock, initargs=(lk,)) + map_args = {} + if not (3, 0) < sys_version < (3, 5): + map_args.update(chunksize=chunksize) + with PoolExecutor(**pool_kwargs) as ex: + return list(tqdm_class(ex.map(fn, *iterables, **map_args), **kwargs)) def thread_map(fn, *iterables, **tqdm_kwargs): @@ -64,9 +83,9 @@ def thread_map(fn, *iterables, **tqdm_kwargs): Parameters ---------- - tqdm_class : optional + tqdm_class : optional `tqdm` class to use for bars [default: tqdm.auto.tqdm]. - max_workers : int, optional + max_workers : int, optional Maximum number of workers to spawn; passed to `concurrent.futures.ThreadPoolExecutor.__init__`. [default: max(32, cpu_count() + 4)]. @@ -84,13 +103,15 @@ def process_map(fn, *iterables, **tqdm_kwargs): ---------- tqdm_class : optional `tqdm` class to use for bars [default: tqdm.auto.tqdm]. - max_workers : int, optional + max_workers : int, optional Maximum number of workers to spawn; passed to `concurrent.futures.ProcessPoolExecutor.__init__`. [default: min(32, cpu_count() + 4)]. - chunksize : int, optional + chunksize : int, optional Size of chunks sent to worker processes; passed to `concurrent.futures.ProcessPoolExecutor.map`. [default: 1]. + lock_name : str, optional + Member of `tqdm_class.get_lock()` to use [default: mp_lock]. """ from concurrent.futures import ProcessPoolExecutor if iterables and "chunksize" not in tqdm_kwargs: @@ -103,4 +124,7 @@ def process_map(fn, *iterables, **tqdm_kwargs): " This may seriously degrade multiprocess performance." " Set `chunksize=1` or more." % longest_iterable_len, TqdmWarning, stacklevel=2) + if "lock_name" not in tqdm_kwargs: + tqdm_kwargs = tqdm_kwargs.copy() + tqdm_kwargs["lock_name"] = "mp_lock" return _executor_map(ProcessPoolExecutor, fn, *iterables, **tqdm_kwargs) |