summaryrefslogtreecommitdiffhomepage
path: root/libs/ffsubsync/subtitle_transformers.py
blob: dcf0664f5b9721e1b6b1c327552f109f75f24e6d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# -*- coding: utf-8 -*-
from datetime import timedelta
import logging
import numbers

from ffsubsync.generic_subtitles import GenericSubtitle, GenericSubtitlesFile, SubsMixin
from ffsubsync.sklearn_shim import TransformerMixin

logging.basicConfig(level=logging.INFO)
logger: logging.Logger = logging.getLogger(__name__)


class SubtitleShifter(SubsMixin, TransformerMixin):
    def __init__(self, td_seconds):
        super(SubsMixin, self).__init__()
        if not isinstance(td_seconds, timedelta):
            self.td_seconds = timedelta(seconds=td_seconds)
        else:
            self.td_seconds = td_seconds

    def fit(self, subs: GenericSubtitlesFile, *_):
        self.subs_ = subs.offset(self.td_seconds)
        return self

    def transform(self, *_):
        return self.subs_


class SubtitleScaler(SubsMixin, TransformerMixin):
    def __init__(self, scale_factor):
        assert isinstance(scale_factor, numbers.Number)
        super(SubsMixin, self).__init__()
        self.scale_factor = scale_factor

    def fit(self, subs: GenericSubtitlesFile, *_):
        scaled_subs = []
        for sub in subs:
            scaled_subs.append(
                GenericSubtitle(
                    # py2 doesn't support direct multiplication of timedelta w/ float
                    timedelta(seconds=sub.start.total_seconds() * self.scale_factor),
                    timedelta(seconds=sub.end.total_seconds() * self.scale_factor),
                    sub.inner,
                )
            )
        self.subs_ = subs.clone_props_for_subs(scaled_subs)
        return self

    def transform(self, *_):
        return self.subs_


class SubtitleMerger(SubsMixin, TransformerMixin):
    def __init__(self, reference_subs, first="reference"):
        assert first in ("reference", "output")
        super(SubsMixin, self).__init__()
        self.reference_subs = reference_subs
        self.first = first

    def fit(self, output_subs: GenericSubtitlesFile, *_):
        def _merger_gen(a, b):
            ita, itb = iter(a), iter(b)
            cur_a = next(ita, None)
            cur_b = next(itb, None)
            while True:
                if cur_a is None and cur_b is None:
                    return
                elif cur_a is None:
                    while cur_b is not None:
                        yield cur_b
                        cur_b = next(itb, None)
                    return
                elif cur_b is None:
                    while cur_a is not None:
                        yield cur_a
                        cur_a = next(ita, None)
                    return
                # else: neither are None
                if cur_a.start < cur_b.start:
                    swapped = False
                else:
                    swapped = True
                    cur_a, cur_b = cur_b, cur_a
                    ita, itb = itb, ita
                prev_a = cur_a
                while prev_a is not None and cur_a.start < cur_b.start:
                    cur_a = next(ita, None)
                    if cur_a is None or cur_a.start < cur_b.start:
                        yield prev_a
                        prev_a = cur_a
                if prev_a is None:
                    while cur_b is not None:
                        yield cur_b
                        cur_b = next(itb, None)
                    return
                if cur_b.start - prev_a.start < cur_a.start - cur_b.start:
                    if swapped:
                        yield cur_b.merge_with(prev_a)
                        ita, itb = itb, ita
                        cur_a, cur_b = cur_b, cur_a
                        cur_a = next(ita, None)
                    else:
                        yield prev_a.merge_with(cur_b)
                        cur_b = next(itb, None)
                else:
                    if swapped:
                        yield cur_b.merge_with(cur_a)
                        ita, itb = itb, ita
                    else:
                        yield cur_a.merge_with(cur_b)
                    cur_a = next(ita, None)
                    cur_b = next(itb, None)

        merged_subs = []
        if self.first == "reference":
            first, second = self.reference_subs, output_subs
        else:
            first, second = output_subs, self.reference_subs
        for merged in _merger_gen(first, second):
            merged_subs.append(merged)
        self.subs_ = output_subs.clone_props_for_subs(merged_subs)
        return self

    def transform(self, *_):
        return self.subs_