"""Shared utilities for STOI baseline prediction experiments."""
from __future__ import annotations
import json
import logging
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from omegaconf import DictConfig
from scipy.optimize import curve_fit
from scipy.signal import correlate
from torch.nn import Module
from clarity.utils.file_io import read_jsonl, read_signal, write_signal
from clarity.utils.signal_processing import resample
from clarity.utils.source_separation_support import separate_sources
logger = logging.getLogger(__name__)
[docs]
class LogisticModel:
"""Class to represent a logistic mapping.
Fits a logistic mapping from input values x to output values y.
"""
params: np.ndarray | None = None # The model params
def _logistic_mapping(self, x, x_0, k):
"""Logistic function
Args:
x - the input value
x_0 - logistic parameter: the x value of the logistic's midpoint
k - logistic parameter: the growth rate of the curve
Returns:
The output of the logistic function.
"""
return 100.0 / (1 + np.exp(-k * (x - x_0)))
[docs]
def fit(self, x, y):
"""Fit a mapping from x values to y values."""
initial_guess = [0.5, 1.0] # Initial guess for parameter values
self.params, *_pcov = curve_fit(self._logistic_mapping, x, y, initial_guess)
[docs]
def predict(self, x):
"""Predict y values given x.
Raises:
TypeError: If the predict() method is called before fit().
"""
if self.params is None:
raise TypeError(
"params is None. Logistic fit() must be called before predict()."
)
return self._logistic_mapping(x, self.params[0], self.params[1])
[docs]
def estimate_vocals(
signal: np.ndarray, sample_rate: int, model: Module, device: str = "cpu"
) -> np.ndarray:
"""
Estimate vocals from the input signal using the pre-trained model.
Args:
signal (torch.Tensor | np.ndarray): Input audio signal.
sample_rate (int): Sample rate of the input signal.
model (torch.nn.Module): Pre-trained source separation model.
device (str): Device to run the model on ('cpu' or 'cuda').
Returns:
np.ndarray: Estimated vocals.
"""
if not isinstance(signal, torch.Tensor):
signal = torch.tensor(signal, dtype=torch.float32)
if signal.ndim == 1:
# HDemucs works with 2 channels, so we need to stack the signal
# to create a stereo signal
signal = torch.stack([signal, signal], dim=0)
signal = signal.to(device)
ref = signal.mean(0)
signal = (signal - ref.mean()) / ref.std()
ref = ref.cpu().detach().numpy()
sources = separate_sources(
model,
signal[None],
sample_rate=sample_rate,
device=device,
)[0]
sources = sources * ref.std() + ref.mean()
sources_list = model.sources
sources = list(sources)
audios = dict(zip(sources_list, sources))
estimated_vocals = audios["vocals"]
return estimated_vocals
[docs]
def load_mixture(
dataroot: Path, record: dict, cfg: DictConfig
) -> tuple[np.ndarray, float]:
"""Load the mixture signal for a given record.
Args:
dataroot (Path): Root path to the dataset.
record (dict): Record containing signal metadata.
cfg (DictConfig): Configuration object.
Returns:
tuple[np.ndarray, int]: Mixture signal and its sample rate.
"""
signal_name = record["signal"]
if cfg.baseline.reference == "processed":
mix_signal_path = (
dataroot / "audio" / cfg.split / "signals" / f"{signal_name}.flac"
)
elif cfg.baseline.reference == "unprocessed":
mix_signal_path = (
dataroot
/ "audio"
/ cfg.split
/ "unprocessed"
/ f"{signal_name}_unproc.flac"
)
else:
raise ValueError(f"Unknown reference type: {cfg.baseline.reference}")
mixture = read_signal(
mix_signal_path,
sample_rate=cfg.data.sample_rate,
)
return mixture, cfg.data.sample_rate
[docs]
def load_vocals(
dataroot: Path, record: dict, cfg: DictConfig, separation_model, device="cpu"
) -> np.ndarray:
"""Load or compute estimated vocals for a given record.
Args:
dataroot (Path): Root path to the dataset.
record (dict): Record containing signal metadata.
cfg (DictConfig): Configuration object.
separation_model: Pre-trained source separation model.
device (str): Device to run the model on ('cpu' or 'cuda').
Returns:
np.ndarray: Estimated vocals signal.
"""
signal_name = record["signal"]
vocals_path = Path("est_vocals") / cfg.split / f"{signal_name}_est_vocals.wav"
if not vocals_path.exists():
signal, signal_sr = load_mixture(dataroot, record, cfg)
if signal_sr != cfg.data.sample_rate:
logger.info(f"resampling mixture signal to {cfg.data.sample_rate} Hz")
signal = resample(signal, signal_sr, cfg.data.sample_rate)
# Estimate vocals to create a processed signal
estimated_vocals = estimate_vocals(
signal.T,
cfg.baseline.separator.sample_rate,
separation_model,
device=device,
).T
if cfg.baseline.separator.keep_vocals:
vocals_path.parent.mkdir(parents=True, exist_ok=True)
write_signal(vocals_path, estimated_vocals, cfg.data.sample_rate)
else:
estimated_vocals = read_signal(vocals_path, cfg.data.sample_rate)
return estimated_vocals
[docs]
def load_dataset_with_score(cfg, split: str) -> pd.DataFrame:
"""Load dataset and add prediction scores.
Args:
cfg (DictConfig): Configuration object.
split (str): Dataset split to load ('train' or 'valid')
Returns:
pd.DataFrame: DataFrame containing dataset records with added scores.
"""
dataset_filename = (
Path(cfg.data.cadenza_data_root)
/ cfg.data.dataset
/ "metadata"
/ f"{split}_metadata.json"
)
with dataset_filename.open("r", encoding="utf-8") as fp:
records = json.load(fp)
# Load STOI or Whisper scores and add them to the records
system_path = f"{cfg.data.dataset}.{split}.{cfg.baseline.system}.jsonl"
system_score = read_jsonl(str(system_path))
system_score_index = {
record["signal"]: record[cfg.baseline.system] for record in system_score
}
for record in records:
record[f"{cfg.baseline.system}"] = system_score_index[record["signal"]]
return pd.DataFrame(records)