Source code for recipes.cad2.task1.baseline.evaluate

"""Evaluate the enhanced signals using HAAQI and Whisper"""

from __future__ import annotations

import json
import logging
from pathlib import Path

import hydra
import numpy as np
import pyloudnorm as pyln
import torch.nn
import whisper
from jiwer import compute_measures
from omegaconf import DictConfig

from clarity.enhancer.multiband_compressor import MultibandCompressor
from clarity.evaluator.haaqi import compute_haaqi
from clarity.evaluator.msbg.msbg import Ear
from clarity.utils.audiogram import Listener
from clarity.utils.flac_encoder import read_flac_signal, save_flac_signal
from clarity.utils.results_support import ResultsFile
from clarity.utils.signal_processing import compute_rms, resample

logger = logging.getLogger(__name__)


[docs] def make_scene_listener_list(scenes_listeners: dict, small_test: bool = False) -> list: """Make the list of scene-listener pairing to process Args: scenes_listeners (dict): Dictionary of scenes and listeners. small_test (bool): Whether to use a small test set. Returns: list: List of scene-listener pairings. """ scene_listener_pairs = [ (scene, listener) for scene in scenes_listeners for listener in scenes_listeners[scene] ] # Can define a standard 'small_test' with just 1/50 of the data if small_test: scene_listener_pairs = scene_listener_pairs[::400] return scene_listener_pairs
[docs] def compute_intelligibility( enhanced_signal: np.ndarray, segment_metadata: dict, scorer: torch.nn.Module, listener: Listener, sample_rate: int, save_intermediate: bool = False, path_intermediate: str | Path | None = None, equiv_0db_spl: float = 100, ) -> tuple[float, float, dict]: """ Compute the Intelligibility score for the enhanced signal using the Whisper model. To the enhanced signal, we apply the MSGB hearing loss model before transcribing with Whisper. Args: enhanced_signal: The enhanced signal segment_metadata: The metadata of the segment scorer: The Whisper model listener: The listener sample_rate: The sample rate of the signal save_intermediate: Save the intermediate signal path_intermediate: The path to save the intermediate signal equiv_0db_spl: The equivalent 0 dB SPL Returns: The intelligibility score for the left and right channels """ lyrics = {} if path_intermediate is None: path_intermediate = Path.cwd() if isinstance(path_intermediate, str): path_intermediate = Path(path_intermediate) ear = Ear( equiv_0db_spl=equiv_0db_spl, sample_rate=sample_rate, ) reference = segment_metadata["text"] lyrics["reference"] = reference # Compute left ear ear.set_audiogram(listener.audiogram_left) enhanced_left = ear.process(enhanced_signal[:, 0])[0] left_path = Path(f"{path_intermediate.as_posix()}_left.flac") save_flac_signal( enhanced_signal, left_path, 44100, sample_rate, ) hypothesis = scorer.transcribe(left_path.as_posix(), fp16=False)["text"] lyrics["hypothesis_left"] = hypothesis left_results = compute_measures(reference, hypothesis) # Compute right ear ear.set_audiogram(listener.audiogram_right) enhanced_right = ear.process(enhanced_signal[:, 1])[0] right_path = Path(f"{path_intermediate.as_posix()}_right.flac") save_flac_signal( enhanced_signal, right_path, 44100, sample_rate, ) hypothesis = scorer.transcribe(right_path.as_posix(), fp16=False)["text"] lyrics["hypothesis_right"] = hypothesis right_results = compute_measures(reference, hypothesis) # Compute the average score for both ears total_words = ( right_results["substitutions"] + right_results["deletions"] + right_results["hits"] ) if save_intermediate: enhanced_signal = np.stack([enhanced_left, enhanced_right], axis=1) save_flac_signal( enhanced_signal, path_intermediate, 44100, sample_rate, ) Path(left_path).unlink() Path(right_path).unlink() return ( left_results["hits"] / total_words, right_results["hits"] / total_words, lyrics, )
[docs] def compute_quality( reference_signal: np.ndarray, enhanced_signal: np.ndarray, listener: Listener, config: DictConfig, ) -> tuple[float, float]: """Compute the HAAQI score for the left and right channels""" scores = [] for channel in range(2): audiogram = ( listener.audiogram_left if channel == 0 else listener.audiogram_right ) s = compute_haaqi( processed_signal=resample( enhanced_signal[:, channel], config.remix_sample_rate, config.HAAQI_sample_rate, ), reference_signal=resample( reference_signal[:, channel], config.input_sample_rate, config.HAAQI_sample_rate, ), processed_sample_rate=config.HAAQI_sample_rate, reference_sample_rate=config.HAAQI_sample_rate, audiogram=audiogram, equalisation=2, level1=65 - 20 * np.log10(compute_rms(reference_signal[:, channel])), ) scores.append(s) return scores[0], scores[1]
[docs] def load_reference_signal( path: str | Path, start_sample: int | None, end_sample: int | None, level_luft: float = -40.0, ) -> np.ndarray: """Load the reference signal""" if isinstance(path, str): path = Path(path) if start_sample is None: start_sample = 0 if end_sample is None: end_sample = -1 vocal, _ = read_flac_signal(path / "vocals.flac") accompaniment = np.zeros_like(vocal) for instrument in ["bass", "drums", "other"]: instrument_signal, sample_rate = read_flac_signal(path / f"{instrument}.flac") accompaniment += instrument_signal mixture = vocal * 10 ** (1 / 20) + accompaniment * 10 ** (-1 / 20) mixture = normalise_luft(mixture, sample_rate, level_luft) return mixture[start_sample:end_sample, :]
[docs] def normalise_luft( signal: np.ndarray, sample_rate: float, target_luft: float = -40.0 ) -> np.ndarray: """ Normalise the signal to a target loudness level. Args: signal: input signal to normalise sample_rate: sample rate of the signal target_luft: target loudness level in LUFS. Returns: np.ndarray: normalised signal """ level_meter = pyln.Meter(int(sample_rate)) input_level = level_meter.integrated_loudness(signal) return signal * (10 ** ((target_luft - input_level) / 20))
[docs] @hydra.main(config_path="", config_name="config", version_base=None) def run_compute_scores(config: DictConfig) -> None: """Compute the scores for the enhanced signals""" enhanced_folder = Path("enhanced_signals") logger.info(f"Evaluating from {enhanced_folder} directory") # Load listener audiograms and songs listener_dict = Listener.load_listener_dict(config.path.listeners_file) # Load alphas with Path(config.path.alphas_file).open("r", encoding="utf-8") as file: alphas = json.load(file) # Load scenes with Path(config.path.scenes_file).open("r", encoding="utf-8") as file: scenes = json.load(file) # Load scene-listeners with Path(config.path.scene_listeners_file).open("r", encoding="utf-8") as file: scenes_listeners = json.load(file) # Load songs with Path(config.path.musics_file).open("r", encoding="utf-8") as file: songs = json.load(file) # Load compressor params with Path(config.path.enhancer_params_file).open("r", encoding="utf-8") as file: enhancer_params = json.load(file) scores_headers = [ "scene", "song", "listener", "lyrics", "hypothesis_left", "hypothesis_right", "haaqi_left", "haaqi_right", "haaqi_avg", "whisper_left", "whisper_rigth", "whisper_be", "alpha", "score", ] if config.evaluate.batch_size == 1: results_file = ResultsFile( "scores.csv", header_columns=scores_headers, ) else: results_file = ResultsFile( f"scores_{config.evaluate.batch + 1}-{config.evaluate.batch_size}.csv", header_columns=scores_headers, ) # Create the list of scene-listener pairs scene_listener_pairs = make_scene_listener_list( scenes_listeners, config.evaluate.small_test ) scene_listener_pairs = scene_listener_pairs[ config.evaluate.batch :: config.evaluate.batch_size ] # create hearing aid enhancer = MultibandCompressor( crossover_frequencies=config.enhancer.crossover_frequencies, sample_rate=config.input_sample_rate, ) intelligibility_scorer = whisper.load_model(config.evaluate.whisper_version) # Loop over the scene-listener pairs for idx, scene_listener_ids in enumerate(scene_listener_pairs, 1): # Iterate over the scene-listener pairs # The reference is the original signal logger.info( f"[{idx:04d}/{len(scene_listener_pairs):04d}] Processing scene-listener" f" pair: {scene_listener_ids}" ) scene_id, listener_id = scene_listener_ids # Load scene details scene = scenes[scene_id] listener = listener_dict[listener_id] alpha = alphas[scene["alpha"]] ############################################################# # REFERENCE SIGNAL # Load the reference signal start_sample = int( songs[scene["segment_id"]]["start_time"] * config.input_sample_rate ) end_sample = int( songs[scene["segment_id"]]["end_time"] * config.input_sample_rate ) reference = load_reference_signal( Path(config.path.music_dir) / songs[scene["segment_id"]]["path"], start_sample, end_sample, ) # Get the listener's compressor params mbc_params_listener: dict[str, dict] = {"left": {}, "right": {}} for ear in ["left", "right"]: mbc_params_listener[ear]["release"] = config.enhancer.release mbc_params_listener[ear]["attack"] = config.enhancer.attack mbc_params_listener[ear]["threshold"] = config.enhancer.threshold mbc_params_listener["left"]["ratio"] = enhancer_params[listener_id]["cr_l"] mbc_params_listener["right"]["ratio"] = enhancer_params[listener_id]["cr_r"] mbc_params_listener["left"]["makeup_gain"] = enhancer_params[listener_id][ "gain_l" ] mbc_params_listener["right"]["makeup_gain"] = enhancer_params[listener_id][ "gain_r" ] # Apply compressor to reference signal enhancer.set_compressors(**mbc_params_listener["left"]) left_reference = enhancer(signal=reference[:, 0]) enhancer.set_compressors(**mbc_params_listener["right"]) right_reference = enhancer(signal=reference[:, 1]) # Reference signal amplified reference = np.stack([left_reference[0], right_reference[0]], axis=1) ############################################################# # ENHANCED SIGNAL # Load the enhanced signals enhanced_signal_path = ( enhanced_folder / f"{scene_id}_{listener_id}_A{alpha}_remix.flac" ) enhanced_signal, _ = read_flac_signal(enhanced_signal_path) ############################################################# # COMPUTE SCORES # Compute the HAAQI and Whisper scores haaqi_scores = compute_quality(reference, enhanced_signal, listener, config) whisper_left, whisper_right, lyrics_text = compute_intelligibility( enhanced_signal=enhanced_signal, segment_metadata=songs[scene["segment_id"]], scorer=intelligibility_scorer, listener=listener, sample_rate=config.remix_sample_rate, save_intermediate=config.evaluate.save_intermediate, path_intermediate=enhanced_signal_path.parent / f"{scene_id}_{listener_id}_A{alpha}_remix_hl.flac", equiv_0db_spl=config.evaluate.equiv_0db_spl, ) max_whisper = np.max([whisper_left, whisper_right]) results_file.add_result( { "scene": scene_id, "song": songs[scene["segment_id"]]["track_name"], "listener": listener_id, "lyrics": lyrics_text["reference"], "hypothesis_left": lyrics_text["hypothesis_left"], "hypothesis_right": lyrics_text["hypothesis_right"], "haaqi_left": haaqi_scores[0], "haaqi_right": haaqi_scores[1], "haaqi_avg": np.mean(haaqi_scores), "whisper_left": whisper_left, "whisper_rigth": whisper_right, "whisper_be": max_whisper, "alpha": alpha, "score": alpha * max_whisper + (1 - alpha) * np.mean(haaqi_scores), } )
# pylint: disable = no-value-for-parameter if __name__ == "__main__": run_compute_scores() logger.info("Evaluation completed")