Source code for clarity.evaluator.msbg.msbg

"""Implementation of the MSBG hearing loss model."""

from __future__ import annotations

import logging
import math
from typing import Final

import numpy as np
import scipy
from numpy import ndarray
from scipy.signal import firwin, lfilter

from clarity.evaluator.msbg.cochlea import Cochlea
from clarity.evaluator.msbg.msbg_utils import (
    DF_ED,
    FF_ED,
    HZ,
    ITU_ERP_DRP,
    ITU_HZ,
    MIDEAR,
    firwin2,
    gen_eh2008_speech_noise,
    gen_tone,
    measure_rms,
)
from clarity.utils.audiogram import Audiogram

# Cut off frequency of low-pass filter at end of simulations:
# prevents possible excessive processing noise at high frequencies.
UPPER_CUTOFF_HZ: Final = 18000


[docs] class Ear: """Representation of a pairs of ears.""" def __init__( self, src_pos: str = "ff", sample_rate: float = 44100.0, equiv_0db_spl: float = 100.0, ahr: float = 20.0, ) -> None: """ Constructor for the Ear class. Args: src_pos (str): Position of the source. sample_rate (float): sample frequency. equiv_0db_spl (): ??? ahr (): ??? """ self.sample_rate = sample_rate self.src_correction = self.get_src_correction(src_pos) self.equiv_0db_spl = equiv_0db_spl self.ahr = ahr self.cochlea: Cochlea | None = None
[docs] def set_audiogram(self, audiogram: Audiogram) -> None: """Set the audiogram to be used.""" if np.max(audiogram.levels[audiogram.levels is not None]) > 80: logging.warning( "Impairment too severe: Suggest you limit audiogram max to" "80-90 dB HL, otherwise things go wrong/unrealistic." ) self.cochlea = Cochlea(audiogram=audiogram)
[docs] @staticmethod def get_src_correction(src_pos: str) -> ndarray: """Select relevant external field to eardrum correction. Args: src_pos (str): Position of src. One of ff, df or ITU """ if src_pos == "ff": src_correction = FF_ED elif src_pos == "df": src_correction = DF_ED elif src_pos == "ITU": # transfer to same grid field = scipy.interpolate.interp1d(ITU_HZ, ITU_ERP_DRP, kind="linear") src_correction = field(HZ) else: logging.error( f"Invalid src position ({src_pos}). Must be one of ff, df or ITU" ) raise ValueError("Invalid src position") return src_correction
[docs] @staticmethod def src_to_cochlea_filt( input_signal: ndarray, src_correction: ndarray, sample_rate: float, backward: bool = False, ) -> ndarray: """Simulate middle and outer ear transfer functions. Made more general, Mar2012, to include diffuse field as well as ITU reference points, that were included in DOS-versions of recruitment simulator, released ca 1999-2001, and on hearing group website, Mar2012 variable [src_pos] takes one of 3 values: 'ff', 'df' and 'ITU' free-field to cochlea filter forwards or backward direction, depends on 'backward' switch. NO LONGER via 2 steps. ff to eardrum and then via middle ear: use same length FIR 5-12-97. Args: input_signal (ndarray): signal to process src_correction (np.ndarray): correction to make for src position as an array returned by get_src_correction(src_pos) where src_pos is one of ff, df or ITU sample_rate (int): sampling frequency backward (bool, optional): if true then cochlea to src (default: False) Returns: np.ndarray: the processed signal """ logging.info("performing outer/middle ear corrections") # make sure that response goes only up to sample_frequency/2 nyquist = int(sample_rate / 2.0) ixf_useful = np.nonzero(HZ < nyquist) hz_used = HZ[ixf_useful] hz_used = np.append(hz_used, nyquist) # sig from free field to cochlea: 0 dB at 1kHz correction = src_correction - MIDEAR field = scipy.interpolate.interp1d(HZ, correction, kind="linear") last_correction = field(nyquist) # generate synthetic response at Nyquist correction_used = np.append(correction[ixf_useful], last_correction) if backward: # ie. coch->src rather than src->coch correction_used = -correction_used correction_used = np.power(10, (0.05 * correction_used)) correction_used = correction_used.flatten() # Create filter with 23 msec window to do reasonable job down to about 100 Hz # Scales with fs, fails with longer windows in fir2 in original MATLAB version n_wdw = 2 * math.floor((sample_rate / 16e3) * 368 / 2) hz_used = hz_used / nyquist b = firwin2(n_wdw + 1, hz_used.flatten(), correction_used, window=("kaiser", 4)) output_signal = scipy.signal.lfilter(b, 1, input_signal) return output_signal
[docs] def make_calibration_signal( self, ref_rms_db: float, n_channels: int = 1 ) -> tuple[ndarray, ndarray]: """Add the calibration signal to the start of the signal. Args: ref_rms_db (float): reference rms level in dB Returns: tuple[ndarray, ndarray] - pre and post calibration signals """ # Calibration noise and tone with same RMS as original speech, # Tone at nearest channel centre frequency to 500 Hz # For testing, ref_rms_dB must be equal to -31.2 noise_burst = gen_eh2008_speech_noise( duration=2, sample_rate=self.sample_rate, level=ref_rms_db ) tone_burst = gen_tone( freq=520, duration=0.5, sample_rate=self.sample_rate, level=ref_rms_db, ) silence = np.zeros(int(0.05 * self.sample_rate)) # 50 ms duration pre_calibration = np.concatenate( (silence, tone_burst, silence, noise_burst, silence) ) # Repeat signals for the desired number of channels post_calibration = np.tile(silence[np.newaxis, ...], (n_channels, 1)) pre_calibration = np.tile(pre_calibration[np.newaxis, ...], (n_channels, 1)) return (pre_calibration, post_calibration)
[docs] def process(self, signal: ndarray, add_calibration: bool = False) -> list[ndarray]: """Run the hearing loss simulation. Args: signal (ndarray): signal to process, shape either N, Nx1, Nx2 add_calibration (bool): prepend calibration tone and speech-shaped noise (default: False) Returns: np.ndarray: the processed signal """ signal = signal.T # signals as rows if len(signal.shape) == 1: signal = signal[np.newaxis, ...] sample_rate = 44100 # This is the only sampling frequency that can be used if sample_rate != self.sample_rate: logging.error( "Warning: only a sampling frequency of 44.1kHz can be used by MSBG." ) raise ValueError("Invalid sampling frequency, valid value is 44100") logging.info("Processing {len(chans)} samples") # Need to know file RMS, and then call that a certain level in SPL: # needs some form of pre-measuring. signal_rms_level_db = 10 * np.log10(np.mean(np.array(signal) ** 2)) equiv_0db_spl = self.equiv_0db_spl + self.ahr level_db_spl = equiv_0db_spl + signal_rms_level_db calib_db_spl = level_db_spl target_spl = level_db_spl ref_rms_db = calib_db_spl - equiv_0db_spl # Measure RMS where 3rd arg is dB_rel_rms (how far below) calculated_rms, idx, _rel_db_thresh, _active = measure_rms( signal[0], sample_rate, -12 ) # Rescale input data and check level after rescaling # This is to ensure that the following processing steps are applied correctly change_db = target_spl - (equiv_0db_spl + 20 * np.log10(calculated_rms)) signal = signal * np.power(10, 0.05 * change_db) new_rms_db = equiv_0db_spl + 10 * np.log10( np.mean(np.power(signal[0][idx], 2.0)) ) logging.info( "Rescaling: " f"leveldBSPL was {level_db_spl:3.1f} dB SPL, now {new_rms_db:3.1f} dB SPL. " f" Target SPL is {target_spl:3.1f} dB SPL." ) # Add calibration signal at target SPL dB if add_calibration is True: pre_calibration, post_calibration = self.make_calibration_signal( ref_rms_db, n_channels=signal.shape[0] ) # signal = [ # np.concatenate((calibration_signal[0], x, #calibration_signal[1])) # for x in signal # ] signal = np.concatenate((pre_calibration, signal, post_calibration), axis=1) # Transform from src pos to cochlea, simulate cochlea, transform back to src pos signal = Ear.src_to_cochlea_filt(signal, self.src_correction, sample_rate) if self.cochlea is not None: signal = np.array([self.cochlea.simulate(x, equiv_0db_spl) for x in signal]) signal = Ear.src_to_cochlea_filt( signal, self.src_correction, sample_rate, backward=True ) # Implement low-pass filter at top end of audio range: flat to Cutoff freq, # tails below -80 dB. Suitable lpf for signals later converted to MP3, flat to # 15 kHz. Small window to design low-pass FIR, to cut off high freq processing # noise low-pass to something sensible, prevents exaggeration of > 15 kHz winlen = 2 * math.floor(0.0015 * sample_rate) + 1 lpf44d1 = firwin( winlen, UPPER_CUTOFF_HZ / int(sample_rate / 2), window=("kaiser", 8) ) signal_list = [lfilter(lpf44d1, 1, x) for x in signal] return signal_list