Source code for recipes.cad2.task2.ConvTasNet.local.cad2task2_dataloader

from __future__ import annotations

import json
from pathlib import Path

import librosa
import numpy as np
import soundfile as sf
import torch
from torch.utils import data


[docs] class Compose: """Composes several augmentation transforms. Args: augmentations: list of augmentations to compose. """ def __init__(self, transforms): self.transforms = transforms def __call__(self, audio): for transform in self.transforms: audio = transform(audio) return audio
[docs] def augment_gain(audio, low=0.25, high=1.25): """Applies a random gain to each source between `low` and `high`""" if torch.FloatTensor(1).uniform_() < 0.5: gain_1 = low + torch.rand(1) * (high - low) gain_2 = low + torch.rand(1) * (high - low) return torch.stack([audio[0] * gain_1, audio[1] * gain_2]) return audio
[docs] def augment_channelswap(audio): """Randomly swap channels of stereo sources""" if audio.shape[0] == 2 and torch.FloatTensor(1).uniform_() < 0.5: return torch.flip(audio, [0]) return audio
[docs] def get_audio_durations(track_path: Path | str) -> float: if isinstance(track_path, str): track_path = Path(track_path) return librosa.get_duration(path=track_path.as_posix())
[docs] class RebalanceMusicDataset(data.Dataset): """ Dataset to process EnsembleSet and CadenzaWoodwind datasets for CAD2 Task2 baseline The dataset is composed of a target source and a random number of accompaniment sources. Args: root_path (str): Path to the root directory of the dataset music_tracks_file (str): Path to the json file containing the music tracks target (str): Target source to be extracted samples_per_track (int): Number of samples to extract from each track segment_length (float): Length of the segment to extract random_segments (bool): If True, extract random segments from the tracks random_track_mix (bool): If True, mix random accompaniment tracks split (str): Split of the dataset to use sample_rate (int): Sample rate of the audio files """ dataset_name = "EnsembleSet & CadenzaWoodwind" def __init__( self, root_path: Path | str, music_tracks_file: Path | str, target: str, samples_per_track: int = 1, segment_length: float | None = 5.0, random_segments=False, random_track_mix=False, split: str = "train", source_augmentations=lambda audio: audio, sample_rate: int = 44100, ): self.instruments = [ "Bassoon", "Cello", "Clarinet", "Flute", "Oboe", "Sax", "Viola", "Violin", ] self.repeated_instruments: list[str] = [] # ["Violin", "Viola", "Flute", "Sax"] if isinstance(root_path, str): root_path = Path(root_path) self.root_path: Path = root_path if split == "train": self.root_path = root_path / "train" self.music_tracks_file = Path(music_tracks_file) self.target = target self.samples_per_track = samples_per_track self.segment_length = segment_length self.random_segments = random_segments self.random_track_mix = random_track_mix self.split = split self.sample_rate = sample_rate self.source_augmentations = source_augmentations with open(self.music_tracks_file) as f: self.tracks = json.load(f) self.tracks_list = list(self.tracks.keys()) self.target_tracks = [] for k, v in self.tracks.items(): for type_source, source in v.items(): if type_source == "mixture": continue if self.target in source["instrument"]: self.target_tracks.append(k) if self.split == "train": self.min_src = 1 self.max_src = 3 self.accompaniment_tracks: dict = { i: [] for i in self.instruments if i != self.target } for k, v in self.tracks.items(): for type_source, source in v.items(): if type_source == "mixture": continue if self.target not in source["instrument"]: self.accompaniment_tracks[ source["instrument"].split("_")[0] ].append(k) def __len__(self): return len(self.target_tracks) * self.samples_per_track def __getitem__(self, idx): # assemble the mixture of target and interferers audio_sources = {} track_idx = idx // self.samples_per_track # Load the target source target_name = self.target_tracks[track_idx] if self.split == "test": self.track_name = target_name accompaniment_tracks = [] target_track = None for type_source, source in self.tracks[target_name].items(): if type_source == "mixture": continue if self.target in source["instrument"]: target_track = source else: accompaniment_tracks.append(source) segment_length = self.segment_length if self.segment_length is None: segment_length = target_track["duration"] target_duration = target_track["duration"] start = target_track["start"] if self.random_segments: start = np.floor( np.random.uniform(0, int(target_duration) - segment_length) ) segment_length = self.segment_length if self.segment_length is None: segment_length = target_track["duration"] start_sample = int(start * self.sample_rate) stop_sample = int((start + segment_length) * self.sample_rate) target_signal, _ = sf.read( self.root_path / target_track["track"], always_2d=True, start=start_sample, stop=stop_sample, ) # convert to torch tensor target_signal = torch.tensor(target_signal.T, dtype=torch.float) # target_signal = F.pad(target_signal, (target_len - target_signal.size(1), 0)) target_signal = self.source_augmentations(target_signal) target_len = int(segment_length * self.sample_rate) audio_sources["target"] = torch.zeros(2, target_len) audio_sources["target"] = target_signal audio_sources["accompaniment"] = torch.zeros_like(target_signal) # *************************************** # Load accompaniments # If random mixture, change `accompaniment_tracks` variable with # random instruments and tracks if self.random_track_mix: num_accomp = np.random.randint(self.min_src, self.max_src + 1) accompaniment_instruments = np.random.choice( [ x for x in self.instruments + self.repeated_instruments if self.target not in x ], num_accomp, replace=False, ) accompaniment_tracks = [] for accomp_instrument in accompaniment_instruments: a_t = np.random.choice(self.accompaniment_tracks[accomp_instrument]) for item in self.tracks[a_t].values(): if accomp_instrument in item["instrument"]: accompaniment_tracks.append(item) # Load accompaniment files for accomp_track in accompaniment_tracks: accomp_duration = accomp_track["duration"] # new random segment start = accomp_track["start"] if self.random_segments: start = np.floor(np.random.uniform(0, accomp_duration - segment_length)) start_sample = int(start * self.sample_rate) stop_sample = int((start + segment_length) * self.sample_rate) try: accomp_signal, _ = sf.read( self.root_path / accomp_track["track"], always_2d=True, start=start_sample, stop=stop_sample, ) accomp_signal = torch.tensor(accomp_signal.T, dtype=torch.float) accomp_signal = self.source_augmentations(accomp_signal) audio_sources["accompaniment"] += accomp_signal except Exception as err: print(accomp_track, flush=True) print(f"start {start}", flush=True) print(f"len {segment_length}", flush=True) print(err) # prepare mixture audio_mix = torch.stack(list(audio_sources.values())).sum(0) audio_sources = torch.stack( [audio_sources["target"], audio_sources["accompaniment"]], dim=0, ) return audio_mix, audio_sources
[docs] def get_infos(self): """Get dataset infos (for publishing models). Returns: dict, dataset infos with keys `dataset`, `task` and `licences`. """ infos = dict() infos["dataset"] = self.dataset_name infos["task"] = "separation" infos["licenses"] = [ensembleset_license, woodwing_licence] return infos
ensembleset_license = dict( title="EnsembleSet", title_link="https://zenodo.org/records/6519024", author="Saurjya Sarkar, Emmanouil Benetos, Mark Sandler", licence="CC BY-NC 4.0", licence_link="https://creativecommons.org/licenses/by-nc/4.0/", non_commercial=False, ) woodwing_licence = dict( title="CadenzaWoodwind", title_link="", author="Alex Miller, Trevor Cox, Gerardo Roa Dabike", licence="CC NC 4.0", licence_link="https://creativecommons.org/licenses/by/4.0/", non_commercial=False, ) if __name__ == "__main__": _root_path = Path( "/media/gerardoroadabike/Extreme" " SSD1/Challenges/CAD2/cadenza_data/cad2/task2/audio" ) _music_tracks_file = Path( "/media/gerardoroadabike/Extreme" " SSD1/Challenges/CAD2/cadenza_data/cad2/task2/metadata/music.valid.json" ) for target_source in [ "Bassoon", # "Cello", # "Clarinet", # "Flute", # "Oboe", # "Sax", # "Viola", # "Violin", ]: dataset = RebalanceMusicDataset( _root_path, _music_tracks_file, target_source, split="valid", random_track_mix=False, random_segments=False, segment_length=None, ) for x, y in dataset: print(x.shape, y.shape)