Source code for clarity.enhancer.gha.gha_interface
from __future__ import annotations
import logging
import os
import pathlib
import subprocess
import tempfile
from pathlib import Path
import numpy as np
from jinja2 import Environment, FileSystemLoader
from clarity.enhancer.gha.gha_utils import format_gaintable, get_gaintable
from clarity.utils.audiogram import Listener
from clarity.utils.file_io import read_signal, write_signal
[docs]
class GHAHearingAid:
def __init__(
self,
sample_rate=44100,
ahr=20,
audf=None,
cfg_file="prerelease_combination4_smooth",
noise_gate_levels=None,
noise_gate_slope=0,
cr_level=0,
max_output_level=100,
equiv_0db_spl=100,
test_nbits=16,
):
if audf is None:
audf = [250, 500, 1000, 2000, 3000, 4000, 6000, 8000]
if noise_gate_levels is None:
noise_gate_levels = [38, 38, 36, 37, 32, 26, 23, 22, 8]
self.sample_rate = sample_rate
self.ahr = ahr
self.audf = audf
self.cfg_file = cfg_file
self.noise_gate_levels = noise_gate_levels
self.noise_gate_slope = noise_gate_slope
self.cr_level = cr_level
self.max_output_level = max_output_level
self.equiv_0db_spl = equiv_0db_spl
self.test_nbits = test_nbits
[docs]
def create_configured_cfgfile(
self, input_file, output_file, formatted_sGt, cfg_template_file
):
"""Using Jinja2, generates cfg file for given configuration.
Creates template output file and configures with correct filenames, peak level
out and DC gaintable.
Args:
input_file (str): file to process
output_file (str): file in which to store processed file
formatted_sGt (ndarray): gaintable formatted for input into cfg file
cfg_template_file: configuration file template
ahr (int): amplification headroom
Returns:
cfg_filename (str): cfg filename
"""
if self.sample_rate != 44100:
logging.error("Current GHA configuration requires 44.1kHz sampling rate.")
raise ValueError(
"Current GHA configuration requires 44.1kHz sampling rate."
)
cfg_template_file = pathlib.Path(cfg_template_file)
# Define cfg filenames
# Read new file and replace any parameter values necessary
# Update peaklevel out by adding headroom
logging.info("Adding %s dB headroom", self.ahr)
peaklevel_in = int(self.equiv_0db_spl)
peaklevel_out = int(self.equiv_0db_spl + self.ahr)
# Render jinja2 template
file_loader = FileSystemLoader(cfg_template_file.parent)
env = Environment(loader=file_loader)
template = env.get_template(cfg_template_file.name)
output = template.render(
io_in=input_file,
io_out=output_file,
peaklevel_in=(
f"[{peaklevel_in} {peaklevel_in} {peaklevel_in} {peaklevel_in}]"
),
peaklevel_out=f"[{peaklevel_out} {peaklevel_out}]",
gtdata=formatted_sGt,
)
return output
[docs]
def process_files(
self, infile_names: list[str], outfile_name: str, listener: Listener
):
"""Process a set of input signals and generate an output.
Args:
infile_names (list[str]): List of input wav files. One stereo wav
file for each hearing device channel
outfile_name (str): File in which to store output wav files
dry_run (bool): perform dry run only
"""
logging.info("Processing %s with listener %s", outfile_name, listener.id)
logging.info(
"Audiogram severity is %s (left) and %s (right)",
listener.audiogram_left.severity,
listener.audiogram_right.severity,
)
audiogram_left = listener.audiogram_left.resample(self.audf)
audiogram_right = listener.audiogram_right.resample(self.audf)
# Get gain table with noisegate correction
gaintable = get_gaintable(
audiogram_left,
audiogram_right,
self.noise_gate_levels,
self.noise_gate_slope,
self.cr_level,
self.max_output_level,
)
formatted_sGt = format_gaintable(gaintable, noisegate_corr=True)
cfg_template = Path(__file__).parent / f"cfg_files/{self.cfg_file}_template.cfg"
# Merge CH1 and CH3 files. This is the baseline configuration.
# CH2 is ignored.
fd_merged, merged_filename = tempfile.mkstemp(
prefix="clarity-merged-", suffix=".wav"
)
# Only need file name; must immediately close the unused file handle.
os.close(fd_merged)
self.create_HA_inputs(infile_names, merged_filename)
# Create the openMHA config file from the template
fd_cfg, cfg_filename = tempfile.mkstemp(
prefix="clarity-openmha-", suffix=".cfg"
)
# Again, only need file name; must immediately close the unused file handle.
os.close(fd_cfg)
with open(cfg_filename, "w", encoding="utf-8") as f:
f.write(
self.create_configured_cfgfile(
merged_filename, outfile_name, formatted_sGt, cfg_template
)
)
# Process file using configured cfg file
# Suppressing OpenMHA output with -q - comment out when testing
# Append log of OpenMHA commands to /cfg_files/logfile
subprocess.run(
[
"mha",
"-q",
"--log=logfile.txt",
f"?read:{cfg_filename}",
"cmd=start",
"cmd=stop",
"cmd=quit",
],
check=True,
)
# Delete temporary files.
os.remove(merged_filename)
os.remove(cfg_filename)
# Check output signal has energy in every channel
sig = read_signal(
outfile_name, sample_rate=self.sample_rate, allow_resample=False
)
if len(np.shape(sig)) == 1:
sig = np.expand_dims(sig, axis=1)
if not np.all(np.sum(abs(sig), axis=0)):
raise ValueError("Channel empty.")
# Rewriting as floating point
write_signal(outfile_name, sig, self.sample_rate, floating_point=True)
logging.info("OpenMHA processing complete")
[docs]
def create_HA_inputs(self, infile_names: list[str], merged_filename: str) -> None:
"""Create input signal for baseline hearing aids.
The baseline hearing aid takes a 4-channel wav file as input. This is
constructed from the left and right signals of the front (CH1) and
rear (CH3) microphones that are available in the Clarity data.
Args:
infile_names (list[str]): Names of file to read
merged_file_name (str): Name of file to write
Raises:
ValueError: If input channel names are inconsistent
"""
if (infile_names[0][-5] != "1") or (infile_names[2][-5] != "3"):
raise ValueError("HA-input signal error: channel mismatch!")
signal_CH1 = read_signal(
infile_names[0], sample_rate=self.sample_rate, allow_resample=False
)
signal_CH3 = read_signal(
infile_names[2], sample_rate=self.sample_rate, allow_resample=False
)
merged_signal = np.zeros((len(signal_CH1), 4))
# channel index 0 = front microphone on the left hearing aid
merged_signal[:, 0] = signal_CH1[:, 0]
# channel index 1 = front microphone on the right hearing aid
merged_signal[:, 1] = signal_CH1[:, 1]
# channel index 2 = rear microphone on the left hearing aid
merged_signal[:, 2] = signal_CH3[:, 0]
# channel index 3 = rear microphone on the right hearing aid
merged_signal[:, 3] = signal_CH3[:, 1]
write_signal(
merged_filename,
merged_signal,
self.sample_rate,
floating_point=True,
strict=True,
)