bazarr/libs/ffsubsync/speech_transformers.py

532 lines
19 KiB
Python

# -*- coding: utf-8 -*-
import os
from contextlib import contextmanager
import logging
import io
import subprocess
import sys
from datetime import timedelta
from typing import cast, Callable, Dict, List, Optional, Union
import ffmpeg
import numpy as np
import tqdm
from ffsubsync.constants import (
DEFAULT_ENCODING,
DEFAULT_MAX_SUBTITLE_SECONDS,
DEFAULT_SCALE_FACTOR,
DEFAULT_START_SECONDS,
SAMPLE_RATE,
)
from ffsubsync.ffmpeg_utils import ffmpeg_bin_path, subprocess_args
from ffsubsync.generic_subtitles import GenericSubtitle
from ffsubsync.sklearn_shim import TransformerMixin
from ffsubsync.sklearn_shim import Pipeline
from ffsubsync.subtitle_parser import make_subtitle_parser
from ffsubsync.subtitle_transformers import SubtitleScaler
logging.basicConfig(level=logging.INFO)
logger: logging.Logger = logging.getLogger(__name__)
def make_subtitle_speech_pipeline(
fmt: str = "srt",
encoding: str = DEFAULT_ENCODING,
caching: bool = False,
max_subtitle_seconds: int = DEFAULT_MAX_SUBTITLE_SECONDS,
start_seconds: int = DEFAULT_START_SECONDS,
scale_factor: float = DEFAULT_SCALE_FACTOR,
parser=None,
**kwargs,
) -> Union[Pipeline, Callable[[float], Pipeline]]:
if parser is None:
parser = make_subtitle_parser(
fmt,
encoding=encoding,
caching=caching,
max_subtitle_seconds=max_subtitle_seconds,
start_seconds=start_seconds,
**kwargs,
)
assert parser.encoding == encoding
assert parser.max_subtitle_seconds == max_subtitle_seconds
assert parser.start_seconds == start_seconds
def subpipe_maker(framerate_ratio):
return Pipeline(
[
("parse", parser),
("scale", SubtitleScaler(framerate_ratio)),
(
"speech_extract",
SubtitleSpeechTransformer(
sample_rate=SAMPLE_RATE,
start_seconds=start_seconds,
framerate_ratio=framerate_ratio,
),
),
]
)
if scale_factor is None:
return subpipe_maker
else:
return subpipe_maker(scale_factor)
def _make_auditok_detector(
sample_rate: int, frame_rate: int, non_speech_label: float
) -> Callable[[bytes], np.ndarray]:
try:
from auditok import (
BufferAudioSource,
ADSFactory,
AudioEnergyValidator,
StreamTokenizer,
)
except ImportError as e:
logger.error(
"""Error: auditok not installed!
Consider installing it with `pip install auditok`. Note that auditok
is GPLv3 licensed, which means that successfully importing it at
runtime creates a derivative work that is GPLv3 licensed. For personal
use this is fine, but note that any commercial use that relies on
auditok must be open source as per the GPLv3!*
*Not legal advice. Consult with a lawyer.
"""
)
raise e
bytes_per_frame = 2
frames_per_window = frame_rate // sample_rate
validator = AudioEnergyValidator(sample_width=bytes_per_frame, energy_threshold=50)
tokenizer = StreamTokenizer(
validator=validator,
min_length=0.2 * sample_rate,
max_length=int(5 * sample_rate),
max_continuous_silence=0.25 * sample_rate,
)
def _detect(asegment: bytes) -> np.ndarray:
asource = BufferAudioSource(
data_buffer=asegment,
sampling_rate=frame_rate,
sample_width=bytes_per_frame,
channels=1,
)
ads = ADSFactory.ads(audio_source=asource, block_dur=1.0 / sample_rate)
ads.open()
tokens = tokenizer.tokenize(ads)
length = (
len(asegment) // bytes_per_frame + frames_per_window - 1
) // frames_per_window
media_bstring = np.zeros(length + 1)
for token in tokens:
media_bstring[token[1]] = 1.0
media_bstring[token[2] + 1] = non_speech_label - 1.0
return np.clip(np.cumsum(media_bstring)[:-1], 0.0, 1.0)
return _detect
def _make_webrtcvad_detector(
sample_rate: int, frame_rate: int, non_speech_label: float
) -> Callable[[bytes], np.ndarray]:
import webrtcvad
vad = webrtcvad.Vad()
vad.set_mode(3) # set non-speech pruning aggressiveness from 0 to 3
window_duration = 1.0 / sample_rate # duration in seconds
frames_per_window = int(window_duration * frame_rate + 0.5)
bytes_per_frame = 2
def _detect(asegment: bytes) -> np.ndarray:
media_bstring = []
failures = 0
for start in range(0, len(asegment) // bytes_per_frame, frames_per_window):
stop = min(start + frames_per_window, len(asegment) // bytes_per_frame)
try:
is_speech = vad.is_speech(
asegment[start * bytes_per_frame : stop * bytes_per_frame],
sample_rate=frame_rate,
)
except Exception:
is_speech = False
failures += 1
# webrtcvad has low recall on mode 3, so treat non-speech as "not sure"
media_bstring.append(1.0 if is_speech else non_speech_label)
return np.array(media_bstring)
return _detect
def _make_silero_detector(
sample_rate: int, frame_rate: int, non_speech_label: float
) -> Callable[[bytes], np.ndarray]:
import torch
window_duration = 1.0 / sample_rate # duration in seconds
frames_per_window = int(window_duration * frame_rate + 0.5)
bytes_per_frame = 1
model, _ = torch.hub.load(
repo_or_dir="snakers4/silero-vad",
model="silero_vad",
force_reload=False,
onnx=False,
)
exception_logged = False
def _detect(asegment) -> np.ndarray:
asegment = np.frombuffer(asegment, np.int16).astype(np.float32) / (1 << 15)
asegment = torch.FloatTensor(asegment)
media_bstring = []
failures = 0
for start in range(0, len(asegment) // bytes_per_frame, frames_per_window):
stop = min(start + frames_per_window, len(asegment))
try:
speech_prob = model(
asegment[start * bytes_per_frame : stop * bytes_per_frame],
frame_rate,
).item()
except Exception:
nonlocal exception_logged
if not exception_logged:
exception_logged = True
logger.exception("exception occurred during speech detection")
speech_prob = 0.0
failures += 1
media_bstring.append(1.0 - (1.0 - speech_prob) * (1.0 - non_speech_label))
return np.array(media_bstring)
return _detect
class ComputeSpeechFrameBoundariesMixin:
def __init__(self) -> None:
self.start_frame_: Optional[int] = None
self.end_frame_: Optional[int] = None
@property
def num_frames(self) -> Optional[int]:
if self.start_frame_ is None or self.end_frame_ is None:
return None
return self.end_frame_ - self.start_frame_
def fit_boundaries(
self, speech_frames: np.ndarray
) -> "ComputeSpeechFrameBoundariesMixin":
nz = np.nonzero(speech_frames > 0.5)[0]
if len(nz) > 0:
self.start_frame_ = int(np.min(nz))
self.end_frame_ = int(np.max(nz))
return self
class VideoSpeechTransformer(TransformerMixin):
def __init__(
self,
vad: str,
sample_rate: int,
frame_rate: int,
non_speech_label: float,
start_seconds: int = 0,
ffmpeg_path: Optional[str] = None,
ref_stream: Optional[str] = None,
vlc_mode: bool = False,
gui_mode: bool = False,
) -> None:
super(VideoSpeechTransformer, self).__init__()
self.vad: str = vad
self.sample_rate: int = sample_rate
self.frame_rate: int = frame_rate
self._non_speech_label: float = non_speech_label
self.start_seconds: int = start_seconds
self.ffmpeg_path: Optional[str] = ffmpeg_path
self.ref_stream: Optional[str] = ref_stream
self.vlc_mode: bool = vlc_mode
self.gui_mode: bool = gui_mode
self.video_speech_results_: Optional[np.ndarray] = None
def try_fit_using_embedded_subs(self, fname: str) -> None:
embedded_subs = []
embedded_subs_times = []
if self.ref_stream is None:
# check first 5; should cover 99% of movies
streams_to_try: List[str] = list(map("0:s:{}".format, range(5)))
else:
streams_to_try = [self.ref_stream]
for stream in streams_to_try:
ffmpeg_args = [
ffmpeg_bin_path(
"ffmpeg", self.gui_mode, ffmpeg_resources_path=self.ffmpeg_path
)
]
ffmpeg_args.extend(
[
"-loglevel",
"fatal",
"-nostdin",
"-i",
fname,
"-map",
"{}".format(stream),
"-f",
"srt",
"-",
]
)
process = subprocess.Popen(
ffmpeg_args, **subprocess_args(include_stdout=True)
)
output = io.BytesIO(process.communicate()[0])
if process.returncode != 0:
break
pipe = cast(
Pipeline,
make_subtitle_speech_pipeline(start_seconds=self.start_seconds),
).fit(output)
speech_step = pipe.steps[-1][1]
embedded_subs.append(speech_step)
embedded_subs_times.append(speech_step.max_time_)
if len(embedded_subs) == 0:
if self.ref_stream is None:
error_msg = "Video file appears to lack subtitle stream"
else:
error_msg = "Stream {} not found".format(self.ref_stream)
raise ValueError(error_msg)
# use longest set of embedded subs
subs_to_use = embedded_subs[int(np.argmax(embedded_subs_times))]
self.video_speech_results_ = subs_to_use.subtitle_speech_results_
def fit(self, fname: str, *_) -> "VideoSpeechTransformer":
if "subs" in self.vad and (
self.ref_stream is None or self.ref_stream.startswith("0:s:")
):
try:
logger.info("Checking video for subtitles stream...")
self.try_fit_using_embedded_subs(fname)
logger.info("...success!")
return self
except Exception as e:
logger.info(e)
try:
total_duration = (
float(
ffmpeg.probe(
fname,
cmd=ffmpeg_bin_path(
"ffprobe",
self.gui_mode,
ffmpeg_resources_path=self.ffmpeg_path,
),
)["format"]["duration"]
)
- self.start_seconds
)
except Exception as e:
logger.warning(e)
total_duration = None
if "webrtc" in self.vad:
detector = _make_webrtcvad_detector(
self.sample_rate, self.frame_rate, self._non_speech_label
)
elif "auditok" in self.vad:
detector = _make_auditok_detector(
self.sample_rate, self.frame_rate, self._non_speech_label
)
elif "silero" in self.vad:
detector = _make_silero_detector(
self.sample_rate, self.frame_rate, self._non_speech_label
)
else:
raise ValueError("unknown vad: %s" % self.vad)
media_bstring: List[np.ndarray] = []
ffmpeg_args = [
ffmpeg_bin_path(
"ffmpeg", self.gui_mode, ffmpeg_resources_path=self.ffmpeg_path
)
]
if self.start_seconds > 0:
ffmpeg_args.extend(
[
"-ss",
str(timedelta(seconds=self.start_seconds)),
]
)
ffmpeg_args.extend(["-loglevel", "fatal", "-nostdin", "-i", fname])
if self.ref_stream is not None and self.ref_stream.startswith("0:a:"):
ffmpeg_args.extend(["-map", self.ref_stream])
ffmpeg_args.extend(
[
"-f",
"s16le",
"-ac",
"1",
"-acodec",
"pcm_s16le",
"-ar",
str(self.frame_rate),
"-",
]
)
process = subprocess.Popen(ffmpeg_args, **subprocess_args(include_stdout=True))
bytes_per_frame = 2
frames_per_window = bytes_per_frame * self.frame_rate // self.sample_rate
windows_per_buffer = 10000
simple_progress = 0.0
redirect_stderr = None
tqdm_extra_args = {}
should_print_redirected_stderr = self.gui_mode
if self.gui_mode:
try:
from contextlib import redirect_stderr # type: ignore
tqdm_extra_args["file"] = sys.stdout
except ImportError:
should_print_redirected_stderr = False
if redirect_stderr is None:
@contextmanager
def redirect_stderr(enter_result=None):
yield enter_result
assert redirect_stderr is not None
pbar_output = io.StringIO()
with redirect_stderr(pbar_output):
with tqdm.tqdm(
total=total_duration, disable=self.vlc_mode, **tqdm_extra_args
) as pbar:
while True:
in_bytes = process.stdout.read(
frames_per_window * windows_per_buffer
)
if not in_bytes:
break
newstuff = len(in_bytes) / float(bytes_per_frame) / self.frame_rate
if (
total_duration is not None
and simple_progress + newstuff > total_duration
):
newstuff = total_duration - simple_progress
simple_progress += newstuff
pbar.update(newstuff)
if self.vlc_mode and total_duration is not None:
print("%d" % int(simple_progress * 100.0 / total_duration))
sys.stdout.flush()
if should_print_redirected_stderr:
assert self.gui_mode
# no need to flush since we pass -u to do unbuffered output for gui mode
print(pbar_output.read())
if "silero" not in self.vad:
in_bytes = np.frombuffer(in_bytes, np.uint8)
media_bstring.append(detector(in_bytes))
process.wait()
if len(media_bstring) == 0:
raise ValueError(
"Unable to detect speech. "
"Perhaps try specifying a different stream / track, or a different vad."
)
self.video_speech_results_ = np.concatenate(media_bstring)
logger.info("total of speech segments: %s", np.sum(self.video_speech_results_))
return self
def transform(self, *_) -> np.ndarray:
return self.video_speech_results_
_PAIRED_NESTER: Dict[str, str] = {
"(": ")",
"{": "}",
"[": "]",
# FIXME: False positive sometimes when there are html tags, e.g. <i> Hello? </i>
# '<': '>',
}
# TODO: need way better metadata detector
def _is_metadata(content: str, is_beginning_or_end: bool) -> bool:
content = content.strip()
if len(content) == 0:
return True
if (
content[0] in _PAIRED_NESTER.keys()
and content[-1] == _PAIRED_NESTER[content[0]]
):
return True
if is_beginning_or_end:
if "english" in content.lower():
return True
if " - " in content:
return True
return False
class SubtitleSpeechTransformer(TransformerMixin, ComputeSpeechFrameBoundariesMixin):
def __init__(
self, sample_rate: int, start_seconds: int = 0, framerate_ratio: float = 1.0
) -> None:
super(SubtitleSpeechTransformer, self).__init__()
self.sample_rate: int = sample_rate
self.start_seconds: int = start_seconds
self.framerate_ratio: float = framerate_ratio
self.subtitle_speech_results_: Optional[np.ndarray] = None
self.max_time_: Optional[int] = None
def fit(self, subs: List[GenericSubtitle], *_) -> "SubtitleSpeechTransformer":
max_time = 0
for sub in subs:
max_time = max(max_time, sub.end.total_seconds())
self.max_time_ = max_time - self.start_seconds
samples = np.zeros(int(max_time * self.sample_rate) + 2, dtype=float)
start_frame = float("inf")
end_frame = 0
for i, sub in enumerate(subs):
if _is_metadata(sub.content, i == 0 or i + 1 == len(subs)):
continue
start = int(
round(
(sub.start.total_seconds() - self.start_seconds) * self.sample_rate
)
)
start_frame = min(start_frame, start)
duration = sub.end.total_seconds() - sub.start.total_seconds()
end = start + int(round(duration * self.sample_rate))
end_frame = max(end_frame, end)
samples[start:end] = min(1.0 / self.framerate_ratio, 1.0)
self.subtitle_speech_results_ = samples
self.fit_boundaries(self.subtitle_speech_results_)
return self
def transform(self, *_) -> np.ndarray:
assert self.subtitle_speech_results_ is not None
return self.subtitle_speech_results_
class DeserializeSpeechTransformer(TransformerMixin):
def __init__(self, non_speech_label: float) -> None:
super(DeserializeSpeechTransformer, self).__init__()
self._non_speech_label: float = non_speech_label
self.deserialized_speech_results_: Optional[np.ndarray] = None
def fit(self, fname, *_) -> "DeserializeSpeechTransformer":
speech = np.load(fname)
if hasattr(speech, "files"):
if "speech" in speech.files:
speech = speech["speech"]
else:
raise ValueError(
'could not find "speech" array in '
"serialized file; only contains: %s" % speech.files
)
speech[speech < 1.0] = self._non_speech_label
self.deserialized_speech_results_ = speech
return self
def transform(self, *_) -> np.ndarray:
assert self.deserialized_speech_results_ is not None
return self.deserialized_speech_results_