diff options
Diffstat (limited to 'libs/tqdm/keras.py')
-rw-r--r-- | libs/tqdm/keras.py | 47 |
1 files changed, 33 insertions, 14 deletions
diff --git a/libs/tqdm/keras.py b/libs/tqdm/keras.py index 27623c099..523e62e94 100644 --- a/libs/tqdm/keras.py +++ b/libs/tqdm/keras.py @@ -1,9 +1,13 @@ from __future__ import absolute_import, division + +from copy import copy +from functools import partial + from .auto import tqdm as tqdm_auto -from copy import deepcopy + try: import keras -except ImportError as e: +except (ImportError, AttributeError) as e: try: from tensorflow import keras except ImportError: @@ -13,14 +17,14 @@ __all__ = ['TqdmCallback'] class TqdmCallback(keras.callbacks.Callback): - """`keras` callback for epoch and batch progress""" + """Keras callback for epoch and batch progress.""" @staticmethod def bar2callback(bar, pop=None, delta=(lambda logs: 1)): def callback(_, logs=None): n = delta(logs) if logs: if pop: - logs = deepcopy(logs) + logs = copy(logs) [logs.pop(i, 0) for i in pop] bar.set_postfix(logs, refresh=False) bar.update(n) @@ -28,7 +32,7 @@ class TqdmCallback(keras.callbacks.Callback): return callback def __init__(self, epochs=None, data_size=None, batch_size=None, verbose=1, - tqdm_class=tqdm_auto): + tqdm_class=tqdm_auto, **tqdm_kwargs): """ Parameters ---------- @@ -41,9 +45,13 @@ class TqdmCallback(keras.callbacks.Callback): 0: epoch, 1: batch (transient), 2: batch. [default: 1]. Will be set to `0` unless both `data_size` and `batch_size` are given. - tqdm_class : optional + tqdm_class : optional `tqdm` class to use for bars [default: `tqdm.auto.tqdm`]. + tqdm_kwargs : optional + Any other arguments used for all bars. """ + if tqdm_kwargs: + tqdm_class = partial(tqdm_class, **tqdm_kwargs) self.tqdm_class = tqdm_class self.epoch_bar = tqdm_class(total=epochs, unit='epoch') self.on_epoch_end = self.bar2callback(self.epoch_bar) @@ -53,20 +61,21 @@ class TqdmCallback(keras.callbacks.Callback): self.batches = batches = None self.verbose = verbose if verbose == 1: - self.batch_bar = tqdm_class(total=batches, unit='batch', - leave=False) + self.batch_bar = tqdm_class(total=batches, unit='batch', leave=False) self.on_batch_end = self.bar2callback( - self.batch_bar, - pop=['batch', 'size'], + self.batch_bar, pop=['batch', 'size'], delta=lambda logs: logs.get('size', 1)) def on_train_begin(self, *_, **__): params = self.params.get auto_total = params('epochs', params('nb_epoch', None)) - if auto_total is not None: + if auto_total is not None and auto_total != self.epoch_bar.total: self.epoch_bar.reset(total=auto_total) - def on_epoch_begin(self, *_, **__): + def on_epoch_begin(self, epoch, *_, **__): + if self.epoch_bar.n < epoch: + ebar = self.epoch_bar + ebar.n = ebar.last_print_n = ebar.initial = epoch if self.verbose: params = self.params.get total = params('samples', params( @@ -78,8 +87,7 @@ class TqdmCallback(keras.callbacks.Callback): total=total, unit='batch', leave=True, unit_scale=1 / (params('batch_size', 1) or 1)) self.on_batch_end = self.bar2callback( - self.batch_bar, - pop=['batch', 'size'], + self.batch_bar, pop=['batch', 'size'], delta=lambda logs: logs.get('size', 1)) elif self.verbose == 1: self.batch_bar.unit_scale = 1 / (params('batch_size', 1) or 1) @@ -92,6 +100,17 @@ class TqdmCallback(keras.callbacks.Callback): self.batch_bar.close() self.epoch_bar.close() + def display(self): + """Displays in the current cell in Notebooks.""" + container = getattr(self.epoch_bar, 'container', None) + if container is None: + return + from .notebook import display + display(container) + batch_bar = getattr(self, 'batch_bar', None) + if batch_bar is not None: + display(batch_bar.container) + @staticmethod def _implements_train_batch_hooks(): return True |