diff options
Diffstat (limited to 'libs/auditok/plotting.py')
-rwxr-xr-x | libs/auditok/plotting.py | 150 |
1 files changed, 150 insertions, 0 deletions
diff --git a/libs/auditok/plotting.py b/libs/auditok/plotting.py new file mode 100755 index 000000000..eca5877f4 --- /dev/null +++ b/libs/auditok/plotting.py @@ -0,0 +1,150 @@ +import matplotlib.pyplot as plt +import numpy as np + +AUDITOK_PLOT_THEME = { + "figure": {"facecolor": "#482a36", "alpha": 0.2}, + "plot": {"facecolor": "#282a36"}, + "energy_threshold": { + "color": "#e31f8f", + "linestyle": "--", + "linewidth": 1, + }, + "signal": {"color": "#40d970", "linestyle": "-", "linewidth": 1}, + "detections": { + "facecolor": "#777777", + "edgecolor": "#ff8c1a", + "linewidth": 1, + "alpha": 0.75, + }, +} + + +def _make_time_axis(nb_samples, sampling_rate): + sample_duration = 1 / sampling_rate + x = np.linspace(0, sample_duration * (nb_samples - 1), nb_samples) + return x + + +def _plot_line(x, y, theme, xlabel=None, ylabel=None, **kwargs): + color = theme.get("color", theme.get("c")) + ls = theme.get("linestyle", theme.get("ls")) + lw = theme.get("linewidth", theme.get("lw")) + plt.plot(x, y, c=color, ls=ls, lw=lw, **kwargs) + plt.xlabel(xlabel, fontsize=8) + plt.ylabel(ylabel, fontsize=8) + + +def _plot_detections(subplot, detections, theme): + fc = theme.get("facecolor", theme.get("fc")) + ec = theme.get("edgecolor", theme.get("ec")) + ls = theme.get("linestyle", theme.get("ls")) + lw = theme.get("linewidth", theme.get("lw")) + alpha = theme.get("alpha") + for (start, end) in detections: + subplot.axvspan(start, end, fc=fc, ec=ec, ls=ls, lw=lw, alpha=alpha) + + +def plot( + audio_region, + scale_signal=True, + detections=None, + energy_threshold=None, + show=True, + figsize=None, + save_as=None, + dpi=120, + theme="auditok", +): + y = np.asarray(audio_region) + if len(y.shape) == 1: + y = y.reshape(1, -1) + nb_subplots, nb_samples = y.shape + sampling_rate = audio_region.sampling_rate + time_axis = _make_time_axis(nb_samples, sampling_rate) + if energy_threshold is not None: + eth_log10 = energy_threshold * np.log(10) / 10 + amplitude_threshold = np.sqrt(np.exp(eth_log10)) + else: + amplitude_threshold = None + if detections is None: + detections = [] + else: + # End of detection corresponds to the end of the last sample but + # to stay compatible with the time axis of signal plotting we want end + # of detection to correspond to the *start* of the that last sample. + detections = [ + (start, end - (1 / sampling_rate)) for (start, end) in detections + ] + if theme == "auditok": + theme = AUDITOK_PLOT_THEME + + fig = plt.figure(figsize=figsize, dpi=dpi) + fig_theme = theme.get("figure", theme.get("fig", {})) + fig_fc = fig_theme.get("facecolor", fig_theme.get("ffc")) + fig_alpha = fig_theme.get("alpha", 1) + fig.patch.set_facecolor(fig_fc) + fig.patch.set_alpha(fig_alpha) + + plot_theme = theme.get("plot", {}) + plot_fc = plot_theme.get("facecolor", plot_theme.get("pfc")) + + if nb_subplots > 2 and nb_subplots % 2 == 0: + nb_rows = nb_subplots // 2 + nb_columns = 2 + else: + nb_rows = nb_subplots + nb_columns = 1 + + for sid, samples in enumerate(y, 1): + ax = fig.add_subplot(nb_rows, nb_columns, sid) + ax.set_facecolor(plot_fc) + if scale_signal: + std = samples.std() + if std > 0: + mean = samples.mean() + std = samples.std() + samples = (samples - mean) / std + max_ = samples.max() + plt.ylim(-1.5 * max_, 1.5 * max_) + if amplitude_threshold is not None: + if scale_signal and std > 0: + amp_th = (amplitude_threshold - mean) / std + else: + amp_th = amplitude_threshold + eth_theme = theme.get("energy_threshold", theme.get("eth", {})) + _plot_line( + [time_axis[0], time_axis[-1]], + [amp_th] * 2, + eth_theme, + label="Detection threshold", + ) + if sid == 1: + legend = plt.legend( + ["Detection threshold"], + facecolor=fig_fc, + framealpha=0.1, + bbox_to_anchor=(0.0, 1.15, 1.0, 0.102), + loc=2, + ) + legend = plt.gca().add_artist(legend) + + signal_theme = theme.get("signal", {}) + _plot_line( + time_axis, + samples, + signal_theme, + xlabel="Time (seconds)", + ylabel="Signal{}".format(" (scaled)" if scale_signal else ""), + ) + detections_theme = theme.get("detections", {}) + _plot_detections(ax, detections, detections_theme) + plt.title("Channel {}".format(sid), fontsize=10) + + plt.xticks(fontsize=8) + plt.yticks(fontsize=8) + plt.tight_layout() + + if save_as is not None: + plt.savefig(save_as, dpi=dpi) + if show: + plt.show() |