"""Tests for the enhance module"""
# pylint: disable=import-error
from pathlib import Path
import numpy as np
import pytest
import torch
from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB
from clarity.enhancer.compressor import Compressor
from clarity.enhancer.nalr import NALR
from clarity.utils.audiogram import Audiogram, Listener
# pylint: disable=import-error, no-name-in-module
from recipes.cad1.task1.baseline.enhance import (
    apply_baseline_ha,
    clip_signal,
    decompose_signal,
    get_device,
    map_to_dict,
    process_stems_for_listener,
    separate_sources,
    to_16bit,
)
BASE_DIR = Path.cwd()
RESOURCES = BASE_DIR / "tests" / "resources" / "recipes" / "cad1" / "task1"
[docs]
def test_map_to_dict():
    """Test that the map_to_dict returns the expected mapping"""
    sources = np.array([[1, 2], [3, 4], [5, 6]])
    sources_list = ["a", "b", "c"]
    output = map_to_dict(sources, sources_list)
    expected_output = {
        "left_a": 1,
        "right_a": 2,
        "left_b": 3,
        "right_b": 4,
        "left_c": 5,
        "right_c": 6,
    }
    assert output == expected_output 
[docs]
@pytest.mark.xfail(reason="Github issue downloading from torch hub")
@pytest.mark.parametrize(
    "separation_model,normalise",
    [
        (pytest.param("demucs"), True),
        (pytest.param("openunmix", marks=pytest.mark.slow), True),
    ],
)
def test_decompose_signal(separation_model, normalise):
    """Takes a signal and decomposes it into VDBO sources using the HDEMUCS model"""
    np.random.seed(123456789)
    # Load Separation Model
    separation_model = separation_model.values[0]
    if separation_model == "demucs":
        model = HDEMUCS_HIGH_MUSDB.get_model().double()
    elif separation_model == "openunmix":
        model = torch.hub.load("sigsep/open-unmix-pytorch", "umxhq").double()
    else:
        raise ValueError(f"Unknown separation model: {separation_model}")
    device = torch.device("cpu")
    model.to(device)
    # Create a mock signal to decompose
    sample_rate = 44100
    duration = 0.5
    signal = np.random.uniform(size=(1, 2, int(sample_rate * duration)))
    # Call the decompose_signal function and check that the output has the expected keys
    cfs = np.array([250, 500, 1000, 2000, 4000, 6000, 8000, 9000, 10000])
    audiogram = Audiogram(levels=np.ones(9), frequencies=cfs)
    listener = Listener(audiogram, audiogram)
    output = decompose_signal(
        model=model,
        model_sample_rate=sample_rate,
        signal=signal,
        signal_sample_rate=sample_rate,
        device=device,
        sources_list=["drums", "bass", "other", "vocals"],
        listener=listener,
        normalise=normalise,
    )
    expected_results = np.load(
        RESOURCES / f"test_enhance.test_decompose_signal_{separation_model}.npy",
        allow_pickle=True,
    )[()]
    for key, item in output.items():
        np.testing.assert_array_almost_equal(item, expected_results[key]) 
[docs]
def test_apply_baseline_ha():
    """Test the behaviour of the CAD1 - Task1 - baseline hearing aid"""
    np.random.seed(987654321)
    # Create mock inputs
    signal = np.random.normal(size=44100)
    listener_audiogram = Audiogram(
        levels=np.ones(9),
        frequencies=np.array([250, 500, 1000, 2000, 4000, 6000, 8000, 9000, 10000]),
    )
    # Create mock objects for enhancer and compressor
    enhancer = NALR(nfir=220, sample_rate=44100)
    compressor = Compressor(
        threshold=0.35, attenuation=0.1, attack=50, release=1000, rms_buffer_size=0.064
    )
    # Call the apply_nalr function and check that the output is as expected
    output = apply_baseline_ha(enhancer, compressor, signal, listener_audiogram)
    expected_results = np.load(
        RESOURCES / "test_enhance.test_apply_baseline_ha.npy",
        allow_pickle=True,
    )
    np.testing.assert_array_almost_equal(output, expected_results) 
[docs]
def test_process_stems_for_listener():
    """Takes 2 stems and applies the baseline processing using a listeners audiograms"""
    np.random.seed(12357)
    # Create mock inputs
    stems = {
        "l_source1": np.random.normal(size=16000),
        "r_source1": np.random.normal(size=16000),
    }
    audiogram = Audiogram(
        levels=np.ones(9),
        frequencies=np.array([250, 500, 1000, 2000, 4000, 6000, 8000, 9000, 10000]),
    )
    listener = Listener(audiogram_left=audiogram, audiogram_right=audiogram)
    # Create mock objects for enhancer and compressor
    enhancer = NALR(nfir=220, sample_rate=16000)
    compressor = Compressor(
        threshold=0.35, attenuation=0.1, attack=50, release=1000, rms_buffer_size=0.064
    )
    # Call the process_stems_for_listener function and check output is as expected
    output_stems = process_stems_for_listener(
        stems, enhancer, compressor, listener=listener
    )
    expected_results = np.load(
        RESOURCES / "test_enhance.test_process_stems_for_listener.npy",
        allow_pickle=True,
    )[()]
    for key, item in output_stems.items():
        np.testing.assert_array_almost_equal(item, expected_results[key]) 
[docs]
def test_separate_sources():
    """Test that the separate_sources function returns the expected output"""
    np.random.seed(123456789)
    # Create a dummy model
    class DummyModel(torch.nn.Module):  # pylint: disable=too-few-public-methods
        """Dummy source separation model"""
        def __init__(self, sources):
            """dummy init"""
            super().__init__()
            self.sources = sources
        def forward(self, x):
            """dummy forward"""
            return torch.Tensor(
                np.random.uniform(size=(x.shape[0], len(self.sources), *x.shape[1:]))
            )
    # Set up some dummy input data
    batch_size = 1
    num_channels = 1
    length = 1
    sample_rate = 16000
    sources = ["vocals", "drums", "bass", "other"]
    mix = np.random.randn(batch_size, num_channels, length * sample_rate)
    device = torch.device("cpu")
    # Create a dummy model
    model = DummyModel(sources)
    # Call separate_sources
    output = separate_sources(model, mix, sample_rate, device=device)
    expected_results = np.load(
        RESOURCES / "test_enhance.test_separate_sources.npy",
        allow_pickle=True,
    )
    # Check that the output has the correct shape
    assert output.shape == expected_results.shape
    np.testing.assert_array_almost_equal(output, expected_results) 
[docs]
def test_get_device():
    """Test the correct device selection given the inputs"""
    # Test default case (no argument passed)
    device, device_type = get_device(None)
    assert (
        device == torch.device("cuda")
        if torch.cuda.is_available()
        else torch.device("cpu")
    )
    assert device_type == "cuda" if torch.cuda.is_available() else "cpu" 
[docs]
def test_to_16bit():
    # Generate a random signal
    signal = np.random.uniform(low=-1.0, high=1.0, size=50)
    signal_16bit = to_16bit(signal)
    assert np.all(np.abs(signal_16bit) <= 32768) 
[docs]
def test_clip_signal():
    # Generate a random signal
    np.random.seed(0)
    signal = np.random.uniform(low=-2.0, high=2.0, size=50)
    # Test with soft clipping
    clipped_signal, n_clipped = clip_signal(signal, soft_clip=True)
    assert max(np.abs(clipped_signal)) <= 1.0
    assert n_clipped == 0
    # Test without soft clipping
    clipped_signal, n_clipped = clip_signal(signal, soft_clip=False)
    assert max(np.abs(clipped_signal)) <= 1.0
    assert n_clipped == 22