Source code for clarity.evaluator.haspi.haspi

"""HASPI intelligibility Index"""

from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Final

import numpy as np

from clarity.evaluator.haspi.eb import ear_model
from clarity.evaluator.haspi.ebm import (
    cepstral_correlation_coef,
    env_filter,
    fir_modulation_filter,
    modulation_cross_correlation,
)
from clarity.evaluator.haspi.ip import get_neural_net, nn_feed_forward_ensemble
from clarity.utils.audiogram import Audiogram, Listener

if TYPE_CHECKING:
    from numpy import ndarray


# HASPI assumes the following audiogram frequencies:
HASPI_AUDIOGRAM_FREQUENCIES: Final = np.array([250, 500, 1000, 2000, 4000, 6000])


[docs] def haspi_v2( # pylint: disable=too-many-arguments too-many-locals reference: ndarray, reference_sample_rate: float, processed: ndarray, processed_sample_rate: float, audiogram: Audiogram, level1: float = 65.0, f_lp: float = 320.0, itype: int = 0, ) -> tuple[float, ndarray]: """ Compute the HASPI intelligibility index using the auditory model followed by computing the envelope cepstral correlation and BM vibration high-level covariance. The reference signal presentation level for NH listeners is assumed to be 65 dB SPL. The same model is used for both normal and impaired hearing. This version of HASPI uses a modulation filterbank followed by an ensemble of neural networks to compute the estimated intelligibility. **NB** - The original HASPI model derivation included a bug which meant that although the 'shift' parameter used in band centre frequency calculations was set to '0.02' it was never actually applied. To replicate this behaviour ear_model is called with 'shift' set to None. For discussion please refer to the discussion in `Issue #105 <https://github.com/claritychallenge/clarity/issues/105>` for further details. Args: reference (np.ndarray): Clear input reference speech signal with no noise or distortion. If a hearing loss is specified, no amplification should be provided. reference_sample_rate (int): Sampling rate in Hz for signal x processed (np.ndarray): Output signal with noise, distortion, HA gain, and/or processing. processed_sample_rate (int): Sampling rate in Hz for signal y. hearing_loss (np.ndarray): (1,6) vector of hearing loss at the 6 audiometric frequencies [250, 500, 1000, 2000, 4000, 6000] Hz. level1 (int): Optional input specifying level in dB SPL that corresponds to a signal RMS = 1. Default is 65 dB SPL if argument not provided. f_lp (int): itype (int): Intelligibility model Returns: tuple(Intel: float, raw: nd-array) Intel: Intelligibility estimated by passing the cepstral coefficients through a modulation filterbank followed by an ensemble of neural networks. raw: vector of 10 cep corr modulation filterbank outputs, averaged over basis functions 2-6. Updates: James M. Kates, 5 August 2013. Translated from MATLAB to Python by Zuzanna Podwinska, March 2022. """ if not audiogram.has_frequencies(HASPI_AUDIOGRAM_FREQUENCIES): logging.warning( "Audiogram does not have all HASPI frequency measurements" "Measurements will be interpolated" ) # Adjust audiogram to match the standard frequencies audiogram = audiogram.resample(HASPI_AUDIOGRAM_FREQUENCIES) # Auditory model for intelligibility # Reference is no processing, normal hearing reference_env, _, processed_env, _, _, _, fsamp = ear_model( reference, reference_sample_rate, processed, processed_sample_rate, audiogram.levels, itype, level1, # shift=0.02 # See comment in docstring shift=None, ) # Envelope modulation features # LP filter and subsample the envelope fsub = 8.0 * f_lp # subsample to span 2 octaves above the cutoff frequency reference_lp, processed_lp = env_filter( reference_env, processed_env, f_lp, fsub, fsamp ) # Compute the cepstral coefficients as a function of subsampled time nbasis = 6 # Use 6 basis functions thr = 2.5 # Silence threshold in dB SL dither = 0.1 # Dither in dB RMS to add to envelope signals reference_cep, processed_cep = cepstral_correlation_coef( reference_lp, processed_lp, thr, dither, nbasis ) # Cepstral coefficients filtered at each modulation rate # Band center frequencies [2, 6, 10, 16, 25, 40, 64, 100, 160, 256] Hz # Band edges [0, 4, 8, 12.5, 20.5, 30.5, 52.4, 78.1, 128, 200, 328] Hz reference_mod, processed_mod, _ = fir_modulation_filter( reference_cep, processed_cep, fsub ) # Cross-correlation between the cepstral coefficients for the degraded and # ref signals at each modulation rate, averaged over basis functions 2-6 average_correlation_matrix = modulation_cross_correlation( reference_mod, processed_mod ) # Intelligibility prediction # Get the neural network parameters and the weights for an ensemble of 10 networks ( neural_net_params, weights_hidden, weights_out, normalization_factor, ) = get_neural_net() # Average the neural network outputs for the modulation filterbank values model = nn_feed_forward_ensemble( average_correlation_matrix, neural_net_params, weights_hidden, weights_out ) model = model / normalization_factor # Return the intelligibility estimate and raw modulation filter outputs return model[0], average_correlation_matrix
[docs] def haspi_v2_be( # pylint: disable=too-many-arguments reference_left: ndarray, reference_right: ndarray, processed_left: ndarray, processed_right: ndarray, sample_rate: float, listener: Listener, level: float = 100.0, ) -> float: """Better ear HASPI. Calculates HASPI for left and right ear and selects the better result. Args: ref_left (np.ndarray): left channel of reference signal ref_right (np.ndarray): right channel of reference signal proc_left (np.ndarray): left channel of processed signal proc_right (np.ndarray): right channel of processed signal sample_rate (int): sampling rate for both signal audiogram_left (): left ear audiogram audiogram_right (): right ear audiogram level: level in dB SPL corresponding to RMS=1 Returns: float: beHASPI score Updates: Zuzanna Podwinska, March 2022 """ score_left, _ = haspi_v2( reference_left, sample_rate, processed_left, sample_rate, listener.audiogram_left, level, ) score_right, _ = haspi_v2( reference_right, sample_rate, processed_right, sample_rate, listener.audiogram_right, level, ) return max(score_left, score_right)