Source code for tests.recipes.cad_icassp_2024.baseline.test_enhance

"""Tests for the enhance module"""

# pylint:: disable=import-error
from pathlib import Path

import numpy as np
import pytest
import torch
from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB

from clarity.enhancer.compressor import Compressor
from clarity.enhancer.nalr import NALR
from clarity.utils.audiogram import Audiogram, Listener
from clarity.utils.flac_encoder import read_flac_signal
from recipes.cad_icassp_2024.baseline.enhance import (
    decompose_signal,
    process_remix_for_listener,
    save_flac_signal,
)

BASE_DIR = Path.cwd()
RESOURCES = BASE_DIR / "tests" / "resources" / "recipes" / "cad_icassp_2024"


[docs] def test_save_flac_signal(tmp_path): """Test save flac signal""" np.random.seed(2024) sample_rate = 44100 duration = 0.5 signal = np.random.rand(int(sample_rate * duration)) filename = Path(tmp_path) / "signal.flac" save_flac_signal(signal, filename, sample_rate, sample_rate) signal_out, sample_rate_out = read_flac_signal(filename) assert np.sum(signal) == pytest.approx(11040.050741283) assert np.sum(signal_out) == pytest.approx(11039.716) assert sample_rate_out == sample_rate
[docs] @pytest.mark.parametrize( "separation_model", [ pytest.param("demucs"), pytest.param("openunmix", marks=pytest.mark.slow), ], ) def test_decompose_signal(separation_model): """Takes a signal and decomposes it into VDBO sources using the HDEMUCS model""" np.random.seed(2024) # Load Separation Model if separation_model == "demucs": model = HDEMUCS_HIGH_MUSDB.get_model() model_sample_rate = HDEMUCS_HIGH_MUSDB.sample_rate sources_order = model.sources elif separation_model == "openunmix": model = torch.hub.load("sigsep/open-unmix-pytorch", "umxhq") model_sample_rate = model.sample_rate sources_order = ["vocals", "drums", "bass", "other"] else: raise ValueError("Invalid separation model") device = torch.device("cpu") model.to(device) # Create a mock signal to decompose sample_rate = 44100 duration = 0.5 signal = np.random.uniform(size=(1, 2, int(sample_rate * duration))).astype( np.float32 ) # Call the decompose_signal function and check that the output has the expected keys cfs = np.array([250, 500, 1000, 2000, 4000, 6000, 8000, 9000, 10000]) audiogram = Audiogram(levels=np.ones(9), frequencies=cfs) listener = Listener(audiogram, audiogram) output = decompose_signal( model, model_sample_rate, signal, sample_rate, device, sources_order, listener, ) expected_results = np.load( RESOURCES / f"test_enhance.test_decompose_signal_{separation_model}.npy", allow_pickle=True, )[()] for key, item in output.items(): np.testing.assert_array_almost_equal(item, expected_results[key])
[docs] @pytest.mark.parametrize( "apply_compressor", [ True, False, ], ) def test_process_remix_for_listener(apply_compressor): """Test the process remix for listener""" np.random.seed(2024) sample_rate = 44100 duration = 0.5 signal = np.random.uniform(size=(2, int(duration * sample_rate))) audiogram = Audiogram( levels=np.ones(9), frequencies=np.array([250, 500, 1000, 2000, 4000, 6000, 8000, 9000, 10000]), ) listener = Listener(audiogram_left=audiogram, audiogram_right=audiogram) enhancer = NALR(nfir=220, sample_rate=16000) compressor = Compressor( threshold=0.35, attenuation=0.1, attack=50, release=1000, rms_buffer_size=0.064 ) output = process_remix_for_listener( signal, enhancer, compressor, listener, apply_compressor=apply_compressor ) if apply_compressor: expected_results = np.load( RESOURCES / "test_enhance.test_process_remix_for_listener_w_compressor.npy", allow_pickle=True, )[()] else: expected_results = np.load( RESOURCES / "test_enhance.test_process_remix_for_listener_wo_compressor.npy", allow_pickle=True, )[()] assert np.sum(output) == pytest.approx(np.sum(expected_results))