Source code for recipes.cad_icassp_2026.baseline.compute_stoi

"""Compute the STOI scores."""

from __future__ import annotations

import json
import logging
from pathlib import Path

import hydra
import numpy as np
import torch
from omegaconf import DictConfig
from pystoi import stoi as compute_stoi
from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB_PLUS
from tqdm import tqdm

from clarity.utils.file_io import read_jsonl, write_jsonl
from clarity.utils.flac_encoder import read_flac_signal
from clarity.utils.signal_processing import resample
from recipes.cad_icassp_2026.baseline.shared_predict_utils import (
    input_align,
    load_vocals,
)

logger = logging.getLogger(__name__)


[docs] def compute_stoi_for_signal( cfg: DictConfig, record: dict, data_root: str, estimated_vocals: np.ndarray ) -> float: """Compute the stoi score for a given signal. Args: cfg (DictConfig): configuration object record (dict): the metadata dict for the signal data_root (str): root path to the dataset estimated_vocals (np.ndarray): estimated vocals signal Returns: float: stoi score """ signal_name = record["signal"] # Load processed signal signal_path = ( Path(data_root) / "audio" / cfg.split / "signals" / f"{signal_name}.flac" ) signal, proc_sr = read_flac_signal(signal_path) if proc_sr != cfg.data.sample_rate: logger.info(f"Resampling {signal_path} to {cfg.data.sample_rate} Hz") signal = resample(signal, proc_sr, cfg.data.sample_rate) signal_norm_factor = np.max(np.abs(signal)) signal /= signal_norm_factor estimated_vocals /= signal_norm_factor # Compute STOI score stoi_score_left = compute_single_stoi( estimated_vocals[:, 0], signal[:, 0], cfg.data.sample_rate, cfg.baseline.stoi_sample_rate, ) stoi_score_right = compute_single_stoi( estimated_vocals[:, 1], signal[:, 1], cfg.data.sample_rate, cfg.baseline.stoi_sample_rate, ) return np.max([stoi_score_left, stoi_score_right])
[docs] def compute_single_stoi( reference: np.ndarray, processed: np.ndarray, fsamp: int, stoi_fsamp: int = 10000 ) -> float: """Compute the STOI score between a reference and processed signal. Args: reference (np.ndarray): Reference signal. processed (np.ndarray): Processed signal. fsamp (int): Sampling frequency. stoi_fsamp (int): Sampling frequency for STOI computation. Default is 10000 Hz. Returns: float: STOI score. """ reference_side = resample(reference, fsamp, stoi_fsamp) processed_side = resample(processed, fsamp, stoi_fsamp) reference_side, processed_side = input_align( reference_side, processed_side, fsamp=int(stoi_fsamp) ) stoi_score = compute_stoi(reference_side, processed_side, int(stoi_fsamp)) return stoi_score
# pylint: disable = no-value-for-parameter
[docs] @hydra.main(config_path="configs", config_name="config", version_base=None) def run_compute_stoi(cfg: DictConfig) -> None: """Run the STOI score computation.""" assert cfg.baseline.name == "stoi" logger.info(f"Running {cfg.baseline.system} baseline on {cfg.split} set...") # Load the set of signal for which we need to compute scores dataroot = Path(cfg.data.cadenza_data_root) / cfg.data.dataset dataset_filename = dataroot / "metadata" / f"{cfg.split}_metadata.json" with dataset_filename.open("r", encoding="utf-8") as fp: records = json.load(fp) total_records = len(records) # Load existing results file if present batch_str = ( f".{cfg.baseline.batch}_{cfg.baseline.n_batches}" if cfg.baseline.n_batches > 1 else "" ) results_file = Path( f"{cfg.data.dataset}.{cfg.split}.{cfg.baseline.system}{batch_str}.jsonl" ) results = read_jsonl(str(results_file)) if results_file.exists() else [] results_index = {result["signal"]: result for result in results} # Find signals for which we don't have scores records = [ record for record in records if record["signal"] not in results_index.keys() ] records = records[cfg.baseline.batch - 1 :: cfg.baseline.n_batches] # Prepare audio source separation model device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {device}") separation_model = HDEMUCS_HIGH_MUSDB_PLUS.get_model() separation_model.to(device) # Iterate over the signals that need scoring logger.info(f"Computing scores for {len(records)} out of {total_records} signals") if cfg.baseline.separator.keep_vocals: logger.info("Saving estimated vocals. If exist, they will not be recomputed.") for record in tqdm(records): signal_name = record["signal"] # Load unprocessed signal to estimate vocals estimated_vocals = load_vocals( dataroot, record, cfg, separation_model, device=device ) stoi = compute_stoi_for_signal(cfg, record, dataroot, estimated_vocals) # Results are appended to the results file to allow interruption result = {"signal": signal_name, f"{cfg.baseline.system}": stoi} write_jsonl(str(results_file), [result])
if __name__ == "__main__": run_compute_stoi()