Source code for clarity.evaluator.haspi.ebm

"""HASPI EBM module"""

from __future__ import annotations

from math import floor
from typing import TYPE_CHECKING

import numpy as np
from scipy.signal import convolve, convolve2d

if TYPE_CHECKING:
    from numpy import ndarray


[docs] def env_filter( reference_db: ndarray, processed_db: ndarray, filter_cutoff: float, freq_sub_sample: float, freq_samp: float, ) -> tuple[ndarray, ndarray]: """ Lowpass filter and subsample the envelope in dB SL produced by the model of the auditory periphery. The LP filter uses a von Hann raised cosine window to ensure that there are no negative envelope values produced by the filtering operation. Args: reference_db (np.ndarray): env in dB SL for the ref signal in each auditory band processed_db (np.ndarray): env in dB SL for the degraded signal in each auditory band filter_cutoff (): LP filter cutoff frequency for the filtered envelope, Hz freq_sub_samp (): subsampling frequency in Hz for the LP filtered envelopes freq_samp (): sampling rate in Hz for the signals xdB and ydB Returns: tuple: reference_env - LP filtered and subsampled reference signal envelope Each frequency band is a separate column. processed_env - LP filtered and subsampled degraded signal envelope Updates: James M. Kates, 12 September 2019. Translated from MATLAB to Python by Zuzanna Podwinska, March 2022. """ # Check the filter design parameters if freq_sub_sample > freq_samp: raise ValueError("upsampling rate too high.") if filter_cutoff > 0.5 * freq_sub_sample: raise ValueError("LP cutoff frequency too high.") # Check the data matrix orientation # Require each frequency band to be a separate column nrow = reference_db.shape[0] # number of rows ncol = reference_db.shape[1] # number of columnts if ncol > nrow: reference_db = reference_db.T processed_db = processed_db.T nsamp = reference_db.shape[0] # Compute the lowpass filter length in samples to give -3 dB at fcut Hz tfilt = 1000 * (1 / filter_cutoff) # filter length in ms tfilt = 0.7 * tfilt # Empirical adjustment to the filter length nfilt = round(0.001 * tfilt * freq_samp) # Filter length in samples nhalf = floor(nfilt / 2) nfilt = int(2 * nhalf) # Force an even filter length # Design the FIR LP filter using a von Hann window to ensure that there are no # negative envelope values. The MATLAB code uses the hanning() function, which # returns the Hann window without the first and last zero-weighted window samples, # unlike np.hann and scipy.signal.windows.hann; the code below replicates this # behaviour window = 0.5 * (1 - np.cos(2 * np.pi * np.arange(1, nfilt / 2 + 1) / (nfilt + 1))) benv = np.concatenate((window, np.flip(window))) benv = benv / np.sum(benv) # LP filter for the envelopes at fsamp reference_env = convolve2d(reference_db, np.expand_dims(benv, 1), "full") reference_env = reference_env[nhalf : nhalf + nsamp, :] processed_env = convolve2d(processed_db, np.expand_dims(benv, 1), "full") processed_env = processed_env[nhalf : nhalf + nsamp, :] # Subsample the LP filtered envelopes space = floor(freq_samp / freq_sub_sample) index = np.arange(0, nsamp, space) return reference_env[index, :], processed_env[index, :]
[docs] def cepstral_correlation_coef( reference_db: ndarray, processed_db: ndarray, thresh_cep: float, thresh_nerve: float, nbasis: int, ) -> tuple[ndarray, ndarray]: """ Compute the cepstral correlation coefficients between the reference signal and the distorted signal log envelopes. The silence portions of the signals are removed prior to the calculation based on the envelope of the reference signal. For each time sample, the log spectrum in dB SL is fitted with a set of half-cosine basis functions. The cepstral coefficients then form the input to the cepstral correlation calculation. Args: reference_db (): subsampled reference signal envelope in dB SL in each band processed_db (): subsampled distorted output signal envelope thresh_cep (): threshold in dB SPL to include sample in calculation thresh_nerve (): additive noise RMS for IHC firing (in dB) nbasis: number of cepstral basis functions to use Returns: tuple: refernce_cep cepstral coefficient matrix for the ref signal (nsamp,nbasis) processed_cep cepstral coefficient matrix for the output signal (nsamp,nbasis) each column is a separate basis function, from low to high Updates: James M. Kates, 23 April 2015. Gammawarp version to fit the basis functions, 11 February 2019. Additive noise for IHC firing rates, 24 April 2019. Translated from MATLAB to Python by Zuzanna Podwinska, March 2022. """ # Processing parameters nbands = reference_db.shape[1] # Mel cepstrum basis functions freq = np.arange(0, nbasis) k = np.arange(0, nbands) cepm = np.zeros((nbands, nbasis)) for n in range(nbasis): basis = np.cos(freq[n] * np.pi * k / (nbands - 1)) cepm[:, n] = basis / np.sqrt(np.sum(basis**2)) # Find the reference segments that lie sufficiently above the quiescent rate x_linear = 10 ** ( reference_db / 20 ) # Convert envelope dB to linear (specific loudness) xsum = np.sum(x_linear, 1) / nbands # Proportional to loudness in sones xsum = 20 * np.log10(xsum) # Convert back to dB (loudness in phons) index = np.where(xsum > thresh_cep)[0] # Identify those segments above threshold nsamp = len(index) # Number of segments above threshold # Exit if not enough segments above zero if nsamp <= 1: raise ValueError("Signal below threshold") # Remove the silent samples reference_db = reference_db[index, :] processed_db = processed_db[index, :] # Add low-level noise to provide IHC firing jitter reference_db = add_noise(reference_db, thresh_nerve) processed_db = add_noise(processed_db, thresh_nerve) # Compute the mel cepstrum coefficients using only those samples above threshold reference_cep = reference_db @ cepm processed_cep = processed_db @ cepm # Remove the average value from the cepstral coefficients. The cepstral # cross-correlation will thus be a cross-covariance, and there is no effect of the # absolute signal level in dB. for n in range(nbasis): x = reference_cep[:, n] x = x - np.mean(x) reference_cep[:, n] = x y = processed_cep[:, n] y = y - np.mean(y) processed_cep[:, n] = y return reference_cep, processed_cep
[docs] def add_noise(reference_db: ndarray, thresh_db: float) -> ndarray: """ Add independent random Gaussian noise to the subsampled signal envelope in each auditory frequency band. Args: reference_db (): subsampled envelope in dB re:auditory threshold thresh_db (): additive noise RMS level (in dB) Returns: () envelope with threshold noise added, in dB re:auditory threshold Updates: James M. Kates, 23 April 2019. Translated from MATLAB to Python by Zuzanna Podwinska, March 2022. """ # Additive noise sequence # Gaussian noise with RMS=1, then scaled noise = thresh_db * np.random.standard_normal(reference_db.shape) # Add the noise to the signal envelope return reference_db + noise
[docs] def fir_modulation_filter( reference_envelope: ndarray, processed_envelope: ndarray, freq_sub_sampling: float, center_frequencies: ndarray | None = None, ) -> tuple[ndarray, ndarray, ndarray]: """ Apply a FIR modulation filterbank to the reference envelope signals contained in matrix reference_envelope and the processed signal envelope signals in matrix processed_envelope. Each column in reference_envelope and processed_envelope is a separate filter band or cepstral coefficient basis function. The modulation filters use a lowpass filter for the lowest modulation rate, and complex demodulation followed by a lowpass filter for the remaining bands. The onset and offset transients are removed from the FIR convolutions to temporally align the modulation filter outputs. Args: reference_envelope (np.ndarray) : matrix containing the subsampled reference envelope values. Each column is a different frequency band or cepstral basis function arranged from low to high. processed_envelope (np.ndarray): matrix containing the subsampled processed envelope values freq_sub_sampling (): envelope sub-sampling rate in Hz center_frequencies (np.ndarray): Center Frequencies Returns: tuple: reference_modulation (): a cell array containing the reference signal output of the modulation filterbank. reference_modulation is of size [nchan,nmodfilt] where nchan is the number of frequency channels or cepstral basis functions in reference_envelope, and nmodfilt is the number of modulation filters used in the analysis. Each cell contains a column vector of length nsamp, where nsamp is the number of samples in each envelope sequence contained in the columns of reference_envelope. processed_modulation (): cell array containing the processed signal output of the modulation filterbank. center_frequencies (): vector of modulation rate filter center frequencies Updates: James M. Kates, 14 February 2019. Two matrix version of gwarp_ModFiltWindow, 19 February 2019. Translated from MATLAB to Python by Zuzanna Podwinska, March 2022. """ # Input signal properties nsamp = reference_envelope.shape[0] nchan = reference_envelope.shape[1] # Modulation filter band cf and edges, 10 bands # Band spacing uses the factor of 1.6 used by Dau if center_frequencies is None: center_frequencies = np.array( [2, 6, 10, 16, 25, 40, 64, 100, 160, 256] ) # Band center frequencies nmod = len(center_frequencies) edge = np.zeros(nmod + 1) edge[0:3] = [0, 4, 8] for k in range(3, nmod + 1): edge[k] = (center_frequencies[k - 1] ** 2) / edge[k - 1] # Allowable filters based on envelope subsampling rate fn_yq = 0.5 * freq_sub_sampling index = edge < fn_yq edge = edge[index] # Filter upper bands edges less than Nyquist rate nmod = len(edge) - 1 center_frequencies = center_frequencies[:nmod] # Assign FIR filter lengths. Setting t0=0.2 gives a filter Q of about 1.25 # to match Ewert et al. (2002), and t0=0.33 corresponds to Q=2 (Dau et al 1997a). # Moritz et al. (2015) used t0=0.29. General relation Q=6.25*t0, compromise with # t0=0.24 which gives Q=1.5 t_0 = 0.24 # Filter length in seconds for the lowest modulation frequency band t = np.zeros(nmod) # pylint: disable=invalid-name t[0] = t_0 t[1] = t_0 t[2:nmod] = ( t_0 * center_frequencies[2] / center_frequencies[2:nmod] ) # Constant-Q filters above 10 Hz nfir = 2 * np.floor( t * freq_sub_sampling / 2 ) # Force even filter lengths in samples nhalf = nfir / 2 # Design the family of lowpass windows filter_coefficients = [] for k in range(nmod): coefficient = np.hanning(nfir[k] + 1) coefficient /= np.sum(coefficient) filter_coefficients.append(coefficient) # Pre-compute the cosine and sine arrays cosine = [] # cosine array, one frequency per list element sine = [] # sine array, one frequency per list element n = np.arange(1, nsamp + 1) for k in range(nmod): if k == 0: cosine.append(1) sine.append(0) else: cosine.append( np.sqrt(2) * np.cos(np.pi * n * center_frequencies[k] / fn_yq) ) sine.append(np.sqrt(2) * np.sin(np.pi * n * center_frequencies[k] / fn_yq)) # Convolve the input and output envelopes with the modulation filters reference_modulation = np.zeros((nchan, nmod, nsamp)) processed_modulation = np.zeros((nchan, nmod, nsamp)) for k in range(nmod): # Loop over the modulation filters coefficient = filter_coefficients[k] transient_duration = int(nhalf[k]) # Transient duration for the filter _cosine = cosine[k] # Cosine and sine for complex modulation _sine = sine[k] for m in range(nchan): # Loop over the input signal vectors # Reference signal # Extract the frequency or cepstral coefficient band reference_cepstral_coef = reference_envelope[:, m] # Complex demodulation, then LP filter reference_complex_demodulation = convolve( ( reference_cepstral_coef * _cosine - 1j * reference_cepstral_coef * _sine ), coefficient, ) # Truncate the filter transients reference_complex_demodulation = reference_complex_demodulation[ transient_duration : transient_duration + nsamp ] # Modulate back up to the carrier freq xfilt = ( np.real(reference_complex_demodulation) * _cosine - np.imag(reference_complex_demodulation) * _sine ) # Save the filtered signal reference_modulation[m, k, :] = xfilt # Processed signal # Extract the frequency or cepstral coefficient band processed_cepstral_coef = processed_envelope[:, m] # Complex demodulation, then LP filter processed_complex_demodulation = convolve( ( processed_cepstral_coef * _cosine - 1j * processed_cepstral_coef * _sine ), coefficient, ) # Truncate the filter transients processed_complex_demodulation = processed_complex_demodulation[ transient_duration : transient_duration + nsamp ] # Modulate back up to the carrier freq yfilt = ( np.real(processed_complex_demodulation) * _cosine - np.imag(processed_complex_demodulation) * _sine ) processed_modulation[m, k, :] = yfilt # Save the filtered signal return reference_modulation, processed_modulation, center_frequencies
[docs] def modulation_cross_correlation( reference_modulation: ndarray, processed_modulation: ndarray ) -> ndarray: """ Compute the cross-correlations between the input signal time-frequency envelope and the distortion time-frequency envelope. The cepstral coefficients or envelopes in each frequency band have been passed through the modulation filterbank using function ebm_ModFilt. Args: reference_modulation (np.array): cell array containing the reference signal output of the modulation filterbank. Xmod is of size [nchan,nmodfilt] where nchan is the number of frequency channels or cepstral basis functions in Xenv, and nmodfilt is the number of modulation filters used in the analysis. Each cell contains a column vector of length nsamp, where nsamp is the number of samples in each envelope sequence contained in the columns of Xenv. processed_modulation (np.ndarray): subsampled distorted output signal envelope Output: float: aveCM modulation correlations averaged over basis functions 2-6 vector of size nmodfilt Updates: James M. Kates, 21 February 2019. Translated from MATLAB to Python by Zuzanna Podwinska, March 2022. """ # Processing parameters nchan = reference_modulation.shape[0] # Number of basis functions nmod = reference_modulation.shape[1] # Number of modulation filters small = 1e-30 # Zero threshold # Compute the cross-covariance matrix covariance_matrix = np.zeros((nchan, nmod)) for m in range(nmod): for j in range(nchan): # Index j gives the input reference band x_j = reference_modulation[j, m] # Input freq band j, modulation freq m x_j -= np.mean(x_j) xsum = np.sum(x_j**2) # Processed signal band y_j = processed_modulation[j, m] # Processed freq band j, modulation freq m y_j -= np.mean(y_j) ysum = np.sum(y_j**2) # Cross-correlate the reference and processed signals if (xsum < small) or (ysum < small): covariance_matrix[j, m] = 0 else: covariance_matrix[j, m] = np.abs(np.sum(x_j * y_j)) / np.sqrt( xsum * ysum ) return np.mean(covariance_matrix[1:6], 0)