Source code for tests.recipes.cec1.e009_sheffield.test_train

"""Tests for cec1 e009 train module"""

import logging
from pathlib import Path

import hydra
import numpy as np
import pytest
import torch

from recipes.cec1.e009_sheffield.train import train_amp, train_den


[docs] @pytest.mark.slow def test_run(tmp_path): """Test for the run function.""" np.random.seed(0) torch.manual_seed(0) hydra.core.global_hydra.GlobalHydra.instance().clear() hydra.initialize( config_path="../../../../recipes/cec1/e009_sheffield", job_name="test_cec1_e009" ) logging.getLogger("pytorch_lightning").setLevel(logging.ERROR) hydra_cfg = hydra.compose( config_name="config", # Override settings to make a fast training test overrides=[ "path.cec1_root=tests/test_data/recipes/cec1/e009_sheffield", f"path.exp_folder={tmp_path}", # Disable multiprocessing for testing (faster) "train_loader.num_workers=0", "dev_loader.num_workers=0", "test_loader.num_workers=0", "train_loader.batch_size=1", "train_dataset.wav_sample_len=1.0", "den_trainer.epochs=1", "amp_trainer.epochs=1", "fir.nfir=32", "mc_conv_tasnet.H=64", "mc_conv_tasnet.B=32", # The validation sanity check step is slow, so disable it "amp_trainer.num_sanity_val_steps=0", ], ) train_den(hydra_cfg, ear="left") hydra_cfg.downsample_factor = 40 train_amp(hydra_cfg, ear="left") expected_files = [ "left_amp/checkpoints/epoch=0-step=1.ckpt", "left_amp/best_k_models.json", "left_amp/best_model.pth", "left_den/checkpoints/epoch=0-step=1.ckpt", "left_den/best_k_models.json", "left_den/best_model.pth", ] for filename in expected_files: assert (Path(tmp_path) / filename).exists()