"""Run the dummy enhancement."""
from __future__ import annotations
import json
import logging
from pathlib import Path
# pylint: disable=import-error
import hydra
import numpy as np
import torch
from numpy import ndarray
from omegaconf import DictConfig
from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB
from clarity.enhancer.compressor import Compressor
from clarity.enhancer.nalr import NALR
from clarity.utils.audiogram import Listener
from clarity.utils.file_io import read_signal
from clarity.utils.flac_encoder import FlacEncoder
from clarity.utils.signal_processing import (
clip_signal,
denormalize_signals,
normalize_signal,
resample,
to_16bit,
)
from clarity.utils.source_separation_support import get_device, separate_sources
from recipes.cad_icassp_2024.baseline.evaluate import (
apply_gains,
apply_ha,
make_scene_listener_list,
remix_stems,
)
logger = logging.getLogger(__name__)
[docs]
def save_flac_signal(
signal: np.ndarray,
filename: Path,
signal_sample_rate,
output_sample_rate,
do_clip_signal: bool = False,
do_soft_clip: bool = False,
do_scale_signal: bool = False,
) -> None:
"""
Function to save output signals.
- The output signal will be resample to ``output_sample_rate``
- The output signal will be clipped to [-1, 1] if ``do_clip_signal`` is True
and use soft clipped if ``do_soft_clip`` is True. Note that if
``do_clip_signal`` is False, ``do_soft_clip`` will be ignored.
Note that if ``do_clip_signal`` is True, ``do_scale_signal`` will be ignored.
- The output signal will be scaled to [-1, 1] if ``do_scale_signal`` is True.
If signal is scale, the scale factor will be saved in a TXT file.
Note that if ``do_clip_signal`` is True, ``do_scale_signal`` will be ignored.
- The output signal will be saved as a FLAC file.
Args:
signal (np.ndarray) : Signal to save
filename (Path) : Path to save signal
signal_sample_rate (int) : Sample rate of the input signal
output_sample_rate (int) : Sample rate of the output signal
do_clip_signal (bool) : Whether to clip signal
do_soft_clip (bool) : Whether to apply soft clipping
do_scale_signal (bool) : Whether to scale signal
"""
# Resample signal to expected output sample rate
if signal_sample_rate != output_sample_rate:
signal = resample(signal, signal_sample_rate, output_sample_rate)
if do_scale_signal:
# Scale stem signal
max_value = np.max(np.abs(signal))
signal = signal / max_value
# Save scale factor
with open(filename.with_suffix(".txt"), "w", encoding="utf-8") as file:
file.write(f"{max_value}")
elif do_clip_signal:
# Clip the signal
signal, n_clipped = clip_signal(signal, do_soft_clip)
if n_clipped > 0:
logger.warning(f"Writing {filename}: {n_clipped} samples clipped")
# Convert signal to 16-bit integer
signal = to_16bit(signal)
# Create flac encoder object to compress and save the signal
FlacEncoder().encode(signal, output_sample_rate, filename)
# pylint: disable=unused-argument
[docs]
def decompose_signal(
model: torch.nn.Module,
model_sample_rate: int,
signal: ndarray,
signal_sample_rate: int,
device: torch.device,
sources_list: list[str],
listener: Listener,
normalise: bool = True,
) -> dict[str, ndarray]:
"""
Decompose signal into 8 stems.
The listener is ignored by the baseline system as it
is not performing personalised decomposition.
Instead, it performs a standard music decomposition using a pre-trained
model trained on the MUSDB18 dataset.
Args:
model (torch.nn.Module): Torch model.
model_sample_rate (int): Sample rate of the model.
signal (ndarray): Signal to be decomposed.
signal_sample_rate (int): Sample frequency.
device (torch.device): Torch device to use for processing.
sources_list (list): List of strings used to index dictionary.
listener (Listener).
normalise (bool): Whether to normalise the signal.
Returns:
Dictionary: Indexed by sources with the associated model as values.
"""
if signal.shape[0] > signal.shape[1]:
signal = signal.T
if signal_sample_rate != model_sample_rate:
signal = resample(signal, signal_sample_rate, model_sample_rate)
if normalise:
signal, ref = normalize_signal(signal)
sources = separate_sources(
model,
torch.from_numpy(signal.astype(np.float32)),
model_sample_rate,
device=device,
)
# only one element in the batch
sources = sources[0]
if normalise:
sources = denormalize_signals(sources, ref)
sources = np.transpose(sources, (0, 2, 1))
return dict(zip(sources_list, sources))
[docs]
def process_remix_for_listener(
signal: ndarray,
enhancer: NALR,
compressor: Compressor,
listener: Listener,
apply_compressor: bool = False,
) -> ndarray:
"""Process the stems from sources.
Args:
stems (dict) : Dictionary of stems
sample_rate (float) : Sample rate of the signal
enhancer (NALR) : NAL-R prescription hearing aid
compressor (Compressor) : Compressor
listener: Listener object
apply_compressor (bool) : Whether to apply the compressor
Returns:
ndarray: Processed signal.
"""
left_output = apply_ha(
enhancer, compressor, signal[:, 0], listener.audiogram_left, apply_compressor
)
right_output = apply_ha(
enhancer, compressor, signal[:, 1], listener.audiogram_right, apply_compressor
)
return np.stack([left_output, right_output], axis=1)
[docs]
@hydra.main(config_path="", config_name="config", version_base=None)
def enhance(config: DictConfig) -> None:
"""
Run the music enhancement.
The system decomposes the music into vocal, drums, bass, and other stems.
Then, the NAL-R prescription procedure is applied to each stem.
Args:
config (dict): Dictionary of configuration options for enhancing music.
Returns 8 stems for each song:
- left channel vocal, drums, bass, and other stems
- right channel vocal, drums, bass, and other stems
"""
# Set the output directory where processed signals will be saved
enhanced_folder = Path("enhanced_signals")
enhanced_folder.mkdir(parents=True, exist_ok=True)
# Loading pretrained source separation model
if config.separator.model == "demucs":
separation_model = HDEMUCS_HIGH_MUSDB.get_model()
model_sample_rate = HDEMUCS_HIGH_MUSDB.sample_rate
sources_order = separation_model.sources
normalise = True
elif config.separator.model == "openunmix":
separation_model = torch.hub.load("sigsep/open-unmix-pytorch", "umxhq", niter=0)
model_sample_rate = separation_model.sample_rate
sources_order = ["vocals", "drums", "bass", "other"]
normalise = False
else:
raise ValueError(f"Separator model {config.separator.model} not supported.")
device, _ = get_device(config.separator.device)
separation_model.to(device)
# Load listener audiograms and songs
listener_dict = Listener.load_listener_dict(config.path.listeners_file)
#
with Path(config.path.gains_file).open("r", encoding="utf-8") as file:
gains = json.load(file)
with Path(config.path.scenes_file).open("r", encoding="utf-8") as file:
scenes = json.load(file)
with Path(config.path.scene_listeners_file).open("r", encoding="utf-8") as file:
scenes_listeners = json.load(file)
with Path(config.path.music_file).open("r", encoding="utf-8") as file:
songs = json.load(file)
enhancer = NALR(**config.nalr)
compressor = Compressor(**config.compressor)
# Select a batch to process
scene_listener_pairs = make_scene_listener_list(
scenes_listeners, config.evaluate.small_test
)
scene_listener_pairs = scene_listener_pairs[
config.evaluate.batch :: config.evaluate.batch_size
]
# Decompose each song into left and right vocal, drums, bass, and other stems
# and process each stem for the listener
previous_song = ""
num_scenes = len(scene_listener_pairs)
for idx, scene_listener_pair in enumerate(scene_listener_pairs, 1):
scene_id, listener_id = scene_listener_pair
scene = scenes[scene_id]
song_name = f"{scene['music']}-{scene['head_loudspeaker_positions']}"
logger.info(
f"[{idx:03d}/{num_scenes:03d}] "
f"Processing {scene_id}: {song_name} for listener {listener_id}"
)
# Get the listener's audiogram
listener = listener_dict[listener_id]
# Read the mixture signal
# Convert to 32-bit floating point and transpose
# from [samples, channels] to [channels, samples]
if song_name != previous_song:
mixture_signal = read_signal(
filename=Path(config.path.music_dir)
/ songs[song_name]["Path"]
/ "mixture.wav",
sample_rate=config.sample_rate,
allow_resample=True,
)
stems: dict[str, ndarray] = decompose_signal(
model=separation_model,
model_sample_rate=model_sample_rate,
signal=mixture_signal,
signal_sample_rate=config.sample_rate,
device=device,
sources_list=sources_order,
listener=listener,
normalise=normalise,
)
stems = apply_gains(stems, config.sample_rate, gains[scene["gain"]])
enhanced_signal = remix_stems(stems, mixture_signal, model_sample_rate)
enhanced_signal = process_remix_for_listener(
signal=enhanced_signal,
enhancer=enhancer,
compressor=compressor,
listener=listener,
apply_compressor=config.apply_compressor,
)
filename = Path(enhanced_folder) / f"{scene_id}_{listener.id}_remix.flac"
filename.parent.mkdir(parents=True, exist_ok=True)
save_flac_signal(
signal=enhanced_signal,
filename=filename,
signal_sample_rate=config.sample_rate,
output_sample_rate=config.remix_sample_rate,
do_clip_signal=True,
do_soft_clip=config.soft_clip,
)
logger.info("Done!")
# pylint: disable = no-value-for-parameter
if __name__ == "__main__":
enhance()