"""
An FIR-based torch implementation of approximated MSBG hearing loss model
"""
from __future__ import annotations
import json
from pathlib import Path
from typing import Final
import numpy as np
import torch
import torch.nn.functional as F
import torchaudio
from scipy.fftpack import fft
from scipy.interpolate import interp1d
from scipy.signal import ellip, firwin, firwin2, freqz
from torch import nn
from clarity.evaluator.msbg.msbg_utils import (
DF_ED,
FF_ED,
HZ,
ITU_ERP_DRP,
ITU_HZ,
MIDEAR,
)
from clarity.evaluator.msbg.smearing import make_smear_mat3
EPS = 1e-8
# old msbg matlab
# set RMS so that peak of output file so that no clipping occurs, set so that
# equiv0dBfileSPL > 100dB for LOUD input files
REF_RMS_DB: Final = -31.2
# what RMS of INPUT speech file translates to in real world (unweighted)
CALIB_DB_SPL: Final = 65
# what 0dB file signal would translate to in dB SPL:
# constant for cochlea_simulate function
EQUIV_0_DB_FILE_SPL: Final = CALIB_DB_SPL - REF_RMS_DB
# clarity msbg
AHR: Final = 20
EQUIV_0_DB_SPL: Final = 100 + AHR
[docs]
class MSBGHearingModel(nn.Module):
def __init__(
self,
audiogram: np.ndarray,
audiometric: np.ndarray,
sr: int = 44100,
spl_cali: bool = True,
src_position: str = "ff",
kernel_size: int = 1025,
device: str | None = None,
) -> None:
super().__init__()
self.sr = sr
self.spl_cali = spl_cali
self.src_position = src_position
self.kernel_size = kernel_size
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
# settings for audiogram
audiogram = np.append(audiogram, audiogram[-1])
audiometric = np.append(audiometric, 16000)
audiogram = np.append(audiogram[0], audiogram)
audiometric = np.append(125, audiometric)
audiogram_cfs = (
np.array([0.125, 0.25, 0.5, 0.75, 1, 1.5, 2, 3, 4, 6, 8, 10, 12, 14, 16])
* 1000
)
interp_f = interp1d(audiometric, audiogram)
audiogram = interp_f(audiogram_cfs)
# settings for src_to_cochlea_filt
if src_position == "ff":
src_corrn = FF_ED
elif src_position == "df":
src_corrn = DF_ED
elif src_position == "ITU":
interf_itu = interp1d(ITU_HZ, ITU_ERP_DRP)
src_corrn = interf_itu(HZ)
else:
raise ValueError(f"Unknown src_position: {src_position}")
nyquist = sr / 2
ixf_useful = np.where(HZ < nyquist)[0]
hz_used = np.append(HZ[ixf_useful], nyquist)
corrn = src_corrn - MIDEAR
interf_corrn = interp1d(HZ, corrn)
last_corrn = interf_corrn(nyquist)
corrn_used = np.append(corrn[ixf_useful], last_corrn)
corrn_forward = 10 ** (0.05 * corrn_used)
corrn_backward = 10 ** (0.05 * -1 * corrn_used)
n_wdw = int(2 * np.floor((sr / 16e3) * 368 / 2))
cochlea_filter_forward = firwin2(
n_wdw + 1, hz_used / nyquist, corrn_forward, window=("kaiser", 4)
)
cochlea_filter_backward = firwin2(
n_wdw + 1, hz_used / nyquist, corrn_backward, window=("kaiser", 4)
)
self.cochlea_padding = len(cochlea_filter_forward) // 2
self.cochlea_filter_forward = (
torch.tensor(
cochlea_filter_forward, dtype=torch.float32, device=self.device
)
.unsqueeze(0)
.unsqueeze(1)
)
self.cochlea_filter_backward = (
torch.tensor(
cochlea_filter_backward, dtype=torch.float32, device=self.device
)
.unsqueeze(0)
.unsqueeze(1)
)
# Settings for smearing
catch_up = 105.0 # dBHL where impaired catches up with normal
# recruitment simulation comes with 3 degrees of broadening of auditory filters:
# different set of centre freqs between simulations.
# check and categorise audiogram: currently ALWAYS recruit with x2 broadening:
# it's the smearing that changes
impaired_freqs = np.where((audiogram_cfs >= 2000) & (audiogram_cfs <= 8000))[0]
impaired_degree = np.mean(audiogram[impaired_freqs])
# impairment degree affects smearing simulation, and now recruitment,
# (assuming we do not have too much SEVERE losses present)
current_dir = Path(__file__).parent
gtf_dir = current_dir / "../evaluator/msbg/msbg_hparams"
if impaired_degree > 56:
f_smear = make_smear_mat3(4, 2, sr)
gt4_bank_file = gtf_dir / "GT4FBank_Brd3.0E_Spaced2.3E_44100Fs.json"
bw_broaden_coef = 3
elif impaired_degree > 35:
gt4_bank_file = gtf_dir / "GT4FBank_Brd2.0E_Spaced1.5E_44100Fs.json"
bw_broaden_coef = 2
f_smear = make_smear_mat3(2.4, 1.6, sr)
elif impaired_degree > 15:
gt4_bank_file = gtf_dir / "GT4FBank_Brd1.5E_Spaced1.1E_44100Fs.json"
bw_broaden_coef = 1
f_smear = make_smear_mat3(1.6, 1.1, sr)
else:
gt4_bank_file = gtf_dir / "GT4FBank_Brd1.5E_Spaced1.1E_44100Fs.json"
bw_broaden_coef = 1
f_smear = make_smear_mat3(1.001, 1.001, sr)
# gt4_bank = loadmat(gt4_bank_file)
with gt4_bank_file.open("r", encoding="utf-8") as fp:
gt4_bank = json.load(fp)
self.smear_nfft = 512
self.smear_win_len = 256
self.smear_hop_len = 64
smear_window = (
0.5
- 0.5
* np.cos(
2
* np.pi
* (np.arange(1, self.smear_win_len + 1) - 0.5)
/ self.smear_win_len
)
) / np.sqrt(1.5)
self.smear_window = torch.tensor(
smear_window, dtype=torch.float32, device=self.device
)
self.f_smear = torch.tensor(f_smear, dtype=torch.float32, device=self.device)
""" settings for recruitment"""
cf_expansion = 0 * np.array(gt4_bank["GTn_CentFrq"])
eq_loud_db = 0 * np.array(gt4_bank["GTn_CentFrq"])
for ix_cfreq in range(len(gt4_bank["GTn_CentFrq"])):
if gt4_bank["GTn_CentFrq"][ix_cfreq] < audiogram_cfs[0]:
cf_expansion[ix_cfreq] = catch_up / (catch_up - audiogram[0])
elif gt4_bank["GTn_CentFrq"][ix_cfreq] > audiogram_cfs[-1]:
cf_expansion[ix_cfreq] = catch_up / (catch_up - audiogram[-1])
else:
interp_audiogram = interp1d(audiogram_cfs, audiogram)
audiog_cf = interp_audiogram(gt4_bank["GTn_CentFrq"][ix_cfreq])
cf_expansion[ix_cfreq] = catch_up / (catch_up - audiog_cf)
eq_loud_db[ix_cfreq] = catch_up
self.n_chans = gt4_bank["NChans"]
self.gtn_denoms = torch.tensor(
gt4_bank["GTn_denoms"], dtype=torch.float32, device=self.device
)
self.gtn_nums = torch.tensor(
gt4_bank["GTn_nums"], dtype=torch.float32, device=self.device
)
self.hp_denoms = torch.tensor(
gt4_bank["HP_denoms"], dtype=torch.float32, device=self.device
)
self.hp_nums = torch.tensor(
gt4_bank["HP_nums"], dtype=torch.float32, device=self.device
)
self.ngamma = int(gt4_bank["NGAMMA"])
self.gtn_delays = gt4_bank["GTnDelays"]
self.start_2_pole_hp = gt4_bank["Start2PoleHP"]
erbn_centre_freq = gt4_bank["ERBn_CentFrq"]
chan_lpf_b = []
chan_lpf_a = []
fir_lpf = []
for ixch in range(self.n_chans):
fc_envelope = (30 / 40) * np.min([100, erbn_centre_freq[ixch]])
chan_lpf_b_ch, chan_lpf_a_ch = ellip(
2, 0.25, 35, fc_envelope / (self.sr / 2)
)
chan_lpf_b.append(chan_lpf_b_ch)
chan_lpf_a.append(chan_lpf_a_ch)
fir_lpf_ch = firwin(
self.kernel_size, fc_envelope / (self.sr / 2), pass_zero="lowpass"
) / np.sqrt(
2
) # sqrt(2) is for the consistency with IIR
fir_lpf.append(fir_lpf_ch)
self.chan_lpf_b = torch.tensor(
np.array(chan_lpf_b), dtype=torch.float32, device=self.device
)
self.chan_lpf_a = torch.tensor(
np.array(chan_lpf_a), dtype=torch.float32, device=self.device
)
self.fir_lpf = torch.tensor(
np.array(fir_lpf), dtype=torch.float32, device=self.device
).unsqueeze(1)
self.expansion_m1 = torch.tensor(
cf_expansion - 1, dtype=torch.float32, device=self.device
)
# self.envlp_max = torch.tensor(10 ** (0.05 * (eq_loud_db - equiv0dBfileSPL)),
# dtype=torch.float32, device=self.device)
self.envelope_max = torch.tensor(
10 ** (0.05 * (eq_loud_db - EQUIV_0_DB_SPL)),
dtype=torch.float32,
device=self.device,
)
recombination_db = gt4_bank["Recombination_dB"]
self.recruitment_out_coef = torch.tensor(
10 ** (-0.05 * recombination_db), dtype=torch.float32, device=self.device
)
"settings for FIR Gammatone Filters"
gt_cfreq = np.array(gt4_bank["GTn_CentFrq"])
gt_bw = np.array(gt4_bank["ERBn_CentFrq"]) * 1.1019 * bw_broaden_coef
self.padding = (self.kernel_size - 1) // 2
n_lin = torch.linspace(
0, self.kernel_size - 1, self.kernel_size, device=self.device
)
window_ = 0.54 - 0.46 * torch.cos(2 * np.pi * n_lin / self.kernel_size)
n_ = (
torch.arange(
0, self.kernel_size, dtype=torch.float32, device=self.device
).view(1, -1)
/ self.sr
)
center_hz = (
torch.tensor(
gt_cfreq / self.sr, dtype=torch.float32, device=self.device
).view(-1, 1)
* self.sr
)
f_times_t = torch.matmul(center_hz, n_)
carrier = torch.cos(2 * np.pi * f_times_t)
carrier_sin = torch.sin(2 * np.pi * f_times_t)
band_hz = (
torch.tensor(gt_bw / self.sr, dtype=torch.float32, device=self.device).view(
-1, 1
)
* self.sr
)
b_times_t = torch.matmul(band_hz, n_)
kernel = torch.pow(n_, 4 - 1) * torch.exp(-2 * np.pi * b_times_t)
gammatone = kernel * carrier
self.peaks = torch.argmax(gammatone, dim=1) # for gammatone delay calibration
gammatone_sin = kernel * carrier_sin
filters = (gammatone * window_).view(self.n_chans, 1, self.kernel_size)
# To get the normalised amplitude
filters = filters.squeeze(1).cpu().numpy()
fr_max = np.zeros(self.n_chans)
for i in range(self.n_chans):
fr = np.abs(fft(filters[i]))
fr_ = fr[: int(self.kernel_size / 2)]
fr_max[i] = np.max(fr_)
amp = torch.tensor(fr_max, dtype=torch.float32, device=self.device)
gammatone = gammatone / amp.unsqueeze(1)
gammatone_sin = gammatone_sin / amp.unsqueeze(1)
self.gt_fir = (gammatone * window_).view(self.n_chans, 1, self.kernel_size)
self.gt_fir_sin = (gammatone_sin * window_).view(
self.n_chans, 1, self.kernel_size
)
"settings for spl calibration"
win_sec = 0.01
self.db_relative_rms = -12
self.win_len = int(self.sr * win_sec)
[docs]
def measure_rms(self, wav: torch.Tensor) -> torch.Tensor:
"""Compute RMS level of a signal.
Measures total power of all 10 msec frames that are above a specified
threshold of db_relative_rms
Args:
wav: input signal
Returns:
RMS level in dB
"""
bs = wav.shape[0]
average_rms = torch.sqrt(torch.mean(wav**2, dim=1) + EPS)
threshold_db = 20 * torch.log10(average_rms + EPS) + self.db_relative_rms
num_frames = wav.shape[1] // self.win_len
wav_reshaped = torch.reshape(
wav[:, : num_frames * self.win_len], [bs, num_frames, self.win_len]
)
db_frames = 10 * torch.log10(torch.mean(wav_reshaped**2, dim=2) + EPS)
key_frames = (
torch.where(
db_frames > threshold_db.unsqueeze(1),
torch.tensor(1, dtype=torch.float32, device=self.device),
torch.tensor(0, dtype=torch.float32, device=self.device),
)
.unsqueeze(-1)
.repeat([1, 1, self.win_len])
.reshape([bs, num_frames * self.win_len])
)
key_rms = torch.sqrt(
torch.sum((wav[:, : num_frames * self.win_len] * key_frames) ** 2, dim=1)
/ (torch.sum(key_frames, dim=1) + EPS)
+ EPS
)
return key_rms.unsqueeze(1)
[docs]
def calibrate_spl(self, x: torch.Tensor) -> torch.Tensor:
if self.spl_cali:
level_re_sample_rate = 10 * torch.log10(
torch.mean(x**2, dim=1, keepdim=True) + EPS
)
level_db_spl = EQUIV_0_DB_SPL + level_re_sample_rate
rms = self.measure_rms(x)
change_db = level_db_spl - (EQUIV_0_DB_SPL + 20 * torch.log10(rms + EPS))
x = x * 10 ** (0.05 * change_db)
return x
[docs]
def src_to_cochlea_filt(
self, x: torch.Tensor, cochlea_filter: torch.Tensor
) -> torch.Tensor:
return F.conv1d(x, cochlea_filter, padding=self.cochlea_padding)
[docs]
def smear(self, x: torch.Tensor) -> torch.Tensor:
"""Padding issue needs to be worked out"""
length = x.shape[2]
x = x.view(x.shape[0], x.shape[2])
spec = torch.stft(
x,
n_fft=self.smear_nfft,
hop_length=self.smear_hop_len,
win_length=self.smear_win_len,
window=self.smear_window,
return_complex=True,
)
mag = torch.abs(spec[:, : self.smear_nfft // 2, :])
power = torch.square(mag)
phasor = spec[:, : self.smear_nfft // 2, :] / (mag + EPS)
smeared_power = (
torch.matmul(
power.transpose(-1, -2), self.f_smear.transpose(0, 1)
).transpose(-1, -2)
+ EPS
)
smeared_power = torch.clamp(smeared_power, min=0)
smeared_spec_nyquist = torch.sqrt(smeared_power + EPS) * phasor
smeared_spec_mid = torch.zeros(
[smeared_power.shape[0], 1, smeared_power.shape[2]],
dtype=torch.float32,
device=self.device,
)
smeared_spec = torch.cat([smeared_spec_nyquist, smeared_spec_mid], dim=1)
smeared_wav = torch.istft(
smeared_spec,
n_fft=self.smear_nfft,
hop_length=self.smear_hop_len,
win_length=self.smear_win_len,
window=self.smear_window,
length=length,
)
return smeared_wav.unsqueeze(1)
[docs]
def recruitment(self, x: torch.Tensor) -> torch.Tensor:
n_samples = x.shape[-1]
ixhp = 0
outputs = []
for ixch in range(self.n_chans):
# Gammaton filtering
pass_n = torchaudio.functional.lfilter(
x, self.gtn_denoms[ixch, :], self.gtn_nums[ixch, :]
)
for _ixg in range(self.ngamma - 1):
pass_n = torchaudio.functional.lfilter(
pass_n, self.gtn_denoms[ixch, :], self.gtn_nums[ixch, :]
)
dly = self.gtn_delays[ixch]
pass_n_cali = torch.zeros_like(pass_n)
pass_n_cali[:, :, : n_samples - dly] = pass_n[:, :, dly:n_samples]
# Tail control
if ixch >= self.start_2_pole_hp:
ixhp += 1
pass_n_cali = torchaudio.functional.lfilter(
pass_n_cali, self.hp_denoms[ixhp - 1, :], self.hp_nums[ixhp - 1, :]
)
# Get the envelope
envelope_out = torchaudio.functional.lfilter(
torch.abs(pass_n_cali),
self.chan_lpf_a[ixch, :],
self.chan_lpf_b[ixch, :],
)
envelope_out = torch.flip(envelope_out, dims=[-1])
envelope_out = torchaudio.functional.lfilter(
envelope_out, self.chan_lpf_a[ixch, :], self.chan_lpf_b[ixch, :]
)
envelope_out = torch.flip(envelope_out, dims=[-1])
envelope_out = torch.clamp(
envelope_out, min=EPS, max=float(self.envelope_max[ixch])
)
gain = (envelope_out / self.envelope_max[ixch]) ** self.expansion_m1[ixch]
outputs.append(gain * pass_n_cali)
y = torch.stack(outputs, dim=-1).sum(dim=-1)
y = y * self.recruitment_out_coef
return y
[docs]
def recruitment_fir(self, x: torch.Tensor) -> torch.Tensor:
n_samples = x.shape[-1]
x = x.repeat([1, self.n_chans, 1])
real = F.conv1d(
x, self.gt_fir, bias=None, padding=self.padding, groups=self.n_chans
)
imag = F.conv1d(
x, self.gt_fir_sin, bias=None, padding=self.padding, groups=self.n_chans
)
real_cali = torch.zeros_like(real)
imag_cali = torch.zeros_like(imag)
for i in range(self.n_chans):
real_cali[:, i, : n_samples - self.peaks[i]] = real[
:, i, self.peaks[i] : n_samples
]
imag_cali[:, i, : n_samples - self.peaks[i]] = imag[
:, i, self.peaks[i] : n_samples
]
env = torch.sqrt(real_cali * real_cali + imag_cali * imag_cali + EPS)
env = F.conv1d(
env, self.fir_lpf, bias=None, padding=self.padding, groups=self.n_chans
)
env_max = self.envelope_max.unsqueeze(0).unsqueeze(-1).repeat([1, 1, n_samples])
gain = torch.clamp(env / env_max, min=EPS, max=1)
gain = gain ** self.expansion_m1.unsqueeze(0).unsqueeze(-1).repeat(
[1, 1, n_samples]
)
y = torch.sum(gain * real_cali, dim=1, keepdim=True)
y = y * self.recruitment_out_coef
return y
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.calibrate_spl(x)
x = x.unsqueeze(1)
x = self.src_to_cochlea_filt(x, self.cochlea_filter_forward)
x = self.smear(x)
# x = self.recruitment(x)
x = self.recruitment_fir(x)
y = self.src_to_cochlea_filt(x, self.cochlea_filter_backward)
return y.squeeze(1)
[docs]
class torchloudnorm(nn.Module):
def __init__(
self,
sample_rate: int = 44100,
norm_lufs: int = -36,
kernel_size: int = 1025,
block_size: float = 0.4,
overlap: float = 0.75,
gamma_a: int = -70,
device: str | None = None,
) -> None:
super().__init__()
self.sample_rate = sample_rate
self.norm_lufs = norm_lufs
self.kernel_size = kernel_size
self.padding = kernel_size // 2
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
# for frequency weighting filters - account for the acoustic respose
# of the head and auditory system
pyln_high_shelf_b = np.array([1.53090959, -2.65116903, 1.16916686])
pyln_high_shelf_a = np.array([1.0, -1.66375011, 0.71265753])
# fir high_shelf
w_high_shelf, h_high_shelf = freqz(
pyln_high_shelf_b, pyln_high_shelf_a, fs=sample_rate
)
freq_high_shelf = np.append(w_high_shelf, sample_rate / 2)
gain_high_shelf = np.append(np.abs(h_high_shelf), np.abs(h_high_shelf)[-1])
fir_high_shelf = firwin2(
kernel_size, freq_high_shelf, gain_high_shelf, fs=sample_rate
)
# fir high_pass
fc_high_pass = 38.0
fir_high_pass = firwin(
kernel_size, fc_high_pass, pass_zero="highpass", fs=sample_rate
)
self.high_shelf = (
torch.tensor(fir_high_shelf, dtype=torch.float32, device=self.device)
.unsqueeze(0)
.unsqueeze(1)
)
self.high_pass = (
torch.tensor(fir_high_pass, dtype=torch.float32, device=self.device)
.unsqueeze(0)
.unsqueeze(1)
)
"rms measurement"
self.frame_size = int(block_size * sample_rate)
self.frame_shift = int(block_size * sample_rate * (1 - overlap))
self.unfold = torch.nn.Unfold(
(1, self.frame_size), stride=(1, self.frame_shift)
)
self.gamma_a = gamma_a
[docs]
def apply_filter(self, x: torch.Tensor) -> torch.Tensor:
x = F.conv1d(x, self.high_shelf, padding=self.padding)
x = F.conv1d(x, self.high_pass, padding=self.padding)
return x
[docs]
def integrated_loudness(self, x: torch.Tensor) -> torch.Tensor:
x = self.apply_filter(x)
x_unfold = self.unfold(x.unsqueeze(2))
z = (
torch.sum(x_unfold**2, dim=1) / self.frame_size
) # mean square for each frame
el = -0.691 + 10 * torch.log10(z + EPS)
idx_a = torch.where(el > self.gamma_a, 1, 0)
z_ave_gated_a = torch.sum(z * idx_a, dim=1, keepdim=True) / (
torch.sum(idx_a, dim=1, keepdim=True) + 1e-8
)
gamma_r = -0.691 + 10 * torch.log10(z_ave_gated_a + EPS) - 10
idx_r = torch.where(el > gamma_r, 1, 0)
idx_a_r = idx_a * idx_r
z_ave_gated_a_r = torch.sum(z * idx_a_r, dim=1, keepdim=True) / (
torch.sum(idx_a_r, dim=1, keepdim=True) + 1e-8
)
lufs = -0.691 + 10 * torch.log10(z_ave_gated_a_r + EPS) # loudness
return lufs
[docs]
def normalize_loudness(self, x: torch.Tensor, lufs: torch.Tensor) -> torch.Tensor:
delta_loudness = self.norm_lufs - lufs
gain = torch.pow(10, delta_loudness / 20)
return gain * x
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
loudness = self.integrated_loudness(x.unsqueeze(1))
y = self.normalize_loudness(x, loudness)
return y