"""Tests for the CPC2 predict functions."""
import warnings
from csv import DictReader
from pathlib import Path
import hydra
import numpy as np
import pandas as pd
import pytest
from clarity.utils.file_io import write_jsonl
from recipes.cpc2.baseline.predict import (
LogisticModel,
make_disjoint_train_set,
predict,
)
# pylint: disable=redefined-outer-name
[docs]
@pytest.fixture
def model():
"""Return a LogisticModel instance."""
model = LogisticModel()
model.fit(np.array([0, 1, 2, 3, 4]), np.array([0, 25, 50, 75, 100]))
return model
# pylint: disable=redefined-outer-name
[docs]
@pytest.mark.parametrize(
"model, value", [(model, 0.0), (model, 1.0), (model, 2.0)], indirect=["model"]
)
def test_logistic_model_symmetry(model: LogisticModel, value):
"""Test the LogisticModel is symmetric."""
symmetric_value = 4 - value
assert model.predict(value) + model.predict(symmetric_value) == pytest.approx(
100.0, rel=pytest.rel_tolerance, abs=pytest.abs_tolerance
)
# pylint: disable=redefined-outer-name
[docs]
@pytest.mark.parametrize(
"model, value",
[(model, -100.0), (model, -200.0), (model, 100.0), (model, 200.0)],
indirect=["model"],
)
def test_logistic_model_extremes(model, value):
"""Test the LogisticModel class ."""
# logistic_model must asymptote to 0 and 100 for extreme values
if value > 10:
assert model.predict(value) == pytest.approx(
100, rel=pytest.rel_tolerance, abs=pytest.abs_tolerance
)
elif value < -10:
assert model.predict(value) == pytest.approx(
0, rel=pytest.rel_tolerance, abs=pytest.abs_tolerance
)
[docs]
@pytest.mark.parametrize(
"data_1, data_2, expected",
[
(
{"signal": "S100", "system": "E100", "listener": "L100"},
{"signal": "S100", "system": "E100", "listener": "L100"},
0,
),
(
{"signal": "S100", "system": "E100", "listener": "L100"},
{"signal": "S100", "system": "E101", "listener": "L100"},
0,
),
(
{"signal": "S100", "system": "E100", "listener": "L100"},
{"signal": "S101", "system": "E101", "listener": "L101"},
1,
),
],
)
def test_make_disjoint_train_set_empty(data_1, data_2, expected):
"""Test the make_disjoint_train_set function."""
test_df1 = pd.DataFrame([data_1])
test_df2 = pd.DataFrame([data_2])
disjoint = make_disjoint_train_set(test_df1, test_df2)
assert disjoint.shape[0] == pytest.approx(
expected, rel=pytest.rel_tolerance, abs=pytest.abs_tolerance
)
[docs]
@pytest.fixture()
def hydra_cfg():
"""Fixture for hydra config."""
hydra.core.global_hydra.GlobalHydra.instance().clear()
hydra.initialize(
config_path="../../../../recipes/cpc2/baseline",
job_name="test_cpc2",
)
cfg = hydra.compose(
config_name="config",
overrides=[
"path.clarity_data_dir=tests/test_data/recipes/cpc2",
"dataset=CEC1.train.sample",
],
)
return cfg
[docs]
def test_predict(hydra_cfg):
"""Test predict function."""
expected_results = [
("S08547_L0001_E001", 0.0),
("S08564_L0001_E001", 0.0),
("S08564_L0002_E002", 31.481621447245452),
("S08564_L0003_E003", 31.481621447245452),
]
haspi_scores = [
{"signal": "S08547_L0001_E001", "haspi": 0.8},
{"signal": "S08564_L0001_E001", "haspi": 0.8},
{"signal": "S08564_L0002_E002", "haspi": 0.8},
{"signal": "S08564_L0003_E003", "haspi": 0.8},
]
haspi_score_file = "CEC1.train.sample.haspi.jsonl"
write_jsonl(haspi_score_file, haspi_scores)
# Run predict, ignoring warning due to unreal data
warnings.simplefilter("ignore", category=RuntimeWarning)
predict(hydra_cfg)
# Check output
expected_output_file = "CEC1.train.sample.predict.csv"
with open(expected_output_file, encoding="utf-8") as f:
results = list(DictReader(f))
results_index = {
entry["signal_ID"]: float(entry["intelligibility_score"]) for entry in results
}
# TODO: Scores are not checked because they can be very different
# depending on the machine. This doesn't really matter for now as I believe
# it's just a consequence of using just 4 samples in the testing data.
# The fitting functions are tested separately.
for signal, _expected_score in expected_results:
assert signal in results_index
# print(results_index[signal], expected_score)
# assert results_index[signal] == pytest.approx(expected_score)
# Clean up
Path(expected_output_file).unlink()
Path(haspi_score_file).unlink()