diff --git a/egs/wham/wavesplit/dataloading.py b/egs/wham/wavesplit/dataloading.py new file mode 100644 index 000000000..65620f9e4 --- /dev/null +++ b/egs/wham/wavesplit/dataloading.py @@ -0,0 +1,153 @@ +import torch +from torch.utils import data +import json +import os +import numpy as np +import soundfile as sf + +DATASET = "WHAM" +# WHAM tasks +enh_single = {"mixture": "mix_single", "sources": ["s1"], "infos": ["noise"], "default_nsrc": 1} +enh_both = {"mixture": "mix_both", "sources": ["mix_clean"], "infos": ["noise"], "default_nsrc": 1} +sep_clean = {"mixture": "mix_clean", "sources": ["s1", "s2"], "infos": [], "default_nsrc": 2} +sep_noisy = {"mixture": "mix_both", "sources": ["s1", "s2"], "infos": ["noise"], "default_nsrc": 2} + +WHAM_TASKS = { + "enhance_single": enh_single, + "enhance_both": enh_both, + "sep_clean": sep_clean, + "sep_noisy": sep_noisy, +} +# Aliases. +WHAM_TASKS["enh_single"] = WHAM_TASKS["enhance_single"] +WHAM_TASKS["enh_both"] = WHAM_TASKS["enhance_both"] + + +class WHAMID(data.Dataset): + """Dataset class for WHAM source separation and speech enhancement tasks. + + Args: + json_dir (str): The path to the directory containing the json files. + task (str): One of ``'enh_single'``, ``'enh_both'``, ``'sep_clean'`` or + ``'sep_noisy'``. + + * ``'enh_single'`` for single speaker speech enhancement. + * ``'enh_both'`` for multi speaker speech enhancement. + * ``'sep_clean'`` for two-speaker clean source separation. + * ``'sep_noisy'`` for two-speaker noisy source separation. + + sample_rate (int, optional): The sampling rate of the wav files. + segment (float, optional): Length of the segments used for training, + in seconds. If None, use full utterances (e.g. for test). + nondefault_nsrc (int, optional): Number of sources in the training + targets. + If None, defaults to one for enhancement tasks and two for + separation tasks. + """ + + def __init__(self, json_dir, task, sample_rate=8000, segment=4.0, nondefault_nsrc=None): + super(WHAMID, self).__init__() + if task not in WHAM_TASKS.keys(): + raise ValueError( + "Unexpected task {}, expected one of " "{}".format(task, WHAM_TASKS.keys()) + ) + # Task setting + self.json_dir = json_dir + self.task = task + self.task_dict = WHAM_TASKS[task] + self.sample_rate = sample_rate + self.seg_len = None if segment is None else int(segment * sample_rate) + if not nondefault_nsrc: + self.n_src = self.task_dict["default_nsrc"] + else: + assert nondefault_nsrc >= self.task_dict["default_nsrc"] + self.n_src = nondefault_nsrc + self.like_test = self.seg_len is None + # Load json examples + ex_json = os.path.join(json_dir, self.task_dict["mixture"] + ".json") + + with open(ex_json, "r") as f: + examples = json.load(f) + + # Filter out short utterances only when segment is specified + self.examples = [] + orig_len = len(examples) + drop_utt, drop_len = 0, 0 + if not self.like_test: + for ex in examples: # Go backward + if ex["length"] < self.seg_len: + drop_utt += 1 + drop_len += ex["length"] + else: + self.examples.append(ex) + + print( + "Drop {} utts({:.2f} h) from {} (shorter than {} samples)".format( + drop_utt, drop_len / sample_rate / 36000, orig_len, self.seg_len + ) + ) + + # count total number of speakers + speakers = set() + for ex in self.examples: + for spk in ex["spk_id"]: + speakers.add(spk[:3]) + + print("Total number of speakers {}".format(len(list(speakers)))) + + # convert speakers id into integers + indx = 0 + spk2indx = {} + for spk in list(speakers): + spk2indx[spk] = indx + indx += 1 + self.spk2indx = spk2indx + + for ex in self.examples: + new = [] + for spk in ex["spk_id"]: + new.append(spk2indx[spk[:3]]) + ex["spk_id"] = new + + def __len__(self): + return len(self.examples) + + def __getitem__(self, idx): + """Gets a mixture/sources pair. + Returns: + mixture, vstack([source_arrays]) + """ + c_ex = self.examples[idx] + # Random start + if c_ex["length"] == self.seg_len or self.like_test: + rand_start = 0 + else: + rand_start = np.random.randint(0, c_ex["length"] - self.seg_len) + if self.like_test: + stop = None + else: + stop = rand_start + self.seg_len + # Load mixture + x, _ = sf.read(c_ex["mix"], start=rand_start, stop=stop, dtype="float32") + # seg_len = torch.as_tensor([len(x)]) + # Load sources + source_arrays = [] + for src in c_ex["sources"]: + s, _ = sf.read(src, start=rand_start, stop=stop, dtype="float32") + source_arrays.append(s) + sources = torch.from_numpy(np.vstack(source_arrays)) + + if np.random.random() > 0.5: # randomly permute (not sure if it can help but makes sense) + sources = torch.stack((sources[1], sources[0])) + c_ex["spk_id"] = [c_ex["spk_id"][1], c_ex["spk_id"][0]] + + return torch.from_numpy(x), sources, torch.Tensor(c_ex["spk_id"]).long() + + +if __name__ == "__main__": + a = WHAMID( + "/media/sam/bx500/wavesplit/asteroid/egs/wham/wavesplit/data/wav8k/min/tt", "sep_clean" + ) + + for i in a: + print(i[-1]) diff --git a/egs/wham/wavesplit/local/conf.yml b/egs/wham/wavesplit/local/conf.yml new file mode 100644 index 000000000..dbf2cd8ec --- /dev/null +++ b/egs/wham/wavesplit/local/conf.yml @@ -0,0 +1,25 @@ +# Network config +masknet: + n_src: 2 + +# Training config +training: + epochs: 200 + batch_size: 4 + num_workers: 6 + half_lr: yes + early_stop: yes + gradient_clipping: 5 +# Optim config +optim: + optimizer: adam + lr: 0.001 +# Data config +data: + train_dir: data/wav8k/min/tr/ + valid_dir: data/wav8k/min/cv/ + task: sep_clean + nondefault_nsrc: + sample_rate: 8000 + mode: min + segment: 1.0 diff --git a/egs/wham/wavesplit/local/convert_sphere2wav.sh b/egs/wham/wavesplit/local/convert_sphere2wav.sh new file mode 100644 index 000000000..8870bf096 --- /dev/null +++ b/egs/wham/wavesplit/local/convert_sphere2wav.sh @@ -0,0 +1,38 @@ +#!/bin/bash +# MIT Copyright (c) 2018 Kaituo XU + + +sphere_dir=tmp +wav_dir=tmp + +. utils/parse_options.sh || exit 1; + + +echo "Download sph2pipe_v2.5 into egs/tools" +mkdir -p ../../tools +wget http://www.openslr.org/resources/3/sph2pipe_v2.5.tar.gz -P ../../tools +cd ../../tools && tar -xzvf sph2pipe_v2.5.tar.gz && gcc -o sph2pipe_v2.5/sph2pipe sph2pipe_v2.5/*.c -lm && cd - + +echo "Convert sphere format to wav format" +sph2pipe=../../tools/sph2pipe_v2.5/sph2pipe + +if [ ! -x $sph2pipe ]; then + echo "Could not find (or execute) the sph2pipe program at $sph2pipe"; + exit 1; +fi + +tmp=data/local/ +mkdir -p $tmp + +[ ! -f $tmp/sph.list ] && find $sphere_dir -iname '*.wv*' | grep -e 'si_tr_s' -e 'si_dt_05' -e 'si_et_05' > $tmp/sph.list + +if [ ! -d $wav_dir ]; then + while read line; do + wav=`echo "$line" | sed "s:wv1:wav:g" | awk -v dir=$wav_dir -F'/' '{printf("%s/%s/%s/%s", dir, $(NF-2), $(NF-1), $NF)}'` + echo $wav + mkdir -p `dirname $wav` + $sph2pipe -f wav $line > $wav + done < $tmp/sph.list > $tmp/wav.list +else + echo "Do you already get wav files? if not, please remove $wav_dir" +fi diff --git a/egs/wham/wavesplit/local/prepare_data.sh b/egs/wham/wavesplit/local/prepare_data.sh new file mode 100755 index 000000000..590a1999e --- /dev/null +++ b/egs/wham/wavesplit/local/prepare_data.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +wav_dir=tmp +out_dir=tmp +python_path=python + +. utils/parse_options.sh + +## Download WHAM noises +mkdir -p $out_dir +echo "Download WHAM noises into $out_dir" +# If downloading stalls for more than 20s, relaunch from previous state. +wget -c --tries=0 --read-timeout=20 https://storage.googleapis.com/whisper-public/wham_noise.zip -P $out_dir + +echo "Download WHAM scripts into $out_dir" +wget https://storage.googleapis.com/whisper-public/wham_scripts.tar.gz -P $out_dir +mkdir -p $out_dir/wham_scripts +tar -xzvf $out_dir/wham_scripts.tar.gz -C $out_dir/wham_scripts +mv $out_dir/wham_scripts.tar.gz $out_dir/wham_scripts + +wait + +unzip $out_dir/wham_noise.zip $out_dir >> logs/unzip_wham.log + +echo "Run python scripts to create the WHAM mixtures" +# Requires : Numpy, Scipy, Pandas, and Pysoundfile +cd $out_dir/wham_scripts/wham_scripts +$python_path create_wham_from_scratch.py \ + --wsj0-root $wav_dir \ + --wham-noise-root $out_dir/wham_noise\ + --output-dir $out_dir +cd - diff --git a/egs/wham/wavesplit/local/preprocess_wham.py b/egs/wham/wavesplit/local/preprocess_wham.py new file mode 100644 index 000000000..73aa6f85a --- /dev/null +++ b/egs/wham/wavesplit/local/preprocess_wham.py @@ -0,0 +1,93 @@ +import argparse +import json +import os +import soundfile as sf +import glob + + +def preprocess_task(task, in_dir, out_dir): + if not os.path.exists(out_dir): + os.makedirs(out_dir) + + if task == "mix_both": + mix_both = glob.glob(os.path.join(in_dir, "mix_both", "*.wav")) + examples = [] + for mix in mix_both: + filename = mix.split("/")[-1] + spk1_id = filename.split("_")[0][:3] + spk2_id = filename.split("_")[2][:3] + length = len(sf.SoundFile(mix)) + + noise = os.path.join(in_dir, "noise", filename) + s1 = os.path.join(in_dir, "s1", filename) + s2 = os.path.join(in_dir, "s2", filename) + + ex = { + "mix": mix, + "sources": [s1, s2], + "noise": noise, + "spk_id": [spk1_id, spk2_id], + "length": length, + } + examples.append(ex) + + with open(os.path.join(out_dir, "mix_both.json"), "w") as f: + json.dump(examples, f, indent=4) + + elif task == "mix_clean": + mix_clean = glob.glob(os.path.join(in_dir, "mix_clean", "*.wav")) + examples = [] + for mix in mix_clean: + filename = mix.split("/")[-1] + spk1_id = filename.split("_")[0][:3] + spk2_id = filename.split("_")[2][:3] + length = len(sf.SoundFile(mix)) + + s1 = os.path.join(in_dir, "s1", filename) + s2 = os.path.join(in_dir, "s2", filename) + + ex = {"mix": mix, "sources": [s1, s2], "spk_id": [spk1_id, spk2_id], "length": length} + examples.append(ex) + + with open(os.path.join(out_dir, "mix_clean.json"), "w") as f: + json.dump(examples, f, indent=4) + + elif task == "mix_single": + mix_single = glob.glob(os.path.join(in_dir, "mix_single", "*.wav")) + examples = [] + for mix in mix_single: + filename = mix.split("/")[-1] + spk1_id = filename.split("_")[0][:3] + length = len(sf.SoundFile(mix)) + + s1 = os.path.join(in_dir, "s1", filename) + + ex = {"mix": mix, "sources": [s1], "spk_id": [spk1_id], "length": length} + examples.append(ex) + + with open(os.path.join(out_dir, "mix_single.json"), "w") as f: + json.dump(examples, f, indent=4) + else: + raise EnvironmentError + + +def preprocess(inp_args): + tasks = ["mix_both", "mix_clean", "mix_single"] + for split in ["tr", "cv", "tt"]: + for task in tasks: + preprocess_task( + task, os.path.join(inp_args.in_dir, split), os.path.join(inp_args.out_dir, split) + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("WHAM data preprocessing") + parser.add_argument( + "--in_dir", type=str, default=None, help="Directory path of wham including tr, cv and tt" + ) + parser.add_argument( + "--out_dir", type=str, default=None, help="Directory path to put output files" + ) + args = parser.parse_args() + print(args) + preprocess(args) diff --git a/egs/wham/wavesplit/losses.py b/egs/wham/wavesplit/losses.py new file mode 100644 index 000000000..769b168ed --- /dev/null +++ b/egs/wham/wavesplit/losses.py @@ -0,0 +1,240 @@ +from torch import nn +import torch +import numpy as np +from torch.nn import functional as F +from itertools import permutations +from asteroid.losses.sdr import MultiSrcNegSDR +import math + + +class ClippedSDR(nn.Module): + def __init__(self, clip_value=-30): + super(ClippedSDR, self).__init__() + + self.snr = MultiSrcNegSDR("sisdr") + self.clip_value = float(clip_value) + + def forward(self, est_targets, targets): + + return torch.clamp(self.snr(est_targets, targets), min=self.clip_value) + + +class SpeakerVectorLoss(nn.Module): + def __init__( + self, + n_speakers, + embed_dim=512, + learnable_emb=True, + loss_type="global", + weight=2, + distance_reg=0.3, + gaussian_reg=0.2, + return_oracle=False, + ): + super(SpeakerVectorLoss, self).__init__() + + self.learnable_emb = learnable_emb + assert loss_type in ["distance", "global", "local"] + self.loss_type = loss_type + self.weight = float(weight) + self.distance_reg = float(distance_reg) + self.gaussian_reg = float(gaussian_reg) + self.return_oracle = return_oracle + + spk_emb = torch.eye(max(n_speakers, embed_dim)) # Neil: one-hot init + spk_emb = spk_emb[:n_speakers, :embed_dim] + + if learnable_emb == True: + self.spk_embeddings = nn.Parameter(spk_emb) + else: + self.register_buffer("spk_embeddings", spk_emb) + + if self.loss_type != "distance": + self.alpha = nn.Parameter(torch.Tensor([1.0])) + self.beta = nn.Parameter(torch.Tensor([0.0])) + + @staticmethod + def _l_dist_speaker(c_spk_vec_perm, spk_embeddings, spk_labels, spk_mask): + + spk_labels = spk_labels.unsqueeze(-1).repeat(1, 1, spk_embeddings.size(-1)) + utt_embeddings = spk_embeddings.unsqueeze(0).repeat(spk_labels.size(0), 1, 1).gather( + 1, spk_labels + ).unsqueeze(-1) * spk_mask.unsqueeze(2) + + distance = ((c_spk_vec_perm - utt_embeddings) ** 2).sum(2) + + intra_spk = ((c_spk_vec_perm[:, 0] - c_spk_vec_perm[:, 1]) ** 2).sum(1) + return distance.sum(1) + F.relu(1.0 - intra_spk) # ok for two speakers + + def _l_local_speaker(self, c_spk_vec_perm, spk_embeddings, spk_labels, spk_mask): + + utt_embeddings = spk_embeddings[spk_labels].unsqueeze(-1) * spk_mask.unsqueeze(2) + alpha = torch.clamp(self.alpha, 1e-8) + + distance = alpha * ((c_spk_vec_perm - utt_embeddings) ** 2).sum(2) + self.beta + distances = ( + alpha * ((c_spk_vec_perm.unsqueeze(1) - utt_embeddings.unsqueeze(2)) ** 2).sum(3) + + self.beta + ) + distances = torch.exp(-distances).sum(1) + + return (distance + torch.log(distances)).sum(1) + + def _l_global_speaker(self, c_spk_vec_perm, spk_embeddings, spk_labels, spk_mask): + + utt_embeddings = spk_embeddings[spk_labels].unsqueeze(-1) * spk_mask.unsqueeze(2) + alpha = torch.clamp(self.alpha, 1e-8) + + distance_utt = alpha * ((c_spk_vec_perm - utt_embeddings) ** 2).sum(2) + self.beta + + B, src, embed_dim, frames = c_spk_vec_perm.size() + spk_embeddings = spk_embeddings.reshape(1, spk_embeddings.shape[0], embed_dim, 1).expand( + B, -1, -1, frames + ) + distances = ( + alpha * ((c_spk_vec_perm.unsqueeze(1) - spk_embeddings.unsqueeze(2)) ** 2).sum(3) + + self.beta + ) + distances = torch.exp(-distances).sum(1) + + return (distance_utt + torch.log(distances)).sum(1) + + # exp normalize trick + # with torch.no_grad(): + # b = torch.max(distances, dim=1, keepdim=True)[0] + # out = -distance_utt + b.squeeze(1) - torch.log(torch.exp(-distances + b).sum(1)) + # return out.sum(1) + + def forward(self, speaker_vectors, spk_mask, spk_labels): + + if self.gaussian_reg: + noise = torch.randn( + self.spk_embeddings.size(), device=speaker_vectors.device + ) * math.sqrt(self.gaussian_reg) + spk_embeddings = self.spk_embeddings + noise.to(spk_labels) + else: + spk_embeddings = self.spk_embeddings + + if self.learnable_emb or self.gaussian_reg: # re project on unit sphere + # re-project on unit sphere + spk_embeddings = ( + spk_embeddings / torch.sum(spk_embeddings ** 2, -1, keepdim=True).sqrt() + ) + + if self.distance_reg: + pairwise_dist = ( + (torch.abs(spk_embeddings.unsqueeze(0) - spk_embeddings.unsqueeze(1))) + .mean(-1) + .fill_diagonal_(np.inf) + ) + distance_reg = -torch.sum(torch.min(torch.log(pairwise_dist), dim=-1)[0]) + + # speaker vectors B, n_src, dim, frames + # spk mask B, n_src, frames boolean mask + # spk indxs list of len B of list which contains spk label for current utterance + B, n_src, embed_dim, frames = speaker_vectors.size() + + n_src = speaker_vectors.shape[1] + perms = list(permutations(range(n_src))) + if self.loss_type == "distance": + loss_set = torch.stack( + [ + self._l_dist_speaker( + speaker_vectors[:, perm], spk_embeddings, spk_labels, spk_mask + ) + for perm in perms + ], + dim=1, + ) + elif self.loss_type == "local": + loss_set = torch.stack( + [ + self._l_local_speaker( + speaker_vectors[:, perm], spk_embeddings, spk_labels, spk_mask + ) + for perm in perms + ], + dim=1, + ) + else: + loss_set = torch.stack( + [ + self._l_global_speaker( + speaker_vectors[:, perm], spk_embeddings, spk_labels, spk_mask + ) + for perm in perms + ], + dim=1, + ) + + # Indexes and values of min losses for each batch element + min_loss, min_loss_idx = torch.min(loss_set, dim=1) + + # reorder sources for each frame !! + perms = min_loss.new_tensor(perms, dtype=torch.long) + perms = perms[..., None, None].expand(-1, -1, B, frames) + min_loss_idx = min_loss_idx[None, None, ...].expand(1, n_src, -1, -1) + min_loss_perm = torch.gather(perms, dim=0, index=min_loss_idx)[0] + min_loss_perm = ( + min_loss_perm.transpose(0, 1).reshape(B, n_src, 1, frames).expand(-1, -1, embed_dim, -1) + ) + # tot_loss + + spk_loss = self.weight * min_loss.mean() + if self.distance_reg: + spk_loss += self.distance_reg * distance_reg + reordered_sources = torch.gather(speaker_vectors, dim=1, index=min_loss_perm) + + if self.return_oracle: + spk_embeddings = spk_embeddings.to(spk_labels) + utt_embeddings = spk_embeddings[spk_labels].unsqueeze(-1) * spk_mask.unsqueeze(2) + return spk_loss, reordered_sources, utt_embeddings + + return spk_loss, reordered_sources + + +if __name__ == "__main__": + import random + + n_speakers = 101 + emb_speaker = 256 + + spk_embedding_1 = torch.rand(emb_speaker) * 0.01 # torch.zeros(emb_speaker).float() + spk_embedding_2 = torch.rand(emb_speaker) * 0.01 # torch.zeros(emb_speaker).float() + spk_embedding_1[0] = 1.0 + spk_embedding_2[1] = 1.0 + + tmp = [] + for i in range(200): + if np.random.random() >= 0.5: + tmp.append(torch.stack((spk_embedding_1, spk_embedding_2))) + else: + tmp.append(torch.stack((spk_embedding_2, spk_embedding_1))) + + speaker_vectors = torch.stack(tmp).permute(1, 2, 0).unsqueeze(0) + + speaker_labels = torch.from_numpy(np.array([[1, 0]])) + oracle = ( + torch.stack((spk_embedding_2, spk_embedding_1)) + .unsqueeze(0) + .unsqueeze(-1) + .repeat(1, 1, 1, 200) + ) + + # testing exp normalize average + # distances = torch.ones((1, 101, 4000)) + # with torch.no_grad(): + # b = torch.max(distances, dim=1, keepdim=True)[0] + # out = b.squeeze(1) - torch.log(torch.exp(-distances + b).sum(1)) + # out2 = - torch.log(torch.exp(-distances).sum(1)) + + loss_spk = SpeakerVectorLoss( + n_speakers, emb_speaker, loss_type="distance", distance_reg=0, gaussian_reg=0 + ) + + speaker_mask = torch.ones( + (1, 2, 200) + ) # silence where there are no speakers actually thi is test + spk_loss, reordered = loss_spk(speaker_vectors, speaker_mask, speaker_labels) + print(spk_loss) + np.testing.assert_array_almost_equal(reordered.numpy(), oracle.numpy()) diff --git a/egs/wham/wavesplit/model.py b/egs/wham/wavesplit/model.py new file mode 100644 index 000000000..e79405144 --- /dev/null +++ b/egs/wham/wavesplit/model.py @@ -0,0 +1,281 @@ +from torch import nn +import torch +from asteroid.masknn import norms +from kmeans_pytorch import kmeans, kmeans_predict + + +class Conv1DBlock(nn.Module): + def __init__(self, in_chan, hid_chan, kernel_size, padding, dilation, norm_type="gLN"): + super(Conv1DBlock, self).__init__() + + conv_norm = norms.get(norm_type) + depth_conv1d = nn.Conv1d(in_chan, hid_chan, kernel_size, padding=padding, dilation=dilation) + torch.nn.init.kaiming_uniform_(depth_conv1d.weight) + + self.out = nn.Sequential(depth_conv1d, nn.PReLU(), conv_norm(hid_chan)) + + def forward(self, x): + """ Input shape [batch, feats, seq]""" + + return self.out(x) + + +class SepConv1DBlock(nn.Module): + def __init__( + self, + in_chan, + hid_chan, + spk_vec_chan, + kernel_size, + padding, + dilation, + norm_type="gLN", + use_FiLM=True, + ): + super(SepConv1DBlock, self).__init__() + + self.use_FiLM = use_FiLM + conv_norm = norms.get(norm_type) + self.depth_conv1d = nn.Conv1d( + in_chan, hid_chan, kernel_size, padding=padding, dilation=dilation + ) + torch.nn.init.kaiming_uniform_(self.depth_conv1d.weight) + self.out = nn.Sequential(nn.PReLU(), conv_norm(hid_chan)) + + # FiLM conditioning + if self.use_FiLM: + self.mul_lin = nn.Linear(spk_vec_chan, hid_chan) + torch.nn.init.kaiming_uniform_(self.mul_lin.weight) + self.add_lin = nn.Linear(spk_vec_chan, hid_chan) + torch.nn.init.kaiming_uniform_(self.add_lin.weight) + + def apply_conditioning(self, spk_vec, squeezed): + spk_vec = spk_vec.unsqueeze(-1) + bias = self.add_lin(spk_vec.transpose(1, -1)).transpose(1, -1) + if self.use_FiLM: + mul = self.mul_lin(spk_vec.transpose(1, -1)).transpose(1, -1) + return mul * squeezed + bias + else: + return squeezed + bias.unsqueeze(-1) + + def forward(self, x, spk_vec): + """ Input shape [batch, feats, seq]""" + + conditioned = self.apply_conditioning(spk_vec, self.depth_conv1d(x)) + + return self.out(conditioned) + + +class SpeakerStack(nn.Module): + # basically this is plain conv-tasnet remove this in future releases + + def __init__( + self, n_src, embed_dim=512, n_blocks=14, n_repeats=1, kernel_size=3, norm_type="gLN" + ): + + super(SpeakerStack, self).__init__() + self.embed_dim = embed_dim + self.n_src = n_src + self.n_blocks = n_blocks + self.n_repeats = n_repeats + self.kernel_size = kernel_size + self.norm_type = norm_type + + # Succession of Conv1DBlock with exponentially increasing dilation. + self.TCN = nn.ModuleList() + for r in range(n_repeats): + for x in range(n_blocks): + padding = (kernel_size - 1) * 2 ** x // 2 + if r == 0 and x == 0: + in_chan = 1 + else: + in_chan = embed_dim + self.TCN.append( + Conv1DBlock( + in_chan, + embed_dim, + kernel_size, + padding=padding, + dilation=2 ** x, + norm_type=norm_type, + ) + ) + mask_conv = nn.Conv1d(embed_dim, n_src * embed_dim, 1) + self.mask_net = nn.Sequential(mask_conv) + + def forward(self, mixture_w): + """ + Args: + mixture_w (:class:`torch.Tensor`): Tensor of shape + [batch, n_filters, n_frames] + + Returns: + :class:`torch.Tensor`: + estimated mask of shape [batch, n_src, n_filters, n_frames] + """ + batch, _, n_frames = mixture_w.size() + output = mixture_w + for i in range(len(self.TCN)): + if i == 0: + output = self.TCN[i](output) + else: + residual = self.TCN[i](output) + output = output + residual + emb = self.mask_net(output) + + emb = emb.view(batch, self.n_src, self.embed_dim, n_frames) + emb = emb / torch.sqrt(torch.sum(emb ** 2, 2, keepdim=True)) + return emb + + +class SeparationStack(nn.Module): + def __init__( + self, + src, + embed_dim=512, + spk_vec_dim=512, + n_blocks=10, + n_repeats=4, + kernel_size=3, + norm_type="gLN", + return_all_layers=True, + ): + + super(SeparationStack, self).__init__() + self.n_blocks = n_blocks + self.n_repeats = n_repeats + self.kernel_size = kernel_size + self.norm_type = norm_type + self.src = src + self.embed_dim = embed_dim + self.return_all = return_all_layers + self.TCN = nn.ModuleList() + if not self.return_all: + self.out = nn.Conv1d(embed_dim, self.src, 1) + + for r in range(n_repeats): + for x in range(n_blocks): + if r == 0 and x == 0: + in_chan = 1 + else: + in_chan = embed_dim + padding = (kernel_size - 1) * 2 ** x // 2 + if not self.return_all: + self.TCN.append( + SepConv1DBlock( + in_chan, + embed_dim, + spk_vec_dim * self.src, + kernel_size, + padding=padding, + dilation=2 ** x, + norm_type=norm_type, + ) + ) + else: + conv = nn.Conv1d(embed_dim, self.src, 1) + torch.nn.init.kaiming_uniform_(conv.weight) + self.TCN.append( + nn.ModuleList( + [ + SepConv1DBlock( + in_chan, + embed_dim, + spk_vec_dim * self.src, + kernel_size, + padding=padding, + dilation=2 ** x, + norm_type=norm_type, + ), + conv, + ] + ) + ) + + def forward(self, mixture_w, spk_vectors): + """ + Args: + mixture_w (:class:`torch.Tensor`): Tensor of shape + [batch, n_filters, n_frames] + + Returns: + :class:`torch.Tensor`: + estimated mask of shape [batch, n_src, n_filters, n_frames] + """ + output = mixture_w + outputs = [] + # output = self.bottleneck(mixture_w) + for i in range(len(self.TCN)): + if i == 0: + if self.return_all: + conv, linear = self.TCN[i] + output = conv(output, spk_vectors) + outputs.append(linear(output)) + else: + output = self.TCN[i](output, spk_vectors) + else: + if self.return_all: + conv, linear = self.TCN[i] + residual = conv(output, spk_vectors) + output = output + residual + outputs.append(linear(output)) + else: + residual = self.TCN[i](output, spk_vectors) + output = output + residual + + if self.return_all: + out = outputs + else: + out = self.out(output) + + return out + + +class Wavesplit(nn.Module): + def __init__(self, n_src, spk_stack_kwargs={}, sep_stack_kwargs={}): + super().__init__() + + self.n_src = n_src + self.spk_stack = SpeakerStack(n_src, **spk_stack_kwargs) + self.sep_stack = SeparationStack(n_src, **sep_stack_kwargs) + + def _check_input_shape(self, x): + if x.ndim < 3: + x = x.unsqueeze(1) + return x + + def get_speaker_vectors(self, x): + x = self._check_input_shape(x) + spk_embeddings = self.spk_stack(x) + return spk_embeddings + + def split_waves(self, x, reordered_spk_vectors): + x = self._check_input_shape(x) + batch_sz, self.n_src, spk_vec_size = reordered_spk_vectors.size() + return self.sep_stack(x, reordered_spk_vectors.reshape(batch_sz, self.n_src * spk_vec_size)) + + def forward(self, x): + # use only in inference + x = self._check_input_shape(x) + spk_embeddings = self.spk_stack(x) + batch_sz, self.n_src, spk_vec_size, samples = spk_embeddings.size() + reordered = [] + for b in range(spk_embeddings.shape[0]): + cluster_ids, cluster_centers = kmeans( + spk_embeddings[b].transpose(1, 2).reshape(self.n_src * samples, spk_vec_size), + self.n_src, + device=spk_embeddings.device, + ) + reordered.append(cluster_centers) + + reordered = torch.stack(reordered) + return self.split_waves(x, reordered) + + +if __name__ == "__main__": + + sep = SeparationStack(2) + wave = torch.rand((2, 1600)) + + wavesplit = Wavesplit(2) + wavesplit(wave) diff --git a/egs/wham/wavesplit/run.sh b/egs/wham/wavesplit/run.sh new file mode 100755 index 000000000..76b2282d6 --- /dev/null +++ b/egs/wham/wavesplit/run.sh @@ -0,0 +1,116 @@ +#!/bin/bash + +# Exit on error +set -e +set -o pipefail + +# Main storage directory. You'll need disk space to dump the WHAM mixtures and the wsj0 wav +# files if you start from sphere files. +storage_dir= + +# If you start from the sphere files, specify the path to the directory and start from stage 0 +sphere_dir= # Directory containing sphere files +# If you already have wsj0 wav files, specify the path to the directory here and start from stage 1 +wsj0_wav_dir= +# If you already have the WHAM mixtures, specify the path to the directory here and start from stage 2 +wham_wav_dir=/media/sam/cb915f0e-e440-414c-bb74-df66b311d09d/2speakers_wham/ +# After running the recipe a first time, you can run it from stage 3 directly to train new models. + +# Path to the python you'll use for the experiment. Defaults to the current python +# You can run ./utils/prepare_python_env.sh to create a suitable python environment, paste the output here. +python_path=python + +# Example usage +# ./run.sh --stage 3 --tag my_tag --task sep_noisy --id 0,1 + +# General +stage=2 # Controls from which stage to start +tag="test" # Controls the directory name associated to the experiment +# You can ask for several GPUs using id (passed to CUDA_VISIBLE_DEVICES) +id=0 + +# Data +task=sep_clean # Specify the task here (sep_clean, sep_noisy, enh_single, enh_both) +sample_rate=8000 +mode=min +nondefault_src= # If you want to train a network with 3 output streams for example. + +# Training +batch_size=1 +num_workers=6 +optimizer=adam +lr=0.001 +epochs=200 + +# Evaluation +eval_use_gpu=1 + +. utils/parse_options.sh + +sr_string=$(($sample_rate/1000)) +suffix=wav${sr_string}k/$mode +dumpdir=data/$suffix # directory to put generated json file + +train_dir=$dumpdir/tr +valid_dir=$dumpdir/cv +test_dir=$dumpdir/tt + +if [[ $stage -le 0 ]]; then + echo "Stage 0: Converting sphere files to wav files" + . local/convert_sphere2wav.sh --sphere_dir $sphere_dir --wav_dir $wsj0_wav_dir +fi + +if [[ $stage -le 1 ]]; then + echo "Stage 1: Generating 8k and 16k WHAM dataset" + . local/prepare_data.sh --wav_dir $wsj0_wav_dir --out_dir $wham_wav_dir --python_path $python_path +fi + +if [[ $stage -le 2 ]]; then + # Make json directories with min/max modes and sampling rates + echo "Stage 2: Generating json files including wav path and duration" + for sr_string in 8 16; do + for mode_option in min max; do + tmp_dumpdir=data/wav${sr_string}k/$mode_option + echo "Generating json files in $tmp_dumpdir" + [[ ! -d $tmp_dumpdir ]] && mkdir -p $tmp_dumpdir + local_wham_dir=$wham_wav_dir/wav${sr_string}k/$mode_option/ + $python_path local/preprocess_wham.py --in_dir $local_wham_dir --out_dir $tmp_dumpdir + done + done +fi + +# Generate a random ID for the run if no tag is specified +uuid=$($python_path -c 'import uuid, sys; print(str(uuid.uuid4())[:8])') +if [[ -z ${tag} ]]; then + tag=${task}_${sr_string}k${mode}_${uuid} +fi +expdir=exp/train_wavesplit_${tag} +mkdir -p $expdir && echo $uuid >> $expdir/run_uuid.txt +echo "Results from the following experiment will be stored in $expdir" + +if [[ $stage -le 3 ]]; then + echo "Stage 3: Training" + mkdir -p logs + CUDA_VISIBLE_DEVICES=$id $python_path train.py \ + --train_dir $train_dir \ + --valid_dir $valid_dir \ + --task $task \ + --sample_rate $sample_rate \ + --num_workers $num_workers \ + --exp_dir ${expdir}/ | tee logs/train_${tag}.log + cp logs/train_${tag}.log $expdir/train.log + + # Get ready to publish + mkdir -p $expdir/publish_dir + echo "wham/Wavesplit" > $expdir/publish_dir/recipe_name.txt +fi + +if [[ $stage -le 4 ]]; then + echo "Stage 4 : Evaluation" + CUDA_VISIBLE_DEVICES=$id $python_path eval.py \ + --task $task \ + --test_dir $test_dir \ + --use_gpu $eval_use_gpu \ + --exp_dir ${expdir} | tee logs/eval_${tag}.log + cp logs/eval_${tag}.log $expdir/eval.log +fi diff --git a/egs/wham/wavesplit/train.py b/egs/wham/wavesplit/train.py new file mode 100644 index 000000000..225fe7a01 --- /dev/null +++ b/egs/wham/wavesplit/train.py @@ -0,0 +1,218 @@ +import os +import argparse +import json + +import torch +from torch.optim.lr_scheduler import ReduceLROnPlateau +from torch.utils.data import DataLoader +import pytorch_lightning as pl +from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping +from model import Wavesplit +from dataloading import WHAMID + +from asteroid.engine.optimizers import make_optimizer +from asteroid.engine.system import System +from losses import ClippedSDR, SpeakerVectorLoss + +# Keys which are not in the conf.yml file can be added here. +# In the hierarchical dictionary created when parsing, the key `key` can be +# found at dic['main_args'][key] + +# By default train.py will use all available GPUs. The `id` option in run.sh +# will limit the number of available GPUs for train.py . +parser = argparse.ArgumentParser() +parser.add_argument("--exp_dir", default="exp/tmp", help="Full path to save best validation model") + + +class WavesplitTrainer(System): + def on_train_start(self) -> None: + self.loss_func["spk_loss"] = self.loss_func["spk_loss"].to(self.device) + + def on_validation_epoch_start(self) -> None: + self.loss_func["spk_loss"] = self.loss_func["spk_loss"].to(self.device) + + def training_step(self, batch, batch_nb): + mixtures, oracle_s, oracle_ids = batch + b, n_spk, frames = oracle_s.size() + + # spk_vectors = self.model.get_speaker_vectors(mixtures) + # b, n_spk, embed_dim, frames = spk_vectors.size() + # spk_activity_mask = torch.ones((b, n_spk, frames)).to(mixtures) + # spk_loss, reordered = self.loss_func["spk_loss"](spk_vectors, spk_activity_mask, oracle_ids) + spk_loss = 0 + # reordered = reordered.mean(-1) # take centroid + reordered = self.loss_func["spk_loss"].spk_embeddings[oracle_ids] + + separated = self.model.split_waves(mixtures, reordered) + + if self.model.sep_stack.return_all: + n_layers = len(separated) + separated = torch.stack(separated).transpose(0, 1) + separated = separated.reshape( + b * n_layers, n_spk, frames + ) # in validation take only last layer + oracle_s = ( + oracle_s.unsqueeze(1).repeat(1, n_layers, 1, 1).reshape(b * n_layers, n_spk, frames) + ) + + sep_loss = self.loss_func["sep_loss"](separated, oracle_s).mean() + tot_loss = sep_loss + spk_loss + + tqdm_log = {"spk_loss": spk_loss, "sep_loss": sep_loss} + tensorboard_logs = {"spk_loss/train": spk_loss, "sep_loss/train": sep_loss} + + return {"loss": tot_loss, "log": tensorboard_logs, "progress_bar": tqdm_log} + + def validation_step(self, batch, batch_nb): + mixtures, oracle_s, oracle_ids = batch + b, n_spk, frames = oracle_s.size() + # spk_vectors = self.model.get_speaker_vectors(mixtures) + ##b, n_spk, embed_dim, frames = spk_vectors.size() + # spk_activity_mask = torch.ones((b, n_spk, frames)).to(mixtures) + # spk_loss, reordered = self.loss_func["spk_loss"](spk_vectors, + # spk_activity_mask, + # oracle_ids) + # reordered = reordered.mean(-1) # take centroid + reordered = self.loss_func["spk_loss"].spk_embeddings[oracle_ids] + spk_loss = 0 + + separated = self.model.split_waves(mixtures, reordered) + + if self.model.sep_stack.return_all: + separated = separated[-1] + + sep_loss = self.loss_func["sep_loss"](separated, oracle_s).mean() + tot_loss = sep_loss + spk_loss + + tensorboard_logs = {"spk_loss/val": spk_loss, "sep_loss/val": sep_loss} + + return {"val_loss": tot_loss.item(), "log": tensorboard_logs} + + +def main(conf): + train_set = WHAMID( + conf["data"]["train_dir"], + conf["data"]["task"], + sample_rate=conf["data"]["sample_rate"], + segment=conf["data"]["segment"], + nondefault_nsrc=conf["data"]["nondefault_nsrc"], + ) + val_set = WHAMID( + conf["data"]["valid_dir"], + conf["data"]["task"], + sample_rate=conf["data"]["sample_rate"], + nondefault_nsrc=conf["data"]["nondefault_nsrc"], + segment=conf["data"]["segment"] * 2, + ) + + train_loader = DataLoader( + train_set, + shuffle=True, + batch_size=conf["training"]["batch_size"], + num_workers=conf["training"]["num_workers"], + drop_last=True, + ) + val_loader = DataLoader( + val_set, + shuffle=False, + batch_size=conf["training"]["batch_size"], + num_workers=conf["training"]["num_workers"], + drop_last=True, + ) + # Update number of source values (It depends on the task) + conf["masknet"].update({"n_src": train_set.n_src}) + + model = Wavesplit( + conf["masknet"]["n_src"], + {"embed_dim": 512}, + {"embed_dim": 512, "spk_vec_dim": 512, "n_repeats": 4, "return_all_layers": False}, + ) + + # Just after instantiating, save the args. Easy loading in the future. + exp_dir = conf["main_args"]["exp_dir"] + os.makedirs(exp_dir, exist_ok=True) + conf_path = os.path.join(exp_dir, "conf.yml") + with open(conf_path, "w") as outfile: + yaml.safe_dump(conf, outfile) + + # Define Loss function. + loss_spk = SpeakerVectorLoss( + len(train_set.spk2indx), embed_dim=512, loss_type="distance", gaussian_reg=0, distance_reg=0 + ) + loss_sep = ClippedSDR() + + # optimizer takes also loss speaker as spk oracle embeddings are trainable + optimizer = make_optimizer( + list(model.parameters()) + list(loss_spk.parameters()), **conf["optim"] + ) + # Define scheduler + scheduler = None + if conf["training"]["half_lr"]: + scheduler = ReduceLROnPlateau(optimizer=optimizer, factor=0.5, patience=5) + + system = WavesplitTrainer( + model=model, + loss_func={"spk_loss": loss_spk, "sep_loss": loss_sep}, + optimizer=optimizer, + train_loader=train_loader, + val_loader=val_loader, + scheduler=scheduler, + config=conf, + ) + + # Define callbacks + callbacks = [] + checkpoint_dir = os.path.join(exp_dir, "checkpoints/") + checkpoint = ModelCheckpoint( + checkpoint_dir, monitor="val_loss", mode="min", save_top_k=5, verbose=True + ) + callbacks.append(checkpoint) + if conf["training"]["early_stop"]: + callbacks.append(EarlyStopping(monitor="val_loss", mode="min", patience=30, verbose=True)) + + # Don't ask GPU if they are not available. + gpus = -1 if torch.cuda.is_available() else None + distributed_backend = "ddp" if torch.cuda.is_available() else None + trainer = pl.Trainer( + max_epochs=conf["training"]["epochs"], + callbacks=callbacks, + default_root_dir=exp_dir, + gpus=gpus, + distributed_backend=distributed_backend, + gradient_clip_val=conf["training"]["gradient_clipping"], + ) + trainer.fit(system) + + best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()} + with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f: + json.dump(best_k, f, indent=0) + + state_dict = torch.load(checkpoint.best_model_path) + system.load_state_dict(state_dict=state_dict["state_dict"]) + system.cpu() + + to_save = system.model.serialize() + to_save.update(train_set.get_infos()) + torch.save(to_save, os.path.join(exp_dir, "best_model.pth")) + + +if __name__ == "__main__": + import yaml + from pprint import pprint + from asteroid.utils import prepare_parser_from_dict, parse_args_as_dict + + # We start with opening the config file conf.yml as a dictionary from + # which we can create parsers. Each top level key in the dictionary defined + # by the YAML file creates a group in the parser. + with open("local/conf.yml") as f: + def_conf = yaml.safe_load(f) + parser = prepare_parser_from_dict(def_conf, parser=parser) + # Arguments are then parsed into a hierarchical dictionary (instead of + # flat, as returned by argparse) to facilitate calls to the different + # asteroid methods (see in main). + # plain_args is the direct output of parser.parse_args() and contains all + # the attributes in an non-hierarchical structure. It can be useful to also + # have it so we included it here but it is not used. + arg_dic, plain_args = parse_args_as_dict(parser, return_plain_args=True) + pprint(arg_dic) + main(arg_dic) diff --git a/egs/wham/wavesplit/utils b/egs/wham/wavesplit/utils new file mode 120000 index 000000000..bcee78945 --- /dev/null +++ b/egs/wham/wavesplit/utils @@ -0,0 +1 @@ +../ConvTasNet/utils/ \ No newline at end of file