#!/usr/bin/env python3
"""
Recipe for training a Transformer ASR system with librispeech, from the SpeechBrain
LibriSpeech/ASR recipe. The SpeechBrain version used in this work is:
https://github.com/speechbrain/speechbrain/tree/1eddf66eea01866d3cf9dfe61b00bb48d2062236
"""
import logging
import sys
from pathlib import Path
import speechbrain as sb
import torch
from hyperpyyaml import load_hyperpyyaml
from speechbrain.utils.distributed import run_on_main
logger = logging.getLogger(__name__)
tokenizer = None
# Define training procedure
[docs]
class ASR(sb.core.Brain):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.acc_metric = None
self.wer_metric = None
self.train_stats = None
self.switched = None
self.optimizer = None
self.tokenizer = None
[docs]
def compute_forward(self, batch, stage):
"""Forward computations from waveform batches to output probabilities."""
batch = batch.to(self.device)
wavs, wav_lens = batch.sig
tokens_bos, _ = batch.tokens_bos
# compute features
feats = self.hparams.compute_features(wavs)
current_epoch = self.hparams.epoch_counter.current
feats = self.hparams.normalize(feats, wav_lens, epoch=current_epoch)
if stage == sb.Stage.TRAIN:
if hasattr(self.hparams, "augmentation"):
feats = self.hparams.augmentation(feats)
# forward modules
src = self.hparams.CNN(feats)
enc_out, pred = self.hparams.Transformer(
src, tokens_bos, wav_lens, pad_idx=self.hparams.pad_index
)
# output layer for ctc log-probabilities
logits = self.hparams.ctc_lin(enc_out)
p_ctc = self.hparams.log_softmax(logits)
# output layer for seq2seq log-probabilities
pred = self.hparams.seq_lin(pred)
p_seq = self.hparams.log_softmax(pred)
# Compute outputs
hyps = None
if stage == sb.Stage.TRAIN:
hyps = None
elif stage == sb.Stage.VALID:
hyps = None
current_epoch = self.hparams.epoch_counter.current
if current_epoch % self.hparams.valid_search_interval == 0:
# for the sake of efficiency, we only perform beamsearch with limited
# capacity and no LM to give user some idea of how the AM is doing
hyps, _ = self.hparams.valid_search(enc_out.detach(), wav_lens)
elif stage == sb.Stage.TEST:
hyps, _ = self.hparams.test_search(enc_out.detach(), wav_lens)
return p_ctc, p_seq, wav_lens, hyps
[docs]
def compute_objectives(self, predictions, batch, stage):
"""Computes the loss (CTC+NLL) given predictions and targets."""
if self.wer_metric is None or self.acc_metric is None:
raise ValueError("wer_metric or acc_metric is None")
(p_ctc, p_seq, wav_lens, hyps) = predictions
ids = batch.id
tokens_eos, tokens_eos_lens = batch.tokens_eos
tokens, tokens_lens = batch.tokens
loss_seq = self.hparams.seq_cost(p_seq, tokens_eos, length=tokens_eos_lens)
loss_ctc = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
loss = (
self.hparams.ctc_weight * loss_ctc
+ (1 - self.hparams.ctc_weight) * loss_seq
)
if stage != sb.Stage.TRAIN:
current_epoch = self.hparams.epoch_counter.current
valid_search_interval = self.hparams.valid_search_interval
if current_epoch % valid_search_interval == 0 or (stage == sb.Stage.TEST):
# Decode token terms to words
predicted_words = [
tokenizer.decode_ids(utt_seq).split(" ") for utt_seq in hyps
]
target_words = [wrd.split(" ") for wrd in batch.wrd]
self.wer_metric.append(ids, predicted_words, target_words)
# compute the accuracy of the one-step-forward prediction
self.acc_metric.append(p_seq, tokens_eos, tokens_eos_lens)
return loss
[docs]
def fit_batch(self, batch):
"""Train the parameters given a single batch in input"""
# check if we need to switch optimizer
# if so change the optimizer from Adam to SGD
if self.optimizer is None:
raise ValueError("optimizer is None")
self.check_and_reset_optimizer()
predictions = self.compute_forward(batch, sb.Stage.TRAIN)
loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
# normalize the loss by gradient_accumulation step
(loss / self.hparams.gradient_accumulation).backward()
if self.step % self.hparams.gradient_accumulation == 0:
# gradient clipping & early stop if loss is not fini
self.check_gradients(loss)
self.optimizer.step()
self.optimizer.zero_grad()
# anneal lr every update
self.hparams.noam_annealing(self.optimizer)
return loss.detach()
[docs]
def evaluate_batch(self, batch, stage):
"""Computations needed for validation/test batches"""
with torch.no_grad():
predictions = self.compute_forward(batch, stage=stage)
loss = self.compute_objectives(predictions, batch, stage=stage)
return loss.detach()
[docs]
def on_stage_start(self, stage, epoch=None):
"""Gets called at the beginning of each epoch"""
if stage != sb.Stage.TRAIN:
self.acc_metric = self.hparams.acc_computer()
self.wer_metric = self.hparams.error_rate_computer()
[docs]
def on_stage_end(self, stage, stage_loss, epoch=None):
"""Gets called at the end of a epoch."""
if self.wer_metric is None or self.acc_metric is None:
raise ValueError("wer_metric or acc_metric is None")
# Compute/store important stats
stage_stats = {"loss": stage_loss}
if stage == sb.Stage.TRAIN:
self.train_stats = stage_stats
else:
stage_stats["ACC"] = self.acc_metric.summarize()
current_epoch = self.hparams.epoch_counter.current
valid_search_interval = self.hparams.valid_search_interval
if current_epoch % valid_search_interval == 0 or stage == sb.Stage.TEST:
stage_stats["WER"] = self.wer_metric.summarize("error_rate")
# log stats and save checkpoint at end-of-epoch
if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process():
# report different epoch stages according current stage
current_epoch = self.hparams.epoch_counter.current
if current_epoch <= self.hparams.stage_one_epochs:
lr = self.hparams.noam_annealing.current_lr
steps = self.hparams.noam_annealing.n_steps
optimizer = self.optimizer.__class__.__name__
else:
lr = self.hparams.lr_sgd
steps = -1
optimizer = self.optimizer.__class__.__name__
epoch_stats = {
"epoch": epoch,
"lr": lr,
"steps": steps,
"optimizer": optimizer,
}
self.hparams.train_logger.log_stats(
stats_meta=epoch_stats,
train_stats=self.train_stats,
valid_stats=stage_stats,
)
self.checkpointer.save_and_keep_only(
meta={"ACC": stage_stats["ACC"], "epoch": epoch},
max_keys=["ACC"],
num_to_keep=1,
)
elif stage == sb.Stage.TEST:
self.hparams.train_logger.log_stats(
stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
test_stats=stage_stats,
)
with open(self.hparams.wer_file, "w", encoding="utf-8") as fp:
self.wer_metric.write_stats(fp)
# save the averaged checkpoint at the end of the evaluation stage
# delete the rest of the intermediate checkpoints
# ACC is set to 1.1 so checkpointer only keeps the averaged checkpoint
self.checkpointer.save_and_keep_only(
meta={"ACC": 1.1, "epoch": epoch},
max_keys=["ACC"],
num_to_keep=1,
)
[docs]
def check_and_reset_optimizer(self):
"""reset the optimizer if training enters stage 2"""
current_epoch = self.hparams.epoch_counter.current
if not hasattr(self, "switched"):
self.switched = False
if isinstance(self.optimizer, torch.optim.SGD):
self.switched = True
if self.switched is True:
return
if current_epoch > self.hparams.stage_one_epochs:
self.optimizer = self.hparams.SGD(self.modules.parameters())
if self.checkpointer is not None:
self.checkpointer.add_recoverable("optimizer", self.optimizer)
self.switched = True
[docs]
def on_fit_start(self):
"""Initialize the right optimizer on the training start"""
super().on_fit_start()
# if the model is resumed from stage two, reinitialize the optimizer
current_epoch = self.hparams.epoch_counter.current
current_optimizer = self.optimizer
if current_epoch > self.hparams.stage_one_epochs:
del self.optimizer
self.optimizer = self.hparams.SGD(self.modules.parameters())
# Load latest checkpoint to resume training if interrupted
if self.checkpointer is not None:
# do not reload the weights if training is interrupted right before
# stage 2
group = current_optimizer.param_groups[0]
if "momentum" not in group:
return
self.checkpointer.recover_if_possible(device=torch.device(self.device))
[docs]
def on_evaluate_start(self, max_key=None, min_key=None):
"""perform checkpoint averge if needed"""
super().on_evaluate_start()
ckpts = self.checkpointer.find_checkpoints(max_key=max_key, min_key=min_key)
ckpt = sb.utils.checkpoints.average_checkpoints(
ckpts, recoverable_name="model", device=self.device
)
self.hparams.model.load_state_dict(ckpt, strict=True)
self.hparams.model.eval()
[docs]
def dataio_prepare(hparams):
"""This function prepares the datasets to be used in the brain class.
It also defines the data processing pipeline through user-defined functions."""
data_folder = hparams["data_folder"]
train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
csv_path=hparams["train_csv"], replacements={"data_root": data_folder}
)
valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
csv_path=hparams["valid_csv"], replacements={"data_root": data_folder}
)
valid_data = valid_data.filtered_sorted(sort_key="duration")
# test is separate
test_datasets = {}
for csv_file in hparams["test_csv"]:
name = Path(csv_file).stem
test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(
csv_path=csv_file, replacements={"data_root": data_folder}
)
test_datasets[name] = test_datasets[name].filtered_sorted(sort_key="duration")
datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()]
# We get the tokenizer as we need it to encode the labels when creating
# mini-batches.
# (Note, tokenizer is also defined in global space. TODO: fix the design)
tokenizer = hparams["tokenizer"] # pylint: disable=redefined-outer-name
# 2. Define audio pipeline:
@sb.utils.data_pipeline.takes("wav")
@sb.utils.data_pipeline.provides("sig")
def audio_pipeline(wav):
sig = sb.dataio.dataio.read_audio(wav)
return sig
sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
# 3. Define text pipeline:
@sb.utils.data_pipeline.takes("wrd")
@sb.utils.data_pipeline.provides(
"wrd", "tokens_list", "tokens_bos", "tokens_eos", "tokens"
)
def text_pipeline(wrd):
yield wrd
tokens_list = tokenizer.encode_as_ids(wrd)
yield tokens_list
tokens_bos = torch.LongTensor([hparams["bos_index"]] + (tokens_list))
yield tokens_bos
tokens_eos = torch.LongTensor(tokens_list + [hparams["eos_index"]])
yield tokens_eos
tokens = torch.LongTensor(tokens_list)
yield tokens
sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
# 4. Set output:
sb.dataio.dataset.set_output_keys(
datasets, ["id", "sig", "wrd", "tokens_bos", "tokens_eos", "tokens"]
)
return train_data, valid_data, test_datasets, tokenizer
[docs]
def main():
# CLI:
hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
with open(hparams_file, encoding="utf-8") as fp:
hparams = load_hyperpyyaml(fp, overrides)
# If distributed_launch=True then
# create ddp_group with the right communication protocol
sb.utils.distributed.ddp_init_group(run_opts)
# Create experiment directory
sb.create_experiment_directory(
experiment_directory=hparams["output_folder"],
hyperparams_to_save=hparams_file,
overrides=overrides,
)
# here we create the datasets objects as well as tokenization and encoding
train_data, valid_data, test_datasets, _tokenizer = dataio_prepare(hparams)
# We download the pretrained LM from HuggingFace (or elsewhere depending on
# the path given in the YAML file). The tokenizer is loaded at the same time.
run_on_main(hparams["pretrainer"].collect_files)
hparams["pretrainer"].load_collected(device=run_opts["device"])
# Trainer initialization
asr_brain = ASR(
modules=hparams["modules"],
opt_class=hparams["Adam"],
hparams=hparams,
run_opts=run_opts,
checkpointer=hparams["checkpointer"],
)
# adding objects to trainer:
asr_brain.tokenizer = hparams["tokenizer"]
# Training
asr_brain.fit(
asr_brain.hparams.epoch_counter,
train_data,
valid_data,
train_loader_kwargs=hparams["train_dataloader_opts"],
valid_loader_kwargs=hparams["valid_dataloader_opts"],
)
# Testing
for dataset_key, test_dataset in test_datasets.items():
# dataset_keys are test_clean, test_other etc
asr_brain.hparams.wer_file = (
Path(hparams["output_folder"]) / f"wer_{dataset_key}.txt"
)
asr_brain.evaluate(
test_dataset,
max_key="ACC",
test_loader_kwargs=hparams["test_dataloader_opts"],
)
if __name__ == "__main__":
main()