summaryrefslogtreecommitdiffhomepage
path: root/libs/tqdm/contrib/concurrent.py
diff options
context:
space:
mode:
Diffstat (limited to 'libs/tqdm/contrib/concurrent.py')
-rw-r--r--libs/tqdm/contrib/concurrent.py66
1 files changed, 45 insertions, 21 deletions
diff --git a/libs/tqdm/contrib/concurrent.py b/libs/tqdm/contrib/concurrent.py
index 197a5f8c5..ccb5e1252 100644
--- a/libs/tqdm/contrib/concurrent.py
+++ b/libs/tqdm/contrib/concurrent.py
@@ -2,9 +2,12 @@
Thin wrappers around `concurrent.futures`.
"""
from __future__ import absolute_import
-from tqdm import TqdmWarning
-from tqdm.auto import tqdm as tqdm_auto
-from copy import deepcopy
+
+from contextlib import contextmanager
+
+from ..auto import tqdm as tqdm_auto
+from ..std import TqdmWarning
+
try:
from operator import length_hint
except ImportError:
@@ -23,10 +26,25 @@ except ImportError:
def cpu_count():
return 4
import sys
+
__author__ = {"github.com/": ["casperdcl"]}
__all__ = ['thread_map', 'process_map']
+@contextmanager
+def ensure_lock(tqdm_class, lock_name=""):
+ """get (create if necessary) and then restore `tqdm_class`'s lock"""
+ old_lock = getattr(tqdm_class, '_lock', None) # don't create a new lock
+ lock = old_lock or tqdm_class.get_lock() # maybe create a new lock
+ lock = getattr(lock, lock_name, lock) # maybe subtype
+ tqdm_class.set_lock(lock)
+ yield lock
+ if old_lock is None:
+ del tqdm_class._lock
+ else:
+ tqdm_class.set_lock(old_lock)
+
+
def _executor_map(PoolExecutor, fn, *iterables, **tqdm_kwargs):
"""
Implementation of `thread_map` and `process_map`.
@@ -36,25 +54,26 @@ def _executor_map(PoolExecutor, fn, *iterables, **tqdm_kwargs):
tqdm_class : [default: tqdm.auto.tqdm].
max_workers : [default: min(32, cpu_count() + 4)].
chunksize : [default: 1].
+ lock_name : [default: "":str].
"""
- kwargs = deepcopy(tqdm_kwargs)
+ kwargs = tqdm_kwargs.copy()
if "total" not in kwargs:
- kwargs["total"] = len(iterables[0])
+ kwargs["total"] = length_hint(iterables[0])
tqdm_class = kwargs.pop("tqdm_class", tqdm_auto)
max_workers = kwargs.pop("max_workers", min(32, cpu_count() + 4))
chunksize = kwargs.pop("chunksize", 1)
- pool_kwargs = dict(max_workers=max_workers)
- sys_version = sys.version_info[:2]
- if sys_version >= (3, 7):
- # share lock in case workers are already using `tqdm`
- pool_kwargs.update(
- initializer=tqdm_class.set_lock, initargs=(tqdm_class.get_lock(),))
- map_args = {}
- if not (3, 0) < sys_version < (3, 5):
- map_args.update(chunksize=chunksize)
- with PoolExecutor(**pool_kwargs) as ex:
- return list(tqdm_class(
- ex.map(fn, *iterables, **map_args), **kwargs))
+ lock_name = kwargs.pop("lock_name", "")
+ with ensure_lock(tqdm_class, lock_name=lock_name) as lk:
+ pool_kwargs = {'max_workers': max_workers}
+ sys_version = sys.version_info[:2]
+ if sys_version >= (3, 7):
+ # share lock in case workers are already using `tqdm`
+ pool_kwargs.update(initializer=tqdm_class.set_lock, initargs=(lk,))
+ map_args = {}
+ if not (3, 0) < sys_version < (3, 5):
+ map_args.update(chunksize=chunksize)
+ with PoolExecutor(**pool_kwargs) as ex:
+ return list(tqdm_class(ex.map(fn, *iterables, **map_args), **kwargs))
def thread_map(fn, *iterables, **tqdm_kwargs):
@@ -64,9 +83,9 @@ def thread_map(fn, *iterables, **tqdm_kwargs):
Parameters
----------
- tqdm_class : optional
+ tqdm_class : optional
`tqdm` class to use for bars [default: tqdm.auto.tqdm].
- max_workers : int, optional
+ max_workers : int, optional
Maximum number of workers to spawn; passed to
`concurrent.futures.ThreadPoolExecutor.__init__`.
[default: max(32, cpu_count() + 4)].
@@ -84,13 +103,15 @@ def process_map(fn, *iterables, **tqdm_kwargs):
----------
tqdm_class : optional
`tqdm` class to use for bars [default: tqdm.auto.tqdm].
- max_workers : int, optional
+ max_workers : int, optional
Maximum number of workers to spawn; passed to
`concurrent.futures.ProcessPoolExecutor.__init__`.
[default: min(32, cpu_count() + 4)].
- chunksize : int, optional
+ chunksize : int, optional
Size of chunks sent to worker processes; passed to
`concurrent.futures.ProcessPoolExecutor.map`. [default: 1].
+ lock_name : str, optional
+ Member of `tqdm_class.get_lock()` to use [default: mp_lock].
"""
from concurrent.futures import ProcessPoolExecutor
if iterables and "chunksize" not in tqdm_kwargs:
@@ -103,4 +124,7 @@ def process_map(fn, *iterables, **tqdm_kwargs):
" This may seriously degrade multiprocess performance."
" Set `chunksize=1` or more." % longest_iterable_len,
TqdmWarning, stacklevel=2)
+ if "lock_name" not in tqdm_kwargs:
+ tqdm_kwargs = tqdm_kwargs.copy()
+ tqdm_kwargs["lock_name"] = "mp_lock"
return _executor_map(ProcessPoolExecutor, fn, *iterables, **tqdm_kwargs)