"""Compute the HASPI scores."""
from __future__ import annotations
import csv
import hashlib
import json
import logging
from pathlib import Path
import hydra
import numpy as np
from omegaconf import DictConfig
from scipy.io import wavfile
from tqdm import tqdm
from clarity.evaluator.haspi import haspi_v2_be
from clarity.utils.audiogram import Listener
from clarity.utils.file_io import read_jsonl, write_jsonl
logger = logging.getLogger(__name__)
# Standard audiograms for each severity level
#
# These are based on those defined in the Clarity library utils/audiogram.py
# which originated from ones used by Moore, Stone, Baer and Glasberg.
#
# See utils/audiogram.py for more details.
MILD_LISTENER = {
"name": "Mild Listener",
"audiogram_cfs": np.array([250, 500, 1000, 2000, 3000, 4000, 6000, 8000]),
"audiogram_levels_l": np.array([10, 15, 19, 25, 28, 31, 35, 38]),
"audiogram_levels_r": np.array([10, 15, 19, 25, 28, 31, 35, 38]),
}
MOD_LISTENER = {
"name": "Moderate Listener",
"audiogram_cfs": np.array([250, 500, 1000, 2000, 3000, 4000, 6000, 8000]),
"audiogram_levels_l": np.array([20, 20, 25, 35, 40, 45, 50, 55]),
"audiogram_levels_r": np.array([20, 20, 25, 35, 40, 45, 50, 55]),
}
MOD_SEV_LISTENER = {
"name": "Moderately Severe Listener",
"audiogram_cfs": np.array([250, 500, 1000, 2000, 3000, 4000, 6000, 8000]),
"audiogram_levels_l": np.array([19, 28, 40, 52, 56, 58, 58, 63]),
"audiogram_levels_r": np.array([19, 28, 40, 52, 56, 58, 58, 63]),
}
[docs]
def set_seed_with_string(seed_string: str) -> None:
"""Set the random seed with a string."""
md5_int = int(hashlib.md5(seed_string.encode("utf-8")).hexdigest(), 16) % (10**8)
np.random.seed(md5_int)
[docs]
def parse_signal_name(signal_name: str) -> dict:
"""Parse the signal name."""
# e.g. CEC2_E032_S09318_L0254.wav
cec, system, scene, listener = signal_name.split("_")
if scene == "" or listener == "" or system == []:
raise ValueError(f"Invalid CEC2 signal name: {signal_name}")
info = {
"scene": scene,
"listener": listener,
"system": system,
"cec": cec,
}
return info
[docs]
def compute_haspi_for_signal(record: dict, data_root: str, split: str) -> float:
"""Compute the HASPI score for a given signal.
Args:
record (dict): the metadata dict for the signal
signal_dir (str): paths to where the HA output signals are stored
ref_dir (str): path to where the reference signals are stored
Returns:
float: HASPI score
"""
signal_name = record["signal"]
signal_dir = Path(data_root) / split / "signals"
ref_dir = Path(data_root) / split / "references"
listener_data_dict = {
"Mild": MILD_LISTENER,
"Moderate": MOD_LISTENER,
"Moderately severe": MOD_SEV_LISTENER,
}
if "hearing_loss" in record:
# For DEV DATA the hearing loss is directly in the metadata...
listener_severity = record["hearing_loss"]
# ... and a simple one-to-one mapping between signal name and reference signal
ref_signal = f"{signal_name}_ref"
else:
# For the TRAINING DATA it needs to be looked up from listeners.csv
signal_parts = parse_signal_name(signal_name)
listener_id = signal_parts["listener"]
scene = signal_parts["scene"]
cec = signal_parts["cec"]
with open(Path(data_root) / "metadata" / "listeners.csv", encoding="utf8") as f:
listener_dict = csv.DictReader(f)
listener_severity_dict = {
row["listener_id"]: row["severity"] for row in listener_dict
}
listener_severity = listener_severity_dict[listener_id]
# ... and the reference signal name is formed as follows.
ref_signal = f"{cec}_{scene}_ref"
listener_data = listener_data_dict[listener_severity]
listener = Listener.from_dict(listener_data)
# Retrieve signals and convert to float32 between -1 and 1
sr_proc, proc = wavfile.read(Path(signal_dir) / f"{signal_name}.wav")
sr_ref, ref = wavfile.read(Path(ref_dir) / f"{ref_signal}.wav")
assert sr_ref == sr_proc
proc = proc / 32768.0
ref = ref / 32768.0
# Compute haspi score using library code
haspi_score = haspi_v2_be(
reference_left=ref[:, 0],
reference_right=ref[:, 1],
processed_left=proc[:, 0],
processed_right=proc[:, 1],
sample_rate=sr_proc,
listener=listener,
)
return haspi_score
# pylint: disable = no-value-for-parameter
[docs]
@hydra.main(config_path=".", config_name="config", version_base=None)
def run_calculate_haspi(cfg: DictConfig) -> None:
"""Run the HASPI score computation."""
# Load the set of signal for which we need to compute scores
dataset_filename = (
Path(cfg.clarity_data_root)
/ cfg.dataset
/ "metadata"
/ f"CPC3.{cfg.split}.json"
)
with dataset_filename.open("r", encoding="utf-8") as fp:
records = json.load(fp)
dataroot = Path(cfg.clarity_data_root) / cfg.dataset
# Load existing results file if present
batch_str = (
f".{cfg.compute_haspi.batch}_{cfg.compute_haspi.n_batches}"
if cfg.compute_haspi.n_batches > 1
else ""
)
results_file = Path(f"{cfg.dataset}.{cfg.split}.haspi{batch_str}.jsonl")
results = read_jsonl(str(results_file)) if results_file.exists() else []
results_index = {result["signal"]: result for result in results}
# Find signals for which we don't have scores
records = [
record for record in records if record["signal"] not in results_index.keys()
]
records = records[cfg.compute_haspi.batch - 1 :: cfg.compute_haspi.n_batches]
# Iterate over the signals that need scoring
logger.info(f"Computing scores for {len(records)} signals")
for record in tqdm(records):
signal_name = record["signal"]
if cfg.compute_haspi.set_random_seed:
set_seed_with_string(signal_name)
haspi = compute_haspi_for_signal(record, dataroot, cfg.split)
# Results are appended to the results file to allow interruption
result = {"signal": signal_name, "haspi": haspi}
write_jsonl(str(results_file), [result])
if __name__ == "__main__":
run_calculate_haspi()