Source code for recipes.cec1.e009_sheffield.train

import json
import logging
from pathlib import Path

import hydra
import numpy as np
import pytorch_lightning as pl
import torch
import torchaudio
from omegaconf import DictConfig
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader

from clarity.dataset.cec1_dataset import CEC1Dataset
from clarity.engine.losses import SNRLoss, STOILevelLoss
from clarity.engine.system import System
from clarity.enhancer.dnn.mc_conv_tasnet import ConvTasNet
from clarity.enhancer.dsp.filter import AudiometricFIR
from clarity.predictor.torch_msbg import MSBGHearingModel

logger = logging.getLogger(__name__)


[docs] class DenModule(System): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.ear_idx = None self.down_sample = None
[docs] def common_step(self, batch, batch_nb, train=True): if self.down_sample is None: raise RuntimeError("Hearing model not loaded") proc, ref = batch ref = ref[:, self.ear_idx, :] if self.config.downsample_factor != 1: proc = self.down_sample(proc) ref = self.down_sample(ref) enhanced = self.model(proc).squeeze(1) loss = self.loss_func(enhanced, ref) return loss
[docs] class AmpModule(System): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.hl_ear = None self.nh_ear = None self.down_sample = None self.up_sample = None self.ear_idx = None self.den_model = None
[docs] def common_step(self, batch, batch_nb, train=True): if ( self.hl_ear is None or self.nh_ear is None or self.down_sample is None or self.up_sample is None or self.den_model is None ): raise RuntimeError("Hearing model not loaded") proc, ref = batch ref = ref[:, self.ear_idx, :] if self.config.downsample_factor != 1: proc = self.down_sample(proc) ref = self.down_sample(ref) enhanced = self.model(self.den_model(proc)).squeeze(1) if self.config.downsample_factor != 1: enhanced = torch.clamp(self.up_sample(enhanced), -1, 1) ref = torch.clamp(self.up_sample(ref), -1, 1) sim_ref = self.nh_ear(ref) sim_enhanced = self.hl_ear(enhanced) loss = self.loss_func(sim_enhanced, sim_ref) return loss
[docs] def train_den(cfg, ear): exp_dir = Path(cfg.path.exp_folder) / f"{ear}_den" if (exp_dir / "best_model.pth").exists(): logger.info("Enhancement module exist") return train_set = CEC1Dataset(**cfg.train_dataset) train_loader = DataLoader(dataset=train_set, **cfg.train_loader) dev_set = CEC1Dataset(**cfg.dev_dataset) dev_loader = DataLoader(dataset=dev_set, **cfg.dev_loader) den_model = ConvTasNet(**cfg.mc_conv_tasnet) optimizer = torch.optim.Adam( params=den_model.parameters(), **cfg.den_trainer.optimizer ) loss_func = SNRLoss() den_module = DenModule( model=den_model, loss_func=loss_func, optimizer=optimizer, train_loader=train_loader, val_loader=dev_loader, config=cfg, ) den_module.ear_idx = 0 if ear == "left" else 1 if cfg.downsample_factor != 1: den_module.down_sample = torchaudio.transforms.Resample( orig_freq=cfg.sample_rate, new_freq=cfg.sample_rate // cfg.downsample_factor, resampling_method="sinc_interp_hann", ) # callbacks callbacks = [] checkpoint_dir = exp_dir / "checkpoints/" checkpoint = ModelCheckpoint( str(checkpoint_dir), monitor="val_loss", mode="min", save_top_k=5, verbose=True ) callbacks.append(checkpoint) # set device gpus = -1 if torch.cuda.is_available() else None trainer = pl.Trainer( max_epochs=cfg.den_trainer.epochs, callbacks=callbacks, default_root_dir=str(exp_dir), gpus=gpus, limit_train_batches=1.0, # Useful for fast experiment gradient_clip_val=cfg.den_trainer.gradient_clip_val, ) trainer.fit(den_module) best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()} with (exp_dir / "best_k_models.json").open("w", encoding="utf-8") as fp: json.dump(best_k, fp, indent=0) state_dict = torch.load(checkpoint.best_model_path) den_module.load_state_dict(state_dict=state_dict["state_dict"]) den_module.cpu() torch.save(den_module.model.state_dict(), str(exp_dir / "best_model.pth"))
[docs] def train_amp(cfg, ear): exp_dir = Path(cfg.path.exp_folder) / f"{ear}_amp" Path.mkdir(exp_dir, parents=True, exist_ok=True) if (exp_dir / "best_model.pth").exists(): logger.info("Amplification module exist") return train_set = CEC1Dataset(**cfg.train_dataset) train_loader = DataLoader(dataset=train_set, **cfg.train_loader) dev_set = CEC1Dataset(**cfg.dev_dataset) dev_loader = DataLoader(dataset=dev_set, **cfg.dev_loader) # load denoising module den_model = ConvTasNet(**cfg.mc_conv_tasnet) den_model_path = exp_dir / ".." / f"{ear}_den/best_model.pth" den_model.load_state_dict(torch.load(den_model_path)) # amplification module amp_model = AudiometricFIR(**cfg.fir) optimizer = torch.optim.Adam( params=amp_model.parameters(), **cfg.amp_trainer.optimizer ) loss_func = STOILevelLoss(**cfg.amp_trainer.stoilevel_loss) amp_module = AmpModule( model=amp_model, loss_func=loss_func, optimizer=optimizer, train_loader=train_loader, val_loader=dev_loader, config=cfg, ) amp_module.ear_idx = 0 if ear == "left" else 1 amp_module.den_model = den_model if cfg.downsample_factor != 1: amp_module.down_sample = torchaudio.transforms.Resample( orig_freq=cfg.sr, new_freq=cfg.sr // cfg.downsample_factor, resampling_method="sinc_interp_hann", ) amp_module.up_sample = torchaudio.transforms.Resample( orig_freq=cfg.sr // cfg.downsample_factor, new_freq=cfg.sr, resampling_method="sinc_interp_hann", ) # build normal hearing and hearing loss ears with open(cfg.listener.metafile, encoding="utf-8") as fp: listeners_file = json.load(fp) audiogram_cfs = listeners_file[cfg.listener.id]["audiogram_cfs"] audiogram_lvl_l = listeners_file[cfg.listener.id]["audiogram_levels_l"] audiogram_lvl_r = listeners_file[cfg.listener.id]["audiogram_levels_r"] audiogram = audiogram_lvl_l if ear == "left" else audiogram_lvl_r amp_module.nh_ear = MSBGHearingModel( audiogram=np.zeros_like(audiogram), audiometric=audiogram_cfs, sr=cfg.sr ) amp_module.hl_ear = MSBGHearingModel( audiogram=audiogram, audiometric=audiogram_cfs, sr=cfg.sr ) # callbacks callbacks = [] checkpoint_dir = exp_dir / "checkpoints/" checkpoint = ModelCheckpoint( str(checkpoint_dir), monitor="val_loss", mode="min", save_top_k=5, verbose=True ) callbacks.append(checkpoint) # set device gpus = -1 if torch.cuda.is_available() else None trainer = pl.Trainer( max_epochs=cfg.amp_trainer.epochs, callbacks=callbacks, default_root_dir=exp_dir, gpus=gpus, limit_train_batches=1.0, # Useful for fast experiment gradient_clip_val=cfg.amp_trainer.gradient_clip_val, num_sanity_val_steps=cfg.amp_trainer.num_sanity_val_steps, ) trainer.fit(amp_module) best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()} with (exp_dir / "best_k_models.json").open("w", encoding="utf-8") as fp: json.dump(best_k, fp, indent=0) state_dict = torch.load(checkpoint.best_model_path) amp_module.load_state_dict(state_dict=state_dict["state_dict"]) amp_module.cpu() torch.save(amp_module.model.state_dict(), str(exp_dir / "best_model.pth"))
[docs] @hydra.main(config_path=".", config_name="config") def run(cfg: DictConfig) -> None: logger.info("Begin training left ear enhancement module.") train_den(cfg, ear="left") logger.info("Begin training right ear enhancement module.") train_den(cfg, ear="right") logger.info("Begin training left ear amplification module.") train_amp(cfg, ear="left") logger.info("Begin training right ear amplification module.") train_amp(cfg, ear="right")
# pylint: disable=no-value-for-parameter if __name__ == "__main__": run()