from __future__ import annotations
from collections.abc import Sequence
from logging import getLogger
from pathlib import Path
from typing import Literal
import librosa
import numpy as np
import soundfile
import torch
from cm_time import timer
from tqdm import tqdm
from so_vits_svc_fork.inference.core import RealtimeVC, RealtimeVC2, Svc
from so_vits_svc_fork.utils import get_optimal_device
LOG = getLogger(__name__)
[docs]
def infer(
*,
# paths
input_path: Path | str | Sequence[Path | str],
output_path: Path | str | Sequence[Path | str],
model_path: Path | str,
config_path: Path | str,
recursive: bool = False,
# svc config
speaker: int | str,
cluster_model_path: Path | str | None = None,
transpose: int = 0,
auto_predict_f0: bool = False,
cluster_infer_ratio: float = 0,
noise_scale: float = 0.4,
f0_method: Literal["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"] = "dio",
# slice config
db_thresh: int = -40,
pad_seconds: float = 0.5,
chunk_seconds: float = 0.5,
absolute_thresh: bool = False,
max_chunk_seconds: float = 40,
device: str | torch.device = get_optimal_device(),
):
if isinstance(input_path, (str, Path)):
input_path = [input_path]
if isinstance(output_path, (str, Path)):
output_path = [output_path]
if len(input_path) != len(output_path):
raise ValueError(f"input_path and output_path must have same length, but got {len(input_path)} and {len(output_path)}")
model_path = Path(model_path)
config_path = Path(config_path)
output_path = [Path(p) for p in output_path]
input_path = [Path(p) for p in input_path]
output_paths = []
input_paths = []
for input_path, output_path in zip(input_path, output_path):
if input_path.is_dir():
if not recursive:
raise ValueError(f"input_path is a directory, but recursive is False: {input_path}")
input_paths.extend(list(input_path.rglob("*.*")))
output_paths.extend([output_path / p.relative_to(input_path) for p in input_paths])
continue
input_paths.append(input_path)
output_paths.append(output_path)
cluster_model_path = Path(cluster_model_path) if cluster_model_path else None
svc_model = Svc(
net_g_path=model_path.as_posix(),
config_path=config_path.as_posix(),
cluster_model_path=(cluster_model_path.as_posix() if cluster_model_path else None),
device=device,
)
try:
pbar = tqdm(list(zip(input_paths, output_paths)), disable=len(input_paths) == 1)
for input_path, output_path in pbar:
pbar.set_description(f"{input_path}")
try:
audio, _ = librosa.load(str(input_path), sr=svc_model.target_sample)
except Exception as e:
LOG.error(f"Failed to load {input_path}")
LOG.exception(e)
continue
output_path.parent.mkdir(parents=True, exist_ok=True)
audio = svc_model.infer_silence(
audio.astype(np.float32),
speaker=speaker,
transpose=transpose,
auto_predict_f0=auto_predict_f0,
cluster_infer_ratio=cluster_infer_ratio,
noise_scale=noise_scale,
f0_method=f0_method,
db_thresh=db_thresh,
pad_seconds=pad_seconds,
chunk_seconds=chunk_seconds,
absolute_thresh=absolute_thresh,
max_chunk_seconds=max_chunk_seconds,
)
soundfile.write(str(output_path), audio, svc_model.target_sample)
finally:
del svc_model
torch.cuda.empty_cache()
[docs]
def realtime(
*,
# paths
model_path: Path | str,
config_path: Path | str,
# svc config
speaker: str,
cluster_model_path: Path | str | None = None,
transpose: int = 0,
auto_predict_f0: bool = False,
cluster_infer_ratio: float = 0,
noise_scale: float = 0.4,
f0_method: Literal["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"] = "dio",
# slice config
db_thresh: int = -40,
pad_seconds: float = 0.5,
chunk_seconds: float = 0.5,
# realtime config
crossfade_seconds: float = 0.05,
additional_infer_before_seconds: float = 0.2,
additional_infer_after_seconds: float = 0.1,
block_seconds: float = 0.5,
version: int = 2,
input_device: int | str | None = None,
output_device: int | str | None = None,
device: str | torch.device = get_optimal_device(),
passthrough_original: bool = False,
):
import sounddevice as sd
model_path = Path(model_path)
config_path = Path(config_path)
cluster_model_path = Path(cluster_model_path) if cluster_model_path else None
svc_model = Svc(
net_g_path=model_path.as_posix(),
config_path=config_path.as_posix(),
cluster_model_path=(cluster_model_path.as_posix() if cluster_model_path else None),
device=device,
)
LOG.info("Creating realtime model...")
if version == 1:
model = RealtimeVC(
svc_model=svc_model,
crossfade_len=int(crossfade_seconds * svc_model.target_sample),
additional_infer_before_len=int(additional_infer_before_seconds * svc_model.target_sample),
additional_infer_after_len=int(additional_infer_after_seconds * svc_model.target_sample),
)
else:
model = RealtimeVC2(
svc_model=svc_model,
)
# LOG all device info
devices = sd.query_devices()
LOG.info(f"Device: {devices}")
if isinstance(input_device, str):
input_device_candidates = [i for i, d in enumerate(devices) if d["name"] == input_device]
if len(input_device_candidates) == 0:
LOG.warning(f"Input device {input_device} not found, using default")
input_device = None
else:
input_device = input_device_candidates[0]
if isinstance(output_device, str):
output_device_candidates = [i for i, d in enumerate(devices) if d["name"] == output_device]
if len(output_device_candidates) == 0:
LOG.warning(f"Output device {output_device} not found, using default")
output_device = None
else:
output_device = output_device_candidates[0]
if input_device is None or input_device >= len(devices):
input_device = sd.default.device[0]
if output_device is None or output_device >= len(devices):
output_device = sd.default.device[1]
LOG.info(f"Input Device: {devices[input_device]['name']}, Output Device: {devices[output_device]['name']}")
# the model RTL is somewhat significantly high only in the first inference
# there could be no better way to warm up the model than to do a dummy inference
# (there are not differences in the behavior of the model between the first and the later inferences)
# so we do a dummy inference to warm up the model (1 second of audio)
LOG.info("Warming up the model...")
svc_model.infer(
speaker=speaker,
transpose=transpose,
auto_predict_f0=auto_predict_f0,
cluster_infer_ratio=cluster_infer_ratio,
noise_scale=noise_scale,
f0_method=f0_method,
audio=np.zeros(svc_model.target_sample, dtype=np.float32),
)
def callback(
indata: np.ndarray,
outdata: np.ndarray,
frames: int,
time: int,
status: sd.CallbackFlags,
) -> None:
LOG.debug(f"Frames: {frames}, Status: {status}, Shape: {indata.shape}, Time: {time}")
kwargs = dict(
input_audio=indata.mean(axis=1).astype(np.float32),
# svc config
speaker=speaker,
transpose=transpose,
auto_predict_f0=auto_predict_f0,
cluster_infer_ratio=cluster_infer_ratio,
noise_scale=noise_scale,
f0_method=f0_method,
# slice config
db_thresh=db_thresh,
# pad_seconds=pad_seconds,
chunk_seconds=chunk_seconds,
)
if version == 1:
kwargs["pad_seconds"] = pad_seconds
with timer() as t:
inference = model.process(
**kwargs,
).reshape(-1, 1)
if passthrough_original:
outdata[:] = (indata + inference) / 2
else:
outdata[:] = inference
rtf = t.elapsed / block_seconds
LOG.info(f"Realtime inference time: {t.elapsed:.3f}s, RTF: {rtf:.3f}")
if rtf > 1:
LOG.warning("RTF is too high, consider increasing block_seconds")
try:
with sd.Stream(
device=(input_device, output_device),
channels=1,
callback=callback,
samplerate=svc_model.target_sample,
blocksize=int(block_seconds * svc_model.target_sample),
latency="low",
) as stream:
LOG.info(f"Latency: {stream.latency}")
while True:
sd.sleep(1000)
finally:
# del model, svc_model
torch.cuda.empty_cache()