"""Run the dummy enhancement."""
from __future__ import annotations
# pylint: disable=too-many-locals
# pylint: disable=import-error
import json
import logging
from pathlib import Path
import hydra
import numpy as np
import pandas as pd
import torch
from numpy import ndarray
from omegaconf import DictConfig
from scipy.io import wavfile
from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB
from torchaudio.transforms import Fade
from clarity.enhancer.compressor import Compressor
from clarity.enhancer.nalr import NALR
from clarity.utils.audiogram import Audiogram, Listener
from clarity.utils.flac_encoder import FlacEncoder
from clarity.utils.signal_processing import (
clip_signal,
denormalize_signals,
normalize_signal,
resample,
to_16bit,
)
from recipes.cad1.task1.baseline.evaluate import make_song_listener_list
logger = logging.getLogger(__name__)
[docs]
def separate_sources(
model: torch.nn.Module,
mix: torch.Tensor | ndarray,
sample_rate: int,
segment: float = 10.0,
overlap: float = 0.1,
device: torch.device | str | None = None,
):
"""
Apply model to a given mixture.
Use fade, and add segments together in order to add model segment by segment.
Args:
model (torch.nn.Module): model to use for separation
mix (torch.Tensor): mixture to separate, shape (batch, channels, time)
sample_rate (int): sampling rate of the mixture
segment (float): segment length in seconds
overlap (float): overlap between segments, between 0 and 1
device (torch.device, str, or None): if provided, device on which to
execute the computation, otherwise `mix.device` is assumed.
When `device` is different from `mix.device`, only local computations will
be on `device`, while the entire tracks will be stored on `mix.device`.
Returns:
torch.Tensor: estimated sources
Based on https://pytorch.org/audio/main/tutorials/hybrid_demucs_tutorial.html
"""
device = mix.device if device is None else torch.device(device)
mix = torch.as_tensor(mix, device=device)
if mix.ndim == 1:
# one track and mono audio
mix = mix.unsqueeze(0)
elif mix.ndim == 2:
# one track and stereo audio
mix = mix.unsqueeze(0)
batch, channels, length = mix.shape
chunk_len = int(sample_rate * segment * (1 + overlap))
start = 0
end = chunk_len
overlap_frames = overlap * sample_rate
fade = Fade(fade_in_len=0, fade_out_len=int(overlap_frames), fade_shape="linear")
final = torch.zeros(batch, 4, channels, length, device=device)
while start < length - overlap_frames:
chunk = mix[:, :, start:end]
with torch.no_grad():
out = model.forward(chunk)
out = fade(out)
final[:, :, :, start:end] += out
if start == 0:
fade.fade_in_len = int(overlap_frames)
start += int(chunk_len - overlap_frames)
else:
start += chunk_len
end += chunk_len
if end >= length:
fade.fade_out_len = 0
return final.cpu().detach().numpy()
[docs]
def get_device(device: str) -> tuple:
"""Get the Torch device.
Args:
device (str): device type, e.g. "cpu", "gpu0", "gpu1", etc.
Returns:
torch.device: torch.device() appropiate to the hardware available.
str: device type selected, e.g. "cpu", "cuda".
"""
if device is None:
if torch.cuda.is_available():
return torch.device("cuda"), "cuda"
return torch.device("cpu"), "cpu"
if device.startswith("gpu"):
device_index = int(device.replace("gpu", ""))
if device_index > torch.cuda.device_count():
raise ValueError(f"GPU device index {device_index} is not available.")
return torch.device(f"cuda:{device_index}"), "cuda"
if device == "cpu":
return torch.device("cpu"), "cpu"
raise ValueError(f"Unsupported device type: {device}")
[docs]
def map_to_dict(sources: ndarray, sources_list: list[str]) -> dict:
"""Map sources to a dictionary separating audio into left and right channels.
Args:
sources (ndarray): Signal to be mapped to dictionary.
sources_list (list): List of strings used to index dictionary.
Returns:
Dictionary: A dictionary of separated source audio split into channels.
"""
audios = dict(zip(sources_list, sources))
signal_stems = {}
for source in sources_list:
audio = audios[source]
signal_stems[f"left_{source}"] = audio[0]
signal_stems[f"right_{source}"] = audio[1]
return signal_stems
# pylint: disable=unused-argument
[docs]
def decompose_signal(
model: torch.nn.Module,
model_sample_rate: int,
signal: ndarray,
signal_sample_rate: int,
device: torch.device,
sources_list: list[str],
listener: Listener,
normalise: bool = True,
) -> dict[str, ndarray]:
"""
Decompose signal into 8 stems.
The left and right audiograms are ignored by the baseline system as it
is performing personalised decomposition.
Instead, it performs a standard music decomposition using the
HDEMUCS model trained on the MUSDB18 dataset.
Args:
model (torch.nn.Module): Torch model.
model_sample_rate (int): Sample rate of the model.
signal (ndarray): Signal to be decomposed.
signal_sample_rate (int): Sample frequency.
device (torch.device): Torch device to use for processing.
sources_list (list): List of strings used to index dictionary.
listener (Listener): Listener object.
normalise (bool): Whether to normalise the signal.
Returns:
Dictionary: Indexed by sources with the associated model as values.
"""
# Resample mixture signal to model sample rate
if signal_sample_rate != model_sample_rate:
signal = resample(signal, signal_sample_rate, model_sample_rate)
if normalise:
signal, ref = normalize_signal(signal)
sources = separate_sources(
model, torch.from_numpy(signal), signal_sample_rate, device=device
)
# only one element in the batch
sources = sources[0]
if normalise:
sources = denormalize_signals(sources, ref)
signal_stems = map_to_dict(sources, sources_list)
return signal_stems
[docs]
def apply_baseline_ha(
enhancer: NALR,
compressor: Compressor,
signal: ndarray,
audiogram: Audiogram,
apply_compressor: bool = False,
) -> ndarray:
"""
Apply NAL-R prescription hearing aid to a signal.
Args:
enhancer: A NALR object that enhances the signal.
compressor: A Compressor object that compresses the signal.
signal: An ndarray representing the audio signal.
audiogram: An Audiogram object representing the listener's audiogram.
apply_compressor: A boolean indicating whether to include the compressor.
Returns:
An ndarray representing the processed signal.
"""
nalr_fir, _ = enhancer.build(audiogram)
proc_signal = enhancer.apply(nalr_fir, signal)
if apply_compressor:
proc_signal, _, _ = compressor.process(proc_signal)
return proc_signal
[docs]
def process_stems_for_listener(
stems: dict,
enhancer: NALR,
compressor: Compressor,
listener: Listener,
apply_compressor: bool = False,
) -> dict:
"""Process the stems from sources.
Args:
stems (dict) : Dictionary of stems
enhancer (NALR) : NAL-R prescription hearing aid
compressor (Compressor) : Compressor
listener (Listener) : Listener object.
apply_compressor (bool) : Whether to apply the compressor
Returns:
processed_sources (dict) : Dictionary of processed stems
"""
processed_stems = {}
for stem_str in stems:
stem_signal = stems[stem_str]
# Determine the audiogram to use
audiogram = (
listener.audiogram_left
if stem_str.startswith("l")
else listener.audiogram_right
)
# Apply NALR prescription to stem_signal
proc_signal = apply_baseline_ha(
enhancer, compressor, stem_signal, audiogram, apply_compressor
)
processed_stems[stem_str] = proc_signal
return processed_stems
[docs]
def remix_signal(stems: dict) -> ndarray:
"""
Function to remix signal. It takes the eight stems
and combines them into a stereo signal.
Args:
stems (dict) : Dictionary of stems
Returns:
(ndarray) : Remixed signal
"""
n_samples = stems[list(stems.keys())[0]].shape[0]
out_left, out_right = np.zeros(n_samples), np.zeros(n_samples)
for stem_str, stem_signal in stems.items():
if stem_str.startswith("l"):
out_left += stem_signal
else:
out_right += stem_signal
return np.stack([out_left, out_right], axis=1)
[docs]
def save_flac_signal(
signal: ndarray,
filename: Path,
signal_sample_rate: int,
output_sample_rate: int,
do_clip_signal: bool = False,
do_soft_clip: bool = False,
do_scale_signal: bool = False,
) -> None:
"""
Function to save output signals.
- The output signal will be resample to ``output_sample_rate``
- The output signal will be clipped to [-1, 1] if ``do_clip_signal`` is True
and use soft clipped if ``do_soft_clip`` is True. Note that if
``do_clip_signal`` is False, ``do_soft_clip`` will be ignored.
Note that if ``do_clip_signal`` is True, ``do_scale_signal`` will be ignored.
- The output signal will be scaled to [-1, 1] if ``do_scale_signal`` is True.
If signal is scale, the scale factor will be saved in a TXT file.
Note that if ``do_clip_signal`` is True, ``do_scale_signal`` will be ignored.
- The output signal will be saved as a FLAC file.
Args:
signal (np.ndarray) : Signal to save
filename (Path) : Path to save signal
signal_sample_rate (int) : Sample rate of the input signal
output_sample_rate (int) : Sample rate of the output signal
do_clip_signal (bool) : Whether to clip signal
do_soft_clip (bool) : Whether to apply soft clipping
do_scale_signal (bool) : Whether to scale signal
"""
# Resample signal to expected output sample rate
if signal_sample_rate != output_sample_rate:
signal = resample(signal, signal_sample_rate, output_sample_rate)
if do_scale_signal:
# Scale stem signal
max_value = np.max(np.abs(signal))
signal = signal / max_value
# Save scale factor
with open(filename.with_suffix(".txt"), "w", encoding="utf-8") as file:
file.write(f"{max_value}")
elif do_clip_signal:
# Clip the signal
signal, n_clipped = clip_signal(signal, do_soft_clip)
if n_clipped > 0:
logger.warning(f"Writing {filename}: {n_clipped} samples clipped")
# Convert signal to 16-bit integer
signal = to_16bit(signal)
# Create flac encoder object to compress and save the signal
FlacEncoder().encode(signal, output_sample_rate, filename)
[docs]
@hydra.main(config_path="", config_name="config", version_base=None)
def enhance(config: DictConfig) -> None:
"""
Run the music enhancement.
The system decomposes the music into vocal, drums, bass, and other stems.
Then, the NAL-R prescription procedure is applied to each stem.
Args:
config (dict): Dictionary of configuration options for enhancing music.
Returns 8 stems for each song:
- left channel vocal, drums, bass, and other stems
- right channel vocal, drums, bass, and other stems
"""
if config.separator.model not in ["demucs", "openunmix"]:
raise ValueError(f"Separator model {config.separator.model} not supported.")
enhanced_folder = Path("enhanced_signals")
enhanced_folder.mkdir(parents=True, exist_ok=True)
if config.separator.model == "demucs":
separation_model = HDEMUCS_HIGH_MUSDB.get_model()
model_sample_rate = HDEMUCS_HIGH_MUSDB.sample_rate
sources_order = separation_model.sources
normalise = True
elif config.separator.model == "openunmix":
separation_model = torch.hub.load("sigsep/open-unmix-pytorch", "umxhq", niter=0)
model_sample_rate = separation_model.sample_rate
sources_order = ["vocals", "drums", "bass", "other"]
normalise = False
else:
raise ValueError(f"Separator model {config.separator.model} not supported.")
device, _ = get_device(config.separator.device)
separation_model.to(device)
# Processing Validation Set
# Load listener audiograms and songs
listener_dict = Listener.load_listener_dict(config.path.listeners_file)
with open(config.path.music_file, encoding="utf-8") as file:
song_data = json.load(file)
songs_df = pd.DataFrame.from_dict(song_data)
song_listener_pairs = make_song_listener_list(songs_df["Track Name"], listener_dict)
# Select a batch to process
song_listener_pairs = song_listener_pairs[
config.evaluate.batch :: config.evaluate.batch_size
]
enhancer = NALR(**config.nalr)
compressor = Compressor(**config.compressor)
# Decompose each song into left and right vocal, drums, bass, and other stems
# and process each stem for the listener
prev_song_name = None
stems: dict[str, ndarray] = {}
num_song_list_pair = len(song_listener_pairs)
for idx, song_listener in enumerate(song_listener_pairs, 1):
song_name, listener_name = song_listener
logger.info(
f"[{idx:03d}/{num_song_list_pair:03d}] "
f"Processing {song_name} for {listener_name}..."
)
# Get the listener's audiogram
listener = listener_dict[listener_name]
# Find the music split directory
split_directory = (
"test"
if songs_df.loc[songs_df["Track Name"] == song_name, "Split"].iloc[0]
== "test"
else "train"
)
# Baseline Steps
# 1. Decompose the mixture signal into vocal, drums, bass, and other stems
# We validate if 2 consecutive signals are the same to avoid
# decomposing the same song multiple times
if prev_song_name != song_name:
# Decompose song only once
prev_song_name = song_name
sample_rate, mixture_signal = wavfile.read(
Path(config.path.music_dir)
/ split_directory
/ song_name
/ "mixture.wav"
)
mixture_signal = (mixture_signal / 32768.0).astype(np.float32).T
assert sample_rate == config.sample_rate
stems = decompose_signal(
separation_model,
model_sample_rate,
mixture_signal,
sample_rate,
device,
sources_order,
listener,
normalise,
)
# 2. Apply NAL-R prescription to each stem
# Baseline applies NALR prescription to each stem instead of using the
# listener's audiograms in the decomposition. This step can be skipped
# if the listener's audiograms are used in the decomposition
processed_stems = process_stems_for_listener(
stems,
enhancer,
compressor,
listener,
config.apply_compressor,
)
# 3. Save processed stems
for stem_str, stem_signal in processed_stems.items():
filename = (
enhanced_folder
/ f"{listener.id}"
/ f"{song_name}"
/ f"{listener.id}_{song_name}_{stem_str}.flac"
)
filename.parent.mkdir(parents=True, exist_ok=True)
save_flac_signal(
signal=stem_signal,
filename=filename,
signal_sample_rate=config.sample_rate,
output_sample_rate=config.stem_sample_rate,
do_scale_signal=True,
)
# 4. Remix Signal
enhanced = remix_signal(processed_stems)
# 5. Save enhanced (remixed) signal
filename = (
enhanced_folder
/ f"{listener.id}"
/ f"{song_name}"
/ f"{listener.id}_{song_name}_remix.flac"
)
save_flac_signal(
signal=enhanced,
filename=filename,
signal_sample_rate=config.sample_rate,
output_sample_rate=config.remix_sample_rate,
do_clip_signal=True,
do_soft_clip=config.soft_clip,
)
# pylint: disable = no-value-for-parameter
if __name__ == "__main__":
enhance()