Source code for so_vits_svc_fork.dataset

from __future__ import annotations

from collections.abc import Sequence
from pathlib import Path
from random import Random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

from .hparams import HParams


[docs] class TextAudioDataset(Dataset): def __init__(self, hps: HParams, is_validation: bool = False): self.datapaths = [ Path(x).parent / (Path(x).name + ".data.pt") for x in Path(hps.data.validation_files if is_validation else hps.data.training_files).read_text("utf-8").splitlines() ] self.hps = hps self.random = Random(hps.train.seed) self.random.shuffle(self.datapaths) self.max_spec_len = 800 def __getitem__(self, index: int) -> dict[str, torch.Tensor]: with Path(self.datapaths[index]).open("rb") as f: data = torch.load(f, weights_only=True, map_location="cpu") # cut long data randomly spec_len = data["mel_spec"].shape[1] hop_len = self.hps.data.hop_length if spec_len > self.max_spec_len: start = self.random.randint(0, spec_len - self.max_spec_len) end = start + self.max_spec_len - 10 for key in data.keys(): if key == "audio": data[key] = data[key][:, start * hop_len : end * hop_len] elif key == "spk": continue else: data[key] = data[key][..., start:end] torch.cuda.empty_cache() return data def __len__(self) -> int: return len(self.datapaths)
def _pad_stack(array: Sequence[torch.Tensor]) -> torch.Tensor: max_idx = torch.argmax(torch.tensor([x_.shape[-1] for x_ in array])) max_x = array[max_idx] x_padded = [F.pad(x_, (0, max_x.shape[-1] - x_.shape[-1]), mode="constant", value=0) for x_ in array] return torch.stack(x_padded)
[docs] class TextAudioCollate(nn.Module):
[docs] def forward(self, batch: Sequence[dict[str, torch.Tensor]]) -> tuple[torch.Tensor, ...]: batch = [b for b in batch if b is not None] batch = sorted(batch, key=lambda x: x["mel_spec"].shape[1], reverse=True) lengths = torch.tensor([b["mel_spec"].shape[1] for b in batch]).long() results = {} for key in batch[0].keys(): if key not in ["spk"]: results[key] = _pad_stack([b[key] for b in batch]).cpu() else: results[key] = torch.tensor([[b[key]] for b in batch]).cpu() return ( results["content"], results["f0"], results["spec"], results["mel_spec"], results["audio"], results["spk"], lengths, results["uv"], )