Source code for recipes.cad_icassp_2026.baseline.predict

"""Make intelligibility predictions from HASPI scores."""

from __future__ import annotations

import logging

import hydra
from omegaconf import DictConfig

from recipes.cad_icassp_2026.baseline.shared_predict_utils import (
    LogisticModel,
    load_dataset_with_score,
)

log = logging.getLogger(__name__)


# pylint: disable = no-value-for-parameter
[docs] @hydra.main(config_path="configs", config_name="config", version_base=None) def predict_dev(cfg: DictConfig): """Predict intelligibility for baselines. Set config.baseline to ```stoi``` or ```whisper_mixture``` or ```whisper_vocals``` depending on which baseline you want to run. """ # Load the metadata file for the dataset log.info("Loading dataset...") records_train_df = load_dataset_with_score(cfg, "train") records_valid_df = load_dataset_with_score(cfg, "valid") # Compute the logistic fit log.info("Making the fitting model...") model = LogisticModel() model.fit(records_train_df[f"{cfg.baseline.system}"], records_train_df.correctness) # Make predictions for all items in the dev data log.info("Starting predictions...") records_valid_df["predicted_correctness"] = model.predict( records_valid_df[f"{cfg.baseline.system}"] ) # Save results to CSV file output_file = f"{cfg.data.dataset}.{cfg.baseline.system}.valid.predict.csv" records_valid_df[["signal", "predicted_correctness"]].to_csv( output_file, index=False, header=["signal_ID", "intelligibility_score"], mode="w", ) log.info(f"Predictions saved to {output_file}")
if __name__ == "__main__": predict_dev()