"""Baseline enhancement for CAD2 task1."""
from __future__ import annotations
import json
import logging
from pathlib import Path
import hydra
import numpy as np
import torch
from numpy import ndarray
from omegaconf import DictConfig
from torchaudio.transforms import Fade
from clarity.enhancer.multiband_compressor import MultibandCompressor
from clarity.utils.flac_encoder import read_flac_signal, save_flac_signal
from recipes.cad2.task1.baseline.evaluate import (
make_scene_listener_list,
normalise_luft,
)
from recipes.cad2.task1.ConvTasNet.local.tasnet import ConvTasNetStereo
logging.captureWarnings(True)
logger = logging.getLogger(__name__)
[docs]
def separate_sources(
model: torch.nn.Module,
mix: torch.Tensor | ndarray,
sample_rate: int,
segment: float = 10.0,
overlap: float = 0.1,
number_sources: int = 4,
device: torch.device | str | None = None,
):
"""
Apply model to a given mixture.
Use fade, and add segments together in order to add model segment by segment.
Args:
model (torch.nn.Module): model to use for separation
mix (torch.Tensor): mixture to separate, shape (batch, channels, time)
sample_rate (int): sampling rate of the mixture
segment (float): segment length in seconds
overlap (float): overlap between segments, between 0 and 1
number_sources (int): number of sources to separate
device (torch.device, str, or None): if provided, device on which to
execute the computation, otherwise `mix.device` is assumed.
When `device` is different from `mix.device`, only local computations will
be on `device`, while the entire tracks will be stored on `mix.device`.
Returns:
torch.Tensor: estimated sources
Based on https://pytorch.org/audio/main/tutorials/hybrid_demucs_tutorial.html
"""
device = mix.device if device is None else torch.device(device)
mix = torch.as_tensor(mix, dtype=torch.float, device=device)
if mix.ndim == 1:
# one track and mono audio
mix = mix.unsqueeze(0).unsqueeze(0)
elif mix.ndim == 2:
# one track and stereo audio
mix = mix.unsqueeze(0)
batch, channels, length = mix.shape
chunk_len = int(sample_rate * segment * (1 + overlap))
start = 0
end = chunk_len
overlap_frames = overlap * sample_rate
fade = Fade(fade_in_len=0, fade_out_len=int(overlap_frames), fade_shape="linear")
final = torch.zeros(batch, number_sources, channels, length, device=device)
while start < length - overlap_frames:
chunk = mix[:, :, start:end]
with torch.no_grad():
out = model.forward(chunk)
out = fade(out)
final[:, :, :, start:end] += out
if start == 0:
fade.fade_in_len = int(overlap_frames)
start += int(chunk_len - overlap_frames)
else:
start += chunk_len
end += chunk_len
if end >= length:
fade.fade_out_len = 0
return final
[docs]
def get_device(device: str) -> tuple:
"""Get the Torch device.
Args:
device (str): device type, e.g. "cpu", "gpu0", "gpu1", etc.
Returns:
torch.device: torch.device() appropiate to the hardware available.
str: device type selected, e.g. "cpu", "cuda".
"""
if device is None:
if torch.cuda.is_available():
return torch.device("cuda"), "cuda"
return torch.device("cpu"), "cpu"
if device.startswith("gpu"):
device_index = int(device.replace("gpu", ""))
if device_index > torch.cuda.device_count():
raise ValueError(f"GPU device index {device_index} is not available.")
return torch.device(f"cuda:{device_index}"), "cuda"
if device == "cpu":
return torch.device("cpu"), "cpu"
raise ValueError(f"Unsupported device type: {device}")
[docs]
def load_separation_model(causality: str, device: torch.device) -> ConvTasNetStereo:
"""
Load the separation model.
Args:
causality (str): Causality of the model (causal or noncausal).
device (torch.device): Device to load the model.
Returns:
model: Separation model.
"""
if causality == "causal":
model = ConvTasNetStereo.from_pretrained(
"cadenzachallenge/ConvTasNet_LyricsSeparation_Causal",
force_download=True,
).to(device)
else:
model = ConvTasNetStereo.from_pretrained(
"cadenzachallenge/ConvTasNet_LyricsSeparation_NonCausal"
).to(device)
return model
[docs]
def downmix_signal(
vocals: ndarray,
accompaniment: ndarray,
beta: float,
) -> ndarray:
"""
Downmix the vocals and accompaniment to stereo.
Args:
vocals (np.ndarray): Vocal signal.
accompaniment (np.ndarray): Accompaniment signal.
beta (float): Downmix parameter.
Returns:
np.ndarray: Downmixed signal.
Notes:
When beta is 0, the downmix is the accompaniment.
When beta is 1, the downmix is the vocals.
"""
# Vocals +1Db, Accompaniment -1Db
# vocal amplification
vamp = beta**2 + 1
# accompaniment amplification
aamp = 2 - vamp
return vocals * vamp + accompaniment * aamp
[docs]
@hydra.main(config_path="", config_name="config", version_base=None)
def enhance(config: DictConfig) -> None:
"""
Run the music enhancement.
The system decomposes the music into vocals and accompaniment.
Then, vocals are enhanced according to alpha values.
Finally, the music is amplified according hearing loss and downmix to stereo.
Args:
config (dict): Dictionary of configuration options for enhancing music.
"""
if config.separator.causality not in ["causal", "noncausal"]:
raise ValueError(
f"Causality must be causal or noncausal, {config.separator.causality} was"
" provided."
)
device, _ = get_device(config.separator.device)
# Set folder to save the enhanced music
enhanced_folder = Path("enhanced_signals")
enhanced_folder.mkdir(parents=True, exist_ok=True)
# Load listener dictionary
# To load the metadata of all listeners
# listener_dict = Listener.load_listener_dict(config.path.listeners_file)
# Load alphas
with Path(config.path.alphas_file).open("r", encoding="utf-8") as file:
alphas = json.load(file)
# Load scenes
with Path(config.path.scenes_file).open("r", encoding="utf-8") as file:
scenes = json.load(file)
# Load scene-listeners
with Path(config.path.scene_listeners_file).open("r", encoding="utf-8") as file:
scenes_listeners = json.load(file)
# Load songs
with Path(config.path.musics_file).open("r", encoding="utf-8") as file:
songs = json.load(file)
# Load compressor params
with Path(config.path.enhancer_params_file).open("r", encoding="utf-8") as file:
enhancer_params = json.load(file)
# Load separation model
separation_model = load_separation_model(config.separator.causality, device)
# create hearing aid
enhancer = MultibandCompressor(
crossover_frequencies=config.enhancer.crossover_frequencies,
sample_rate=config.input_sample_rate,
)
# Make the list of scene-listener pairings to process
scene_listener_pairs = make_scene_listener_list(
scenes_listeners, config.evaluate.small_test
)
scene_listener_pairs = scene_listener_pairs[
config.evaluate.batch :: config.evaluate.batch_size
]
# Process each scene-listener pair
for idx, scene_listener_ids in enumerate(scene_listener_pairs, 1):
logger.info(
f"[{idx:04d}/{len(scene_listener_pairs):04d}] Processing scene-listener"
f" pair: {scene_listener_ids}"
)
scene_id, listener_id = scene_listener_ids
scene = scenes[scene_id]
# This recipe is not using the listener metadata
# But you can load it as follows:
# listener = listener_dict[listener_id]
alpha = alphas[scene["alpha"]]
# Load the music
input_mixture, input_sample_rate = read_flac_signal(
Path(config.path.music_dir)
/ songs[scene["segment_id"]]["path"]
/ "mixture.flac"
)
start_sample = int(
songs[scene["segment_id"]]["start_time"] * config.input_sample_rate
)
end_time = int(
(songs[scene["segment_id"]]["end_time"]) * config.input_sample_rate
)
input_mixture = input_mixture[start_sample:end_time, :]
assert input_sample_rate == config.input_sample_rate
# normalise input mixture to -40 dB LUFS
input_mixture = normalise_luft(
input_mixture, config.input_sample_rate, target_luft=-40
)
# Separate the music
est_sources = separate_sources(
separation_model,
input_mixture.T,
device=device,
**config.separator.separation,
)
vocals, accompaniment = est_sources.squeeze(0).cpu().detach().numpy()
# Get the listener's compressor params
mbc_params_listener: dict[str, dict] = {"left": {}, "right": {}}
for ear in ["left", "right"]:
mbc_params_listener[ear]["release"] = config.enhancer.release
mbc_params_listener[ear]["attack"] = config.enhancer.attack
mbc_params_listener[ear]["threshold"] = config.enhancer.threshold
mbc_params_listener["left"]["ratio"] = enhancer_params[listener_id]["cr_l"]
mbc_params_listener["right"]["ratio"] = enhancer_params[listener_id]["cr_r"]
mbc_params_listener["left"]["makeup_gain"] = enhancer_params[listener_id][
"gain_l"
]
mbc_params_listener["right"]["makeup_gain"] = enhancer_params[listener_id][
"gain_r"
]
# Downmix to stereo
enhanced_signal = downmix_signal(vocals, accompaniment, beta=alpha)
# Apply Amplification
enhancer.set_compressors(**mbc_params_listener["left"])
left_enhanced = enhancer(signal=enhanced_signal[0, :])
enhancer.set_compressors(**mbc_params_listener["right"])
right_enhanced = enhancer(signal=enhanced_signal[1, :])
enhanced_signal = np.stack((left_enhanced[0], right_enhanced[0]), axis=1)
# Save the enhanced music
filename = enhanced_folder / f"{scene_id}_{listener_id}_A{alpha}_remix.flac"
filename.parent.mkdir(parents=True, exist_ok=True)
save_flac_signal(
signal=enhanced_signal,
filename=filename,
signal_sample_rate=config.input_sample_rate,
output_sample_rate=config.remix_sample_rate,
do_clip_signal=True,
do_soft_clip=config.soft_clip,
)
# pylint: disable = no-value-for-parameter
if __name__ == "__main__":
enhance()
logger.info("Enhancement completed.")