Source code for recipes.cad2.task2.ConvTasNet.train

import argparse
import json
import os
from pathlib import Path
from pprint import pprint as print

import pytorch_lightning as pl
import torch
import yaml
from asteroid.engine.system import System
from asteroid.utils import parse_args_as_dict, prepare_parser_from_dict
from local import (
    Compose,
    ConvTasNetStereo,
    RebalanceMusicDataset,
    augment_channelswap,
    augment_gain,
)
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader

# Keys which are not in the conf.yml file can be added here.
# In the hierarchical dictionary created when parsing, the key `key` can be
# found at dic['main_args'][key]
# By default train.py will use all available GPUs. The `id` option in run.sh
# will limit the number of available GPUs for train.py .
parser = argparse.ArgumentParser()
parser.add_argument(
    "--exp_dir", default="exp/tmp", help="Full path to save best validation model"
)


[docs] def main(conf): source_augmentations = Compose([augment_gain, augment_channelswap]) dataset_kwargs = { "root_path": Path(conf["data"]["root_path"]), "sample_rate": conf["data"]["sample_rate"], "target": conf["data"]["target"], } train_set = RebalanceMusicDataset( split="train", music_tracks_file=f"{conf['data']['music_tracks_file']}/music.train.json", samples_per_track=conf["data"]["samples_per_track"], segment_length=conf["data"]["segment_length"], random_segments=True, random_track_mix=True, source_augmentations=source_augmentations, **dataset_kwargs, ) val_set = RebalanceMusicDataset( music_tracks_file=f"{conf['data']['music_tracks_file']}/music.valid.json", split="valid", segment_length=None, **dataset_kwargs, ) train_loader = DataLoader( train_set, shuffle=True, batch_size=conf["training"]["batch_size"], num_workers=conf["training"]["num_workers"], drop_last=True, pin_memory=True, ) val_loader = DataLoader( val_set, shuffle=False, batch_size=conf["training"]["batch_size"], num_workers=conf["training"]["num_workers"], pin_memory=True, drop_last=False, ) model = ConvTasNetStereo( **conf["convtasnet"], samplerate=conf["data"]["sample_rate"] ) optimizer = torch.optim.Adam(model.parameters(), conf["optim"]["lr"]) # Define scheduler scheduler = None if conf["training"]["half_lr"]: scheduler = ReduceLROnPlateau(optimizer=optimizer, factor=0.5, patience=5) # Just after instantiating, save the args. Easy loading in the future. exp_dir = conf["main_args"]["exp_dir"] os.makedirs(exp_dir, exist_ok=True) conf_path = os.path.join(exp_dir, "conf.yml") with open(conf_path, "w") as outfile: yaml.safe_dump(conf, outfile) # Define Loss function. loss_func = torch.nn.L1Loss() # loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx") system = System( model=model, loss_func=loss_func, optimizer=optimizer, train_loader=train_loader, val_loader=val_loader, scheduler=scheduler, config=conf, ) # Define callbacks callbacks = [] checkpoint_dir = os.path.join(exp_dir, "checkpoints/") checkpoint = ModelCheckpoint( checkpoint_dir, monitor="val_loss", mode="min", save_top_k=1, verbose=True ) callbacks.append(checkpoint) if conf["training"]["early_stop"]: callbacks.append( EarlyStopping(monitor="val_loss", mode="min", patience=20, verbose=True) ) trainer = pl.Trainer( max_epochs=conf["training"]["epochs"], callbacks=callbacks, default_root_dir=exp_dir, accelerator="gpu" if torch.cuda.is_available() else "cpu", strategy="auto", devices="auto", gradient_clip_val=5.0, accumulate_grad_batches=conf["training"]["aggregate"], ) trainer.fit(system) best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()} with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f: json.dump(best_k, f, indent=0) state_dict = torch.load(checkpoint.best_model_path) system.load_state_dict(state_dict=state_dict["state_dict"]) system.cpu() to_save = system.model.serialize() to_save.update(train_set.get_infos()) torch.save(to_save, os.path.join(exp_dir, "best_model.pth"))
if __name__ == "__main__": # We start with opening the config file conf.yml as a dictionary from # which we can create parsers. Each top level key in the dictionary defined # by the YAML file creates a group in the parser. with open("local/conf.yml") as f: def_conf = yaml.safe_load(f) parser = prepare_parser_from_dict(def_conf, parser=parser) # Arguments are then parsed into a hierarchical dictionary (instead of # flat, as returned by argparse) to facilitate calls to the different # asteroid methods (see in main). # plain_args is the direct output of parser.parse_args() and contains all # the attributes in an non-hierarchical structure. It can be useful to also # have it so we included it here but it is not used. arg_dic, plain_args = parse_args_as_dict(parser, return_plain_args=True) print(arg_dic) main(arg_dic)