Source code for recipes.cpc1.e032_sheffield.train_asr

#!/usr/bin/env python3
"""
Recipe for training a Transformer ASR system with librispeech, from the SpeechBrain
LibriSpeech/ASR recipe. The SpeechBrain version used in this work is:
https://github.com/speechbrain/speechbrain/tree/1eddf66eea01866d3cf9dfe61b00bb48d2062236
"""

import logging
import sys
from pathlib import Path

import speechbrain as sb
import torch
from hyperpyyaml import load_hyperpyyaml
from speechbrain.utils.distributed import run_on_main

logger = logging.getLogger(__name__)

tokenizer = None


# Define training procedure
[docs] class ASR(sb.core.Brain): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.acc_metric = None self.wer_metric = None self.train_stats = None self.switched = None self.optimizer = None self.tokenizer = None
[docs] def compute_forward(self, batch, stage): """Forward computations from waveform batches to output probabilities.""" batch = batch.to(self.device) wavs, wav_lens = batch.sig tokens_bos, _ = batch.tokens_bos # compute features feats = self.hparams.compute_features(wavs) current_epoch = self.hparams.epoch_counter.current feats = self.hparams.normalize(feats, wav_lens, epoch=current_epoch) if stage == sb.Stage.TRAIN: if hasattr(self.hparams, "augmentation"): feats = self.hparams.augmentation(feats) # forward modules src = self.hparams.CNN(feats) enc_out, pred = self.hparams.Transformer( src, tokens_bos, wav_lens, pad_idx=self.hparams.pad_index ) # output layer for ctc log-probabilities logits = self.hparams.ctc_lin(enc_out) p_ctc = self.hparams.log_softmax(logits) # output layer for seq2seq log-probabilities pred = self.hparams.seq_lin(pred) p_seq = self.hparams.log_softmax(pred) # Compute outputs hyps = None if stage == sb.Stage.TRAIN: hyps = None elif stage == sb.Stage.VALID: hyps = None current_epoch = self.hparams.epoch_counter.current if current_epoch % self.hparams.valid_search_interval == 0: # for the sake of efficiency, we only perform beamsearch with limited # capacity and no LM to give user some idea of how the AM is doing hyps, _ = self.hparams.valid_search(enc_out.detach(), wav_lens) elif stage == sb.Stage.TEST: hyps, _ = self.hparams.test_search(enc_out.detach(), wav_lens) return p_ctc, p_seq, wav_lens, hyps
[docs] def compute_objectives(self, predictions, batch, stage): """Computes the loss (CTC+NLL) given predictions and targets.""" if self.wer_metric is None or self.acc_metric is None: raise ValueError("wer_metric or acc_metric is None") (p_ctc, p_seq, wav_lens, hyps) = predictions ids = batch.id tokens_eos, tokens_eos_lens = batch.tokens_eos tokens, tokens_lens = batch.tokens loss_seq = self.hparams.seq_cost(p_seq, tokens_eos, length=tokens_eos_lens) loss_ctc = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens) loss = ( self.hparams.ctc_weight * loss_ctc + (1 - self.hparams.ctc_weight) * loss_seq ) if stage != sb.Stage.TRAIN: current_epoch = self.hparams.epoch_counter.current valid_search_interval = self.hparams.valid_search_interval if current_epoch % valid_search_interval == 0 or (stage == sb.Stage.TEST): # Decode token terms to words predicted_words = [ tokenizer.decode_ids(utt_seq).split(" ") for utt_seq in hyps ] target_words = [wrd.split(" ") for wrd in batch.wrd] self.wer_metric.append(ids, predicted_words, target_words) # compute the accuracy of the one-step-forward prediction self.acc_metric.append(p_seq, tokens_eos, tokens_eos_lens) return loss
[docs] def fit_batch(self, batch): """Train the parameters given a single batch in input""" # check if we need to switch optimizer # if so change the optimizer from Adam to SGD if self.optimizer is None: raise ValueError("optimizer is None") self.check_and_reset_optimizer() predictions = self.compute_forward(batch, sb.Stage.TRAIN) loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN) # normalize the loss by gradient_accumulation step (loss / self.hparams.gradient_accumulation).backward() if self.step % self.hparams.gradient_accumulation == 0: # gradient clipping & early stop if loss is not fini self.check_gradients(loss) self.optimizer.step() self.optimizer.zero_grad() # anneal lr every update self.hparams.noam_annealing(self.optimizer) return loss.detach()
[docs] def evaluate_batch(self, batch, stage): """Computations needed for validation/test batches""" with torch.no_grad(): predictions = self.compute_forward(batch, stage=stage) loss = self.compute_objectives(predictions, batch, stage=stage) return loss.detach()
[docs] def on_stage_start(self, stage, epoch=None): """Gets called at the beginning of each epoch""" if stage != sb.Stage.TRAIN: self.acc_metric = self.hparams.acc_computer() self.wer_metric = self.hparams.error_rate_computer()
[docs] def on_stage_end(self, stage, stage_loss, epoch=None): """Gets called at the end of a epoch.""" if self.wer_metric is None or self.acc_metric is None: raise ValueError("wer_metric or acc_metric is None") # Compute/store important stats stage_stats = {"loss": stage_loss} if stage == sb.Stage.TRAIN: self.train_stats = stage_stats else: stage_stats["ACC"] = self.acc_metric.summarize() current_epoch = self.hparams.epoch_counter.current valid_search_interval = self.hparams.valid_search_interval if current_epoch % valid_search_interval == 0 or stage == sb.Stage.TEST: stage_stats["WER"] = self.wer_metric.summarize("error_rate") # log stats and save checkpoint at end-of-epoch if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process(): # report different epoch stages according current stage current_epoch = self.hparams.epoch_counter.current if current_epoch <= self.hparams.stage_one_epochs: lr = self.hparams.noam_annealing.current_lr steps = self.hparams.noam_annealing.n_steps optimizer = self.optimizer.__class__.__name__ else: lr = self.hparams.lr_sgd steps = -1 optimizer = self.optimizer.__class__.__name__ epoch_stats = { "epoch": epoch, "lr": lr, "steps": steps, "optimizer": optimizer, } self.hparams.train_logger.log_stats( stats_meta=epoch_stats, train_stats=self.train_stats, valid_stats=stage_stats, ) self.checkpointer.save_and_keep_only( meta={"ACC": stage_stats["ACC"], "epoch": epoch}, max_keys=["ACC"], num_to_keep=1, ) elif stage == sb.Stage.TEST: self.hparams.train_logger.log_stats( stats_meta={"Epoch loaded": self.hparams.epoch_counter.current}, test_stats=stage_stats, ) with open(self.hparams.wer_file, "w", encoding="utf-8") as fp: self.wer_metric.write_stats(fp) # save the averaged checkpoint at the end of the evaluation stage # delete the rest of the intermediate checkpoints # ACC is set to 1.1 so checkpointer only keeps the averaged checkpoint self.checkpointer.save_and_keep_only( meta={"ACC": 1.1, "epoch": epoch}, max_keys=["ACC"], num_to_keep=1, )
[docs] def check_and_reset_optimizer(self): """reset the optimizer if training enters stage 2""" current_epoch = self.hparams.epoch_counter.current if not hasattr(self, "switched"): self.switched = False if isinstance(self.optimizer, torch.optim.SGD): self.switched = True if self.switched is True: return if current_epoch > self.hparams.stage_one_epochs: self.optimizer = self.hparams.SGD(self.modules.parameters()) if self.checkpointer is not None: self.checkpointer.add_recoverable("optimizer", self.optimizer) self.switched = True
[docs] def on_fit_start(self): """Initialize the right optimizer on the training start""" super().on_fit_start() # if the model is resumed from stage two, reinitialize the optimizer current_epoch = self.hparams.epoch_counter.current current_optimizer = self.optimizer if current_epoch > self.hparams.stage_one_epochs: del self.optimizer self.optimizer = self.hparams.SGD(self.modules.parameters()) # Load latest checkpoint to resume training if interrupted if self.checkpointer is not None: # do not reload the weights if training is interrupted right before # stage 2 group = current_optimizer.param_groups[0] if "momentum" not in group: return self.checkpointer.recover_if_possible(device=torch.device(self.device))
[docs] def on_evaluate_start(self, max_key=None, min_key=None): """perform checkpoint averge if needed""" super().on_evaluate_start() ckpts = self.checkpointer.find_checkpoints(max_key=max_key, min_key=min_key) ckpt = sb.utils.checkpoints.average_checkpoints( ckpts, recoverable_name="model", device=self.device ) self.hparams.model.load_state_dict(ckpt, strict=True) self.hparams.model.eval()
[docs] def dataio_prepare(hparams): """This function prepares the datasets to be used in the brain class. It also defines the data processing pipeline through user-defined functions.""" data_folder = hparams["data_folder"] train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( csv_path=hparams["train_csv"], replacements={"data_root": data_folder} ) valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv( csv_path=hparams["valid_csv"], replacements={"data_root": data_folder} ) valid_data = valid_data.filtered_sorted(sort_key="duration") # test is separate test_datasets = {} for csv_file in hparams["test_csv"]: name = Path(csv_file).stem test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv( csv_path=csv_file, replacements={"data_root": data_folder} ) test_datasets[name] = test_datasets[name].filtered_sorted(sort_key="duration") datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()] # We get the tokenizer as we need it to encode the labels when creating # mini-batches. # (Note, tokenizer is also defined in global space. TODO: fix the design) tokenizer = hparams["tokenizer"] # pylint: disable=redefined-outer-name # 2. Define audio pipeline: @sb.utils.data_pipeline.takes("wav") @sb.utils.data_pipeline.provides("sig") def audio_pipeline(wav): sig = sb.dataio.dataio.read_audio(wav) return sig sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) # 3. Define text pipeline: @sb.utils.data_pipeline.takes("wrd") @sb.utils.data_pipeline.provides( "wrd", "tokens_list", "tokens_bos", "tokens_eos", "tokens" ) def text_pipeline(wrd): yield wrd tokens_list = tokenizer.encode_as_ids(wrd) yield tokens_list tokens_bos = torch.LongTensor([hparams["bos_index"]] + (tokens_list)) yield tokens_bos tokens_eos = torch.LongTensor(tokens_list + [hparams["eos_index"]]) yield tokens_eos tokens = torch.LongTensor(tokens_list) yield tokens sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline) # 4. Set output: sb.dataio.dataset.set_output_keys( datasets, ["id", "sig", "wrd", "tokens_bos", "tokens_eos", "tokens"] ) return train_data, valid_data, test_datasets, tokenizer
[docs] def main(): # CLI: hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) with open(hparams_file, encoding="utf-8") as fp: hparams = load_hyperpyyaml(fp, overrides) # If distributed_launch=True then # create ddp_group with the right communication protocol sb.utils.distributed.ddp_init_group(run_opts) # Create experiment directory sb.create_experiment_directory( experiment_directory=hparams["output_folder"], hyperparams_to_save=hparams_file, overrides=overrides, ) # here we create the datasets objects as well as tokenization and encoding train_data, valid_data, test_datasets, _tokenizer = dataio_prepare(hparams) # We download the pretrained LM from HuggingFace (or elsewhere depending on # the path given in the YAML file). The tokenizer is loaded at the same time. run_on_main(hparams["pretrainer"].collect_files) hparams["pretrainer"].load_collected(device=run_opts["device"]) # Trainer initialization asr_brain = ASR( modules=hparams["modules"], opt_class=hparams["Adam"], hparams=hparams, run_opts=run_opts, checkpointer=hparams["checkpointer"], ) # adding objects to trainer: asr_brain.tokenizer = hparams["tokenizer"] # Training asr_brain.fit( asr_brain.hparams.epoch_counter, train_data, valid_data, train_loader_kwargs=hparams["train_dataloader_opts"], valid_loader_kwargs=hparams["valid_dataloader_opts"], ) # Testing for dataset_key, test_dataset in test_datasets.items(): # dataset_keys are test_clean, test_other etc asr_brain.hparams.wer_file = ( Path(hparams["output_folder"]) / f"wer_{dataset_key}.txt" ) asr_brain.evaluate( test_dataset, max_key="ACC", test_loader_kwargs=hparams["test_dataloader_opts"], )
if __name__ == "__main__": main()