aboutsummaryrefslogtreecommitdiffhomepage
path: root/libs/ffsubsync/aligners.py
blob: aebfe128d4dd3a24cd5ef3dbe984fd536d988db4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# -*- coding: utf-8 -*- 
import logging
import math

import numpy as np
from .sklearn_shim import TransformerMixin

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class FailedToFindAlignmentException(Exception):
    pass


class FFTAligner(TransformerMixin):
    def __init__(self):
        self.best_offset_ = None
        self.best_score_ = None
        self.get_score_ = False

    def fit(self, refstring, substring, get_score=False):
        refstring, substring = [
            list(map(int, s))
            if isinstance(s, str) else s
            for s in [refstring, substring]
        ]
        refstring, substring = map(
            lambda s: 2 * np.array(s).astype(float) - 1, [refstring, substring])
        total_bits = math.log(len(substring) + len(refstring), 2)
        total_length = int(2 ** math.ceil(total_bits))
        extra_zeros = total_length - len(substring) - len(refstring)
        subft = np.fft.fft(np.append(np.zeros(extra_zeros + len(refstring)), substring))
        refft = np.fft.fft(np.flip(np.append(refstring, np.zeros(len(substring) + extra_zeros)), 0))
        convolve = np.real(np.fft.ifft(subft * refft))
        best_idx = np.argmax(convolve)
        self.best_offset_ = len(convolve) - 1 - best_idx - len(substring)
        self.best_score_ = convolve[best_idx]
        self.get_score_ = get_score
        return self

    def transform(self, *_):
        if self.get_score_:
            return self.best_score_, self.best_offset_
        else:
            return self.best_offset_


class MaxScoreAligner(TransformerMixin):
    def __init__(self, base_aligner, sample_rate=None, max_offset_seconds=None):
        if isinstance(base_aligner, type):
            self.base_aligner = base_aligner()
        else:
            self.base_aligner = base_aligner
        self.max_offset_seconds = max_offset_seconds
        if sample_rate is None or max_offset_seconds is None:
            self.max_offset_samples = None
        else:
            self.max_offset_samples = abs(max_offset_seconds * sample_rate)
        self._scores = []

    def fit(self, refstring, subpipes):
        if not isinstance(subpipes, list):
            subpipes = [subpipes]
        for subpipe in subpipes:
            if hasattr(subpipe, 'transform'):
                substring = subpipe.transform(None)
            else:
                substring = subpipe
            self._scores.append((
                self.base_aligner.fit_transform(
                    refstring, substring, get_score=True
                ),
                subpipe
            ))
        return self

    def transform(self, *_):
        scores = self._scores
        if self.max_offset_samples is not None:
            scores = list(filter(lambda s: abs(s[0][1]) <= self.max_offset_samples, scores))
        if len(scores) == 0:
            raise FailedToFindAlignmentException('Synchronization failed; consider passing '
                                                 '--max-offset-seconds with a number larger than '
                                                 '{}'.format(self.max_offset_seconds))
        (score, offset), subpipe = max(scores, key=lambda x: x[0][0])
        return offset, subpipe