diff --git a/egs/wham/WaveSplit/README.md b/egs/wham/WaveSplit/README.md new file mode 100644 index 000000000..3e51a096a --- /dev/null +++ b/egs/wham/WaveSplit/README.md @@ -0,0 +1,11 @@ +### WaveSplit + + we train on 1 sec now. + + tried with 256 embedding dimension. + + still does not work with oracle embeddings. + + not clear how in sep stack loss at every layer is computed ( is the same output layer used in all ?). + Also no mention in the paper about output layer and that first conv has no skip connection. + \ No newline at end of file diff --git a/egs/wham/WaveSplit/eval.py b/egs/wham/WaveSplit/eval.py new file mode 100644 index 000000000..c3e76d3f6 --- /dev/null +++ b/egs/wham/WaveSplit/eval.py @@ -0,0 +1,118 @@ +import os +import random +import soundfile as sf +import torch +import yaml +import json +import argparse +import pandas as pd +from tqdm import tqdm +from pprint import pprint + +from asteroid.metrics import get_metrics +from asteroid.losses import PITLossWrapper, pairwise_neg_sisdr +from asteroid.data.wham_dataset import WhamDataset +from asteroid.utils import tensors_to_device + +from model import load_best_model + +parser = argparse.ArgumentParser() +parser.add_argument('--task', type=str, required=True, + help='One of `enh_single`, `enh_both`, ' + '`sep_clean` or `sep_noisy`') +parser.add_argument('--test_dir', type=str, required=True, + help='Test directory including the json files') +parser.add_argument('--use_gpu', type=int, default=0, + help='Whether to use the GPU for model execution') +parser.add_argument('--exp_dir', default='exp/tmp', + help='Experiment root') +parser.add_argument('--n_save_ex', type=int, default=50, + help='Number of audio examples to save, -1 means all') + +compute_metrics = ['si_sdr', 'sdr', 'sir', 'sar', 'stoi'] + + +def main(conf): + model = load_best_model(conf['train_conf'], conf['exp_dir']) + # Handle device placement + if conf['use_gpu']: + model.cuda() + model_device = next(model.parameters()).device + test_set = WhamDataset(conf['test_dir'], conf['task'], + sample_rate=conf['sample_rate'], + nondefault_nsrc=model.masker.n_src, + segment=None) # Uses all segment length + # Used to reorder sources only + loss_func = PITLossWrapper(pairwise_neg_sisdr, mode='pairwise') + + # Randomly choose the indexes of sentences to save. + ex_save_dir = os.path.join(conf['exp_dir'], 'examples/') + if conf['n_save_ex'] == -1: + conf['n_save_ex'] = len(test_set) + save_idx = random.sample(range(len(test_set)), conf['n_save_ex']) + series_list = [] + torch.no_grad().__enter__() + for idx in tqdm(range(len(test_set))): + # Forward the network on the mixture. + mix, sources = tensors_to_device(test_set[idx], device=model_device) + est_sources = model(mix[None, None]) + loss, reordered_sources = loss_func(est_sources, sources[None], + return_est=True) + mix_np = mix[None].cpu().data.numpy() + sources_np = sources.squeeze().cpu().data.numpy() + est_sources_np = reordered_sources.squeeze().cpu().data.numpy() + utt_metrics = get_metrics(mix_np, sources_np, est_sources_np, + sample_rate=conf['sample_rate']) + utt_metrics['mix_path'] = test_set.mix[idx][0] + series_list.append(pd.Series(utt_metrics)) + + # Save some examples in a folder. Wav files and metrics as text. + if idx in save_idx: + local_save_dir = os.path.join(ex_save_dir, 'ex_{}/'.format(idx)) + os.makedirs(local_save_dir, exist_ok=True) + sf.write(local_save_dir + "mixture.wav", mix_np[0], + conf['sample_rate']) + # Loop over the sources and estimates + for src_idx, src in enumerate(sources_np): + sf.write(local_save_dir + "s{}.wav".format(src_idx+1), src, + conf['sample_rate']) + for src_idx, est_src in enumerate(est_sources_np): + sf.write(local_save_dir + "s{}_estimate.wav".format(src_idx+1), + est_src, conf['sample_rate']) + # Write local metrics to the example folder. + with open(local_save_dir + 'metrics.json', 'w') as f: + json.dump(utt_metrics, f, indent=0) + + # Save all metrics to the experiment folder. + all_metrics_df = pd.DataFrame(series_list) + all_metrics_df.to_csv(os.path.join(conf['exp_dir'], 'all_metrics.csv')) + + # Print and save summary metrics + final_results = {} + for metric_name in compute_metrics: + input_metric_name = 'input_' + metric_name + ldf = all_metrics_df[metric_name] - all_metrics_df[input_metric_name] + final_results[metric_name] = all_metrics_df[metric_name].mean() + final_results[metric_name + '_imp'] = ldf.mean() + print('Overall metrics :') + pprint(final_results) + with open(os.path.join(conf['exp_dir'], 'final_metrics.json'), 'w') as f: + json.dump(final_results, f, indent=0) + + +if __name__ == '__main__': + args = parser.parse_args() + arg_dic = dict(vars(args)) + + # Load training config + conf_path = os.path.join(args.exp_dir, 'conf.yml') + with open(conf_path) as f: + train_conf = yaml.safe_load(f) + arg_dic['sample_rate'] = train_conf['data']['sample_rate'] + arg_dic['train_conf'] = train_conf + + if args.task != arg_dic['train_conf']['data']['task']: + print("Warning : the task used to test is different than " + "the one from training, be sure this is what you want.") + + main(arg_dic) diff --git a/egs/wham/WaveSplit/local/conf.yml b/egs/wham/WaveSplit/local/conf.yml new file mode 100644 index 000000000..912366f48 --- /dev/null +++ b/egs/wham/WaveSplit/local/conf.yml @@ -0,0 +1,41 @@ +# Filterbank config +filterbank: + n_filters: 64 + kernel_size: 16 + stride: 8 +# Network config +masknet: + in_chan: 64 + n_src: 2 + out_chan: 64 + bn_chan: 128 + hid_size: 128 + chunk_size: 250 + hop_size: 125 + n_repeats: 6 + mask_act: 'sigmoid' + bidirectional: true + dropout: 0 +# Training config +training: + epochs: 200 + batch_size: 4 + num_workers: 4 + half_lr: yes + early_stop: yes + gradient_clipping: 5000 +# Optim config +optim: + optimizer: adam + lr: 0.001 + weight_decay: 0. +# Data config +data: + train_dir: data/wav8k/min/tr/ + valid_dir: data/wav8k/min/cv/ + data_augmentation: True + task: sep_clean + nondefault_nsrc: + sample_rate: 8000 + mode: min + segment: 1 \ No newline at end of file diff --git a/egs/wham/WaveSplit/local/convert_sphere2wav.sh b/egs/wham/WaveSplit/local/convert_sphere2wav.sh new file mode 100644 index 000000000..dc842b37e --- /dev/null +++ b/egs/wham/WaveSplit/local/convert_sphere2wav.sh @@ -0,0 +1,38 @@ +#!/bin/bash +# MIT Copyright (c) 2018 Kaituo XU + + +sphere_dir=tmp +wav_dir=tmp + +. utils/parse_options.sh || exit 1; + + +echo "Download sph2pipe_v2.5 into egs/tools" +mkdir -p ../../tools +wget http://www.openslr.org/resources/3/sph2pipe_v2.5.tar.gz -P ../../tools +cd ../../tools && tar -xzvf sph2pipe_v2.5.tar.gz && gcc -o sph2pipe_v2.5/sph2pipe sph2pipe_v2.5/*.c -lm && cd - + +echo "Convert sphere format to wav format" +sph2pipe=../../tools/sph2pipe_v2.5/sph2pipe + +if [ ! -x $sph2pipe ]; then + echo "Could not find (or execute) the sph2pipe program at $sph2pipe"; + exit 1; +fi + +tmp=data/local/ +mkdir -p $tmp + +[ ! -f $tmp/sph.list ] && find $sphere_dir -iname '*.wv*' | grep -e 'si_tr_s' -e 'si_dt_05' -e 'si_et_05' > $tmp/sph.list + +if [ ! -d $wav_dir ]; then + while read line; do + wav=`echo "$line" | sed "s:wv[12]:wav:g" | awk -v dir=$wav_dir -F'/' '{printf("%s/%s/%s/%s", dir, $(NF-2), $(NF-1), $NF)}'` + echo $wav + mkdir -p `dirname $wav` + $sph2pipe -f wav $line > $wav + done < $tmp/sph.list > $tmp/wav.list +else + echo "Do you already get wav files? if not, please remove $wav_dir" +fi diff --git a/egs/wham/WaveSplit/local/prepare_data.sh b/egs/wham/WaveSplit/local/prepare_data.sh new file mode 100755 index 000000000..0f76d106b --- /dev/null +++ b/egs/wham/WaveSplit/local/prepare_data.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +wav_dir=tmp +out_dir=tmp +python_path=python + +. utils/parse_options.sh + +## Download WHAM noises +mkdir -p $out_dir +echo "Download WHAM noises into $out_dir" +# If downloading stalls for more than 20s, relaunch from previous state. +wget -c --tries=0 --read-timeout=20 https://storage.googleapis.com/whisper-public/wham_noise.zip -P $out_dir + +echo "Download WHAM scripts into $out_dir" +wget https://storage.googleapis.com/whisper-public/wham_scripts.tar.gz -P $out_dir +mkdir -p $out_dir/wham_scripts +tar -xzvf $out_dir/wham_scripts.tar.gz -C $out_dir/wham_scripts +mv $out_dir/wham_scripts.tar.gz $out_dir/wham_scripts + +wait + +unzip $out_dir/wham_noise.zip $out_dir >> logs/unzip_wham.log + +echo "Run python scripts to create the WHAM mixtures" +# Requires : Numpy, Scipy, Pandas, and Pysoundfile +cd $out_dir/wham_scripts +$python_path create_wham_from_scratch.py \ + --wsj0-root $wav_dir \ + --wham-noise-root $out_dir/wham_noise\ + --output-dir $out_dir +cd - \ No newline at end of file diff --git a/egs/wham/WaveSplit/local/preprocess_wham.py b/egs/wham/WaveSplit/local/preprocess_wham.py new file mode 100644 index 000000000..36a4e2bc9 --- /dev/null +++ b/egs/wham/WaveSplit/local/preprocess_wham.py @@ -0,0 +1,84 @@ +import argparse +import json +import os +import soundfile as sf +import glob + + +def preprocess_task(task, in_dir, out_dir): + if not os.path.exists(out_dir): + os.makedirs(out_dir) + + if task == "mix_both": + mix_both = glob.glob(os.path.join(in_dir, "mix_both", "*.wav")) + examples = [] + for mix in mix_both: + filename = mix.split("/")[-1] + spk1_id = filename.split("_")[0][:3] + spk2_id = filename.split("_")[2][:3] + length = len(sf.SoundFile(mix)) + + noise = os.path.join(in_dir, "noise", filename) + s1 = os.path.join(in_dir, "s1", filename) + s2 = os.path.join(in_dir, "s2", filename) + + ex = {"mix": mix, "sources": [s1 ,s2], "noise": noise, "spk_id": [spk1_id, spk2_id], "length": length} + examples.append(ex) + + with open(os.path.join(out_dir, 'mix_both.json'), 'w') as f: + json.dump(examples, f, indent=4) + + elif task == "mix_clean": + mix_clean = glob.glob(os.path.join(in_dir, "mix_clean", "*.wav")) + examples = [] + for mix in mix_clean: + filename = mix.split("/")[-1] + spk1_id = filename.split("_")[0][:3] + spk2_id = filename.split("_")[2][:3] + length = len(sf.SoundFile(mix)) + + s1 = os.path.join(in_dir, "s1", filename) + s2 = os.path.join(in_dir, "s2", filename) + + ex = {"mix": mix, "sources": [s1, s2], "spk_id": [spk1_id, spk2_id], "length": length} + examples.append(ex) + + with open(os.path.join(out_dir, 'mix_clean.json'), 'w') as f: + json.dump(examples, f, indent=4) + + elif task == "mix_single": + mix_single = glob.glob(os.path.join(in_dir, "mix_single", "*.wav")) + examples = [] + for mix in mix_single: + filename = mix.split("/")[-1] + spk1_id = filename.split("_")[0][:3] + length = len(sf.SoundFile(mix)) + + s1 = os.path.join(in_dir, "s1", filename) + + ex = {"mix": mix, "sources": [s1], "spk_id": [spk1_id], "length": length} + examples.append(ex) + + with open(os.path.join(out_dir, 'mix_single.json'), 'w') as f: + json.dump(examples, f, indent=4) + else: + raise EnvironmentError + + +def preprocess(inp_args): + tasks = ['mix_both', 'mix_clean', 'mix_single'] + for split in ["tr", "cv", "tt"]: + for task in tasks: + preprocess_task(task, os.path.join(inp_args.in_dir, split), os.path.join(inp_args.out_dir, split)) + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("WHAM data preprocessing") + parser.add_argument('--in_dir', type=str, default=None, + help='Directory path of wham including tr, cv and tt') + parser.add_argument('--out_dir', type=str, default=None, + help='Directory path to put output files') + args = parser.parse_args() + print(args) + preprocess(args) diff --git a/egs/wham/WaveSplit/local/resample_dataset.py b/egs/wham/WaveSplit/local/resample_dataset.py new file mode 100644 index 000000000..591400003 --- /dev/null +++ b/egs/wham/WaveSplit/local/resample_dataset.py @@ -0,0 +1,29 @@ +import argparse +import os +from glob import glob +from distutils.dir_util import copy_tree +from scipy.signal import resample_poly +import soundfile as sf + +parser = argparse.ArgumentParser("Script for resampling a dataset") +parser.add_argument("source_dir", type=str) +parser.add_argument("out_dir", type=str) +parser.add_argument("original_sr", type=int) +parser.add_argument("target_sr", type=int) +parser.add_argument("extension", type=str, default="wav") + + +def main(out_dir, original_sr, target_sr, extension): + assert original_sr >= target_sr, "Upsampling not supported" + wavs = glob(os.path.join(out_dir, "*.{}".format(extension))) + for wav in wavs: + data, fs = sf.read(wav) + assert fs == original_sr + data = resample_poly(data, target_sr, fs) + sf.write(wav, data) + + +if __name__ == "__main__": + args = parser.add_argument() + copy_tree(args.source_dir, args.out_dir) + main(args.out_dir, args.original_sr, args.target_sr, args.extension) \ No newline at end of file diff --git a/egs/wham/WaveSplit/losses.py b/egs/wham/WaveSplit/losses.py new file mode 100644 index 000000000..2932e26ef --- /dev/null +++ b/egs/wham/WaveSplit/losses.py @@ -0,0 +1,188 @@ +from torch import nn +import torch +import numpy as np +from torch.nn import functional as F +from itertools import permutations +from asteroid.losses.sdr import MultiSrcNegSDR, SingleSrcNegSDR +from asteroid.losses import PITLossWrapper, PairwiseNegSDR,pairwise_neg_sisdr +import math + + +class ClippedSDR(nn.Module): + + def __init__(self, clip_value=-30): + super(ClippedSDR, self).__init__() + + self.snr = MultiSrcNegSDR("snr") + self.clip_value = float(clip_value) + + def forward(self, est_targets, targets): + + return torch.clamp(self.snr(est_targets, targets), min=self.clip_value) + + +class SpeakerVectorLoss(nn.Module): + + def __init__(self, n_speakers, embed_dim=32, learnable_emb=True, loss_type="global", + weight=2, distance_reg=0.3, gaussian_reg=0.2, return_oracle=False): + super(SpeakerVectorLoss, self).__init__() + + self.learnable_emb = learnable_emb + self.loss_type = loss_type + self.weight = float(weight) + self.distance_reg = float(distance_reg) + self.gaussian_reg = float(gaussian_reg) + self.return_oracle = return_oracle + + assert loss_type in ["distance", "global", "local"] + + spk_emb = torch.eye(max(n_speakers, embed_dim)) # one-hot init works better according to Neil + spk_emb = spk_emb[:n_speakers, :embed_dim] + + if learnable_emb == True: + self.spk_embeddings = nn.Parameter(spk_emb) + else: + self.register_buffer("spk_embeddings", spk_emb) + + if loss_type != "distance": + self.alpha = nn.Parameter(torch.Tensor([1.])) + self.beta = nn.Parameter(torch.Tensor([0.])) + + @staticmethod + def _l_dist_speaker(c_spk_vec_perm, spk_embeddings, spk_labels, spk_mask): + + utt_embeddings = spk_embeddings[spk_labels].unsqueeze(-1) * spk_mask.unsqueeze(2) + c_spk = c_spk_vec_perm[:, 0] + pair_dist = ((c_spk.unsqueeze(1) - c_spk_vec_perm)**2).mean(2) + pair_dist = pair_dist[:, 1:] + distance = ((c_spk_vec_perm - utt_embeddings)**2).mean(dim=(1,2)) + return distance + F.relu(1. - pair_dist).mean(dim=(1)) + + def _l_local_speaker(self, c_spk_vec_perm, spk_embeddings, spk_labels, spk_mask): + + utt_embeddings = spk_embeddings[spk_labels].unsqueeze(-1) * spk_mask.unsqueeze(2) + alpha = torch.clamp(self.alpha, 1e-8) + + distance = alpha*((c_spk_vec_perm - utt_embeddings)**2).mean(2) + self.beta + distances = alpha * ((c_spk_vec_perm.unsqueeze(1) - utt_embeddings.unsqueeze(2)) ** 2).mean(3) + self.beta + distances = torch.exp(-distances).mean(1) + + return (distance + torch.log(distances)).mean(1) + + def _l_global_speaker(self, c_spk_vec_perm, spk_embeddings, spk_labels, spk_mask): + + utt_embeddings = spk_embeddings[spk_labels].unsqueeze(-1) * spk_mask.unsqueeze(2) + alpha = torch.clamp(self.alpha, 1e-8) + + distance_utt = alpha*((c_spk_vec_perm - utt_embeddings)**2).sum(2) + self.beta + + B, src, embed_dim, frames = c_spk_vec_perm.size() + spk_embeddings = spk_embeddings.reshape(1, spk_embeddings.shape[0], embed_dim, 1).expand(B, -1, -1, frames) + distances = alpha * ((c_spk_vec_perm.unsqueeze(1) - spk_embeddings.unsqueeze(2)) ** 2).sum(3) + self.beta + distances = torch.exp(-distances).sum(1) + + return (distance_utt + torch.log(distances)).sum(1) + + + # exp normalize trick + #with torch.no_grad(): + # b = torch.max(distances, dim=1, keepdim=True)[0] + #out = -distance_utt + b.squeeze(1) - torch.log(torch.exp(-distances + b).sum(1)) + #return out.sum(1) + + + def forward(self, speaker_vectors, spk_mask, spk_labels): + + if self.gaussian_reg: + noise = torch.randn(self.spk_embeddings.size(), device=speaker_vectors.device)*math.sqrt(self.gaussian_reg) + spk_embeddings = self.spk_embeddings + noise + else: + spk_embeddings = self.spk_embeddings + + if self.learnable_emb or self.gaussian_reg: # re project on unit sphere + + spk_embeddings = spk_embeddings / torch.sum(spk_embeddings ** 2, -1, keepdim=True).sqrt() + + if self.distance_reg: + + pairwise_dist = (torch.abs(spk_embeddings.unsqueeze(0) - spk_embeddings.unsqueeze(1))).mean(-1).fill_diagonal_(np.inf) + distance_reg = -torch.sum(torch.min(torch.log(pairwise_dist), dim=-1)[0]) + + # speaker vectors B, n_src, dim, frames + # spk mask B, n_src, frames boolean mask + # spk indxs list of len B of list which contains spk label for current utterance + B, n_src, embed_dim, frames = speaker_vectors.size() + + n_src = speaker_vectors.shape[1] + perms = list(permutations(range(n_src))) + if self.loss_type == "distance": + loss_set = torch.stack([self._l_dist_speaker(speaker_vectors[:, perm], spk_embeddings, spk_labels, spk_mask) for perm in perms], + dim=1) + elif self.loss_type == "local": + loss_set = torch.stack([self._l_local_speaker(speaker_vectors[:, perm], spk_embeddings, spk_labels, spk_mask) for perm in perms], + dim=1) + else: + loss_set = torch.stack([self._l_global_speaker(speaker_vectors[:, perm], spk_embeddings, spk_labels, spk_mask) for perm in perms], + dim=1) + + # Indexes and values of min losses for each batch element + min_loss, min_loss_idx = torch.min(loss_set, dim=1) + + # reorder sources for each frame !! + perms = min_loss.new_tensor(perms, dtype=torch.long) + perms = perms[..., None, None].expand(-1, -1, B, frames) + min_loss_idx = min_loss_idx[None, None,...].expand(1, n_src, -1, -1) + min_loss_perm = torch.gather(perms, dim=0, index=min_loss_idx)[0] + min_loss_perm = min_loss_perm.transpose(0, 1).reshape(B, n_src, 1, frames).expand(-1, -1, embed_dim, -1) + # tot_loss + + spk_loss = self.weight*min_loss.mean() + if self.distance_reg: + spk_loss += self.distance_reg*distance_reg + reordered_sources = torch.gather(speaker_vectors, dim=1, index=min_loss_perm) + + if self.return_oracle: + utt_embeddings = spk_embeddings[spk_labels].unsqueeze(-1) * spk_mask.unsqueeze(2) + return spk_loss, reordered_sources, utt_embeddings + + return spk_loss, reordered_sources + + +if __name__ == "__main__": + n_speakers = 101 + emb_speaker = 256 + + # testing exp normalize average + #distances = torch.ones((1, 101, 4000)) + #with torch.no_grad(): + # b = torch.max(distances, dim=1, keepdim=True)[0] + #out = b.squeeze(1) - torch.log(torch.exp(-distances + b).sum(1)) + #out2 = - torch.log(torch.exp(-distances).sum(1)) + + loss_spk = SpeakerVectorLoss(n_speakers, emb_speaker, loss_type="distance") + + speaker_vectors = torch.rand(2, 3, emb_speaker, 200) + speaker_labels = torch.from_numpy(np.array([[1, 2, 0], [5, 2, 10]])) + speaker_mask = torch.randint(0, 2, (2, 3, 200)) # silence where there are no speakers actually thi is test + speaker_mask[:, -1, :] = speaker_mask[:, -1, :]*0 + loss_spk(speaker_vectors, speaker_mask, speaker_labels) + + c = ClippedSDR(-30) + a = torch.rand((2, 3, 200)) + print(c(a, a)) + + + + + + + + + + + + + + + + diff --git a/egs/wham/WaveSplit/run.sh b/egs/wham/WaveSplit/run.sh new file mode 100755 index 000000000..74483eb14 --- /dev/null +++ b/egs/wham/WaveSplit/run.sh @@ -0,0 +1,146 @@ +#!/bin/bash + +set -e # Exit on error +# Main storage directory. You'll need disk space to dump the WHAM mixtures and the wsj0 wav +# files if you start from sphere files. +storage_dir=/media/sam/Data/WSJ/wham_scripts/ +#storage_dir=/srv/storage/talc3@talc-data.nancy/multispeech/calcul/users/mpariente/DATA/wsj0_wav + + +# If you start from the sphere files, specify the path to the directory and start from stage 0 +sphere_dir= # Directory containing sphere files +# If you already have wsj0 wav files, specify the path to the directory here and start from stage 1 +wsj0_wav_dir=${storage_dir}/wsj0_wav/ +# If you already have the WHAM mixtures, specify the path to the directory here and start from stage 2 +wham_wav_dir=${storage_dir}/2speakers_wham/ +# After running the recipe a first time, you can run it from stage 3 directly to train new models. + +# we use directly wsj0 for data-augmentation. Because original WSJ0 is 16k we copy the data and downsample it offline +# to use for 8k separation training. This is accomlished in step 3. If only 16k separation is desired one can skip +# stage 3. +wsj0_wav_dir_8k=data/wsj0_8k + +# Path to the python you'll use for the experiment. Defaults to the current python +# You can run ./utils/prepare_python_env.sh to create a suitable python environment, paste the output here. +#python_path=${storage_dir}/asteroid_conda/miniconda3/bin/python +python_path=python + +# Example usage +# ./run.sh --stage 3 --tag my_tag --task sep_noisy --id 0,1 + +# General +stage=3 # Controls from which stage to start +tag="test" # Controls the directory name associated to the experiment +# You can ask for several GPUs using id (passed to CUDA_VISIBLE_DEVICES) +id=0 + +# Data +data_dir=data # Local data directory (No disk space needed) +task=sep_clean # Specify the task here (sep_clean, sep_noisy, enh_single, enh_both) +sample_rate=8000 +mode=min +nondefault_src= # If you want to train a network with 3 output streams for example. + +# Training +batch_size=4 +num_workers=4 +#optimizer=adam +lr=0.001 +epochs=400 + +# Evaluation +eval_use_gpu=1 + + +. utils/parse_options.sh + +if not python -c "import pysndfx" &> /dev/null; then + echo 'This recipe requires pysndfx. Please install with pip install pysndfx. Exiting.' + exit +fi + +if [[ $stage -le -1 ]]; then + echo "Stage -1: Creating python environment to run this" + if [[ -x "${python_path}" ]] + then + echo "The provided python path is executable, don't proceed to installation." + else + . utils/prepare_python_env.sh --install_dir $python_path --asteroid_root ../../.. + echo "Miniconda3 install can be found at $python_path" + python_path=${python_path}/miniconda3/bin/python + echo -e "\n To use this python version for the next experiments, change" + echo -e "python_path=$python_path at the beginning of the file \n" + fi +fi + + + +if [[ $stage -le 0 ]]; then + echo "Stage 0: Converting sphere files to wav files" + . local/convert_sphere2wav.sh --sphere_dir $sphere_dir --wav_dir $wsj0_wav_dir +fi + + +if [[ $stage -le 1 ]]; then + echo "Stage 1: Generating 8k and 16k WHAM dataset" + . local/prepare_data.sh --wav_dir $wsj0_wav_dir --out_dir $wham_wav_dir --python_path $python_path +fi + + +if [[ $stage -le 2 ]]; then + # Make json directories with min/max modes and sampling rates + echo "Stage 2: Generating json files including wav path and duration" + for sr_string in 8; do + for mode in min; do + tmp_dumpdir=data/wav${sr_string}k/$mode + echo "Generating json files in $tmp_dumpdir" + [[ ! -d $tmp_dumpdir ]] && mkdir -p $tmp_dumpdir + local_wham_dir=$wham_wav_dir/wav${sr_string}k/$mode/ + $python_path local/preprocess_wham.py --in_dir $local_wham_dir --out_dir $tmp_dumpdir + done + done +fi + + +sr_string=$(($sample_rate/1000)) +suffix=wav${sr_string}k/$mode +dumpdir=data/$suffix # directory to put generated json file + +train_dir=$dumpdir/tr +valid_dir=$dumpdir/cv +test_dir=$dumpdir/tt + +# Generate a random ID for the run if no tag is specified +uuid=$($python_path -c 'import uuid, sys; print(str(uuid.uuid4())[:8])') +if [[ -z ${tag} ]]; then + tag=${task}_${sr_string}k${mode}_${uuid} +fi +expdir=exp/train_dprnn_${tag} +mkdir -p $expdir && echo $uuid >> $expdir/run_uuid.txt +echo "Results from the following experiment will be stored in $expdir" + + +if [[ $stage -le 3 ]]; then + echo "Stage 4: Training" + mkdir -p logs + CUDA_VISIBLE_DEVICES=$id $python_path train.py \ + --train_dir $train_dir \ + --valid_dir $valid_dir \ + --task $task \ + --sample_rate $sample_rate \ + --lr $lr \ + --epochs $epochs \ + --batch_size $batch_size \ + --num_workers $num_workers \ + --exp_dir ${expdir}/ | tee logs/train_${tag}.log +fi + + +if [[ $stage -le 4 ]]; then + echo "Stage 5 : Evaluation" + CUDA_VISIBLE_DEVICES=$id $python_path eval.py \ + --task $task \ + --test_dir $test_dir \ + --use_gpu $eval_use_gpu \ + --exp_dir ${expdir} +fi diff --git a/egs/wham/WaveSplit/train.py b/egs/wham/WaveSplit/train.py new file mode 100644 index 000000000..42af42f80 --- /dev/null +++ b/egs/wham/WaveSplit/train.py @@ -0,0 +1,350 @@ +import os +import argparse +import json + +import torch +import warnings +from torch.optim.lr_scheduler import ReduceLROnPlateau +from torch.utils.data import DataLoader +import pytorch_lightning as pl +from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping +from collections import OrderedDict + +from asteroid.data.wham_dataset import WhamDataset +from asteroid.engine.system import System +from asteroid.losses import PITLossWrapper, pairwise_neg_sisdr + + +from losses import SpeakerVectorLoss, ClippedSDR +from wavesplit import SpeakerStack, SeparationStack +from asteroid.filterbanks import make_enc_dec +from wavesplitwham import WaveSplitWhamDataset +from argparse import Namespace +from asteroid.utils import flatten_dict +from copy import deepcopy + +# Keys which are not in the conf.yml file can be added here. +# In the hierarchical dictionary created when parsing, the key `key` can be +# found at dic['main_args'][key] + +# By default train.py will use all available GPUs. The `id` option in run.sh +# will limit the number of available GPUs for train.py . +# This can be changed: `python train.py --gpus 0,1` will only train on 2 GPUs. +parser = argparse.ArgumentParser() +parser.add_argument('--gpus', type=str, help='list of GPUs', default='-1') +parser.add_argument('--exp_dir', default='exp/tmp', + help='Full path to save best validation model') + +warnings.simplefilter("ignore", UserWarning) + + +class Wavesplit(pl.LightningModule): # redefinition + + def __init__(self, spk_stack, sep_stack, optimizer, spk_loss, sep_loss, train_loader, + val_loader=None, scheduler=None, config=None): + super().__init__() + + #self.spk_stack = SpeakerStack(256, 2, 1, 1) + + + #self.spk_loss = SpeakerVectorLoss(101, 256, False, "distance", 10) + self.spk_stack = spk_stack + self.sep_stack = sep_stack + self.optimizer = optimizer + self.sep_loss = sep_loss + self.spk_loss = spk_loss + + self.train_loader = train_loader + self.val_loader = val_loader + self.scheduler = scheduler + config = {} if config is None else config + self.config = config + # hparams will be logged to Tensorboard as text variables. + # torch doesn't support None in the summary writer for now, convert + # None to strings temporarily. + # See https://github.com/pytorch/pytorch/issues/33140 + self.hparams = Namespace(**self.none_to_string(flatten_dict(config))) + + #n_speakers = 100 + #embed_dim = 128 + #spk_emb = torch.eye(max(n_speakers, embed_dim)) # one-hot init works better according to Neil + #spk_emb = spk_emb[:n_speakers, :embed_dim] + + #self.oracle = spk_emb.cuda() + + def forward(self, *args, **kwargs): + """ Applies forward pass of the model. + + Returns: + :class:`torch.Tensor` + """ + + return self.model(*args, **kwargs) + + def common_step(self, batch, batch_nb): + """ Common forward step between training and validation. + + The function of this method is to unpack the data given by the loader, + forward the batch through the model and compute the loss. + + Args: + batch: the object returned by the loader (a list of torch.Tensor + in most cases) but can be something else. + batch_nb (int): The number of the batch in the epoch. + + Returns: + :class:`torch.Tensor` : The loss value on this batch. + + .. note:: This is typically the method to overwrite when subclassing + `System`. If the training and validation steps are different + (except for loss.backward() and optimzer.step()), then overwrite + `training_step` and `validation_step` instead. + """ + inputs, targets, spk_ids = batch + spk_embed = self.spk_stack(inputs) + + spk_loss, reordered_embed = self.spk_loss(spk_embed, torch.ones((spk_embed.shape[0], + spk_embed.shape[1],spk_embed.shape[-1])).to(spk_embed.device), spk_ids) + reordered_embed = reordered_embed.mean(-1) + + #reordered_embed = self.oracle[spk_ids] + #b, n_spk, spk_vec_size = reordered_embed.size() + + separated = self.sep_stack(inputs, torch.cat((reordered_embed[:, 0], reordered_embed[:, 1]), 1)) + + sep_loss = 0 + for i, o in enumerate(separated): + o = self.pad_output_to_inp(o, inputs) + last = self.sep_loss(o, targets).mean() + sep_loss += last + sep_loss = sep_loss / (i+1) + loss = sep_loss + spk_loss + + return loss, spk_loss, last.mean() + + @staticmethod + def pad_output_to_inp(output, inp): + """ Pad first argument to have same size as second argument""" + inp_len = inp.size(-1) + output_len = output.size(-1) + return torch.nn.functional.pad(output, [0, inp_len - output_len]) + + def training_step(self, batch, batch_nb): + """ Pass data through the model and compute the loss. + + Backprop is **not** performed (meaning PL will do it for you). + + Args: + batch: the object returned by the loader (a list of torch.Tensor + in most cases) but can be something else. + batch_nb (int): The number of the batch in the epoch. + + Returns: + dict: + + ``'loss'``: loss + + ``'log'``: dict with tensorboard logs + + """ + loss, spk_loss, sep_loss = self.common_step(batch, batch_nb) + tqdm_dict = {'train_loss': loss, "spk_loss": spk_loss, "sep_loss": sep_loss} + tensorboard_logs = {'train_loss': loss, "spk_loss": spk_loss, "sep_loss": sep_loss} + + output = OrderedDict({ + 'loss': loss, + 'progress_bar': tqdm_dict, + 'log': tensorboard_logs + }) + return output + + def validation_step(self, batch, batch_nb): + """ Need to overwrite PL validation_step to do validation. + + Args: + batch: the object returned by the loader (a list of torch.Tensor + in most cases) but can be something else. + batch_nb (int): The number of the batch in the epoch. + + Returns: + dict: + + ``'val_loss'``: loss + """ + loss, spk_loss, sep_loss = self.common_step(batch, batch_nb) + return {'val_loss': loss, 'val_spk_loss': spk_loss, 'val_sep_loss': sep_loss} + + def validation_end(self, outputs): + """ How to aggregate outputs of `validation_step` for logging. + + Args: + outputs (list[dict]): List of validation losses, each with a + ``'val_loss'`` key + + Returns: + dict: Average loss + + ``'val_loss'``: Average loss on `outputs` + + ``'log'``: Tensorboard logs + + ``'progress_bar'``: Tensorboard logs + """ + avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() + avg_spk_loss = torch.stack([x['val_spk_loss'] for x in outputs]).mean() + avg_sep_loss = torch.stack([x['val_sep_loss'] for x in outputs]).mean() + tensorboard_logs = {'val_loss': avg_loss, "val_spk_loss": avg_spk_loss, "val_sep_loss": avg_sep_loss} + return {'val_loss': avg_loss, 'log': tensorboard_logs, + 'progress_bar': tensorboard_logs} + + def unsqueeze_if_dp_or_ddp(self, *values): + """ Apply unsqueeze(0) to all values if training is done with dp + or ddp. Unused now.""" + if self.trainer.use_dp or self.trainer.use_ddp2: + values = [v.unsqueeze(0) for v in values] + if len(values) == 1: + return values[0] + return values + + def configure_optimizers(self): + """ Required by pytorch-lightning. """ + if self.scheduler is not None: + return [self.optimizer], [self.scheduler] + return self.optimizer + + @pl.data_loader + def train_dataloader(self): + return self.train_loader + + @pl.data_loader + def val_dataloader(self): + return self.val_loader + + @pl.data_loader + def tng_dataloader(self): # pragma: no cover + """ Deprecated.""" + pass + + def on_save_checkpoint(self, checkpoint): + """ Overwrite if you want to save more things in the checkpoint.""" + checkpoint['training_config'] = self.config + return checkpoint + + def on_batch_start(self, batch): + """ Overwrite if needed. Called by pytorch-lightning""" + pass + + def on_batch_end(self): + """ Overwrite if needed. Called by pytorch-lightning""" + pass + + def on_epoch_start(self): + """ Overwrite if needed. Called by pytorch-lightning""" + pass + + def on_epoch_end(self): + """ Overwrite if needed. Called by pytorch-lightning""" + pass + + @staticmethod + def none_to_string(dic): + """ Converts `None` to ``'None'`` to be handled by torch summary writer. + + Args: + dic (dict): Dictionary to be transformed. + + Returns: + dict: Transformed dictionary. + """ + for k, v in dic.items(): + if v is None: + dic[k] = str(v) + return dic + + +def main(conf): + train_set = WaveSplitWhamDataset(conf['data']['train_dir'], conf['data']['task'], + sample_rate=conf['data']['sample_rate'], segment=conf['data']['segment'], + nondefault_nsrc=conf['data']['nondefault_nsrc']) + val_set = WaveSplitWhamDataset(conf['data']['valid_dir'], conf['data']['task'], + sample_rate=conf['data']['sample_rate'], + nondefault_nsrc=conf['data']['nondefault_nsrc']) + + train_loader = DataLoader(train_set, shuffle=True, + batch_size=conf['training']['batch_size'], + num_workers=conf['training']['num_workers'], + drop_last=True) + val_loader = DataLoader(val_set, shuffle=False, + batch_size=conf['training']['batch_size'], + num_workers=conf['training']['num_workers'], + drop_last=True) + # Update number of source values (It depends on the task) + conf['masknet'].update({'n_src': train_set.n_src}) + spk_stack = SpeakerStack(2, 256) # inner dim is 256 instead of 512 from paper to spare mem 13 layers as in the paper. + sep_stack = SeparationStack(2, 256, 512, 10, 4) # 40 layers. + # Define model and optimizer in a local function (defined in the recipe). + # Two advantages to this : re-instantiating the model and optimizer + # for retraining and evaluating is straight-forward. + # Define scheduler + spk_loss = SpeakerVectorLoss(101, 256, loss_type="distance") # 100 speakers in WHAM dev and train, 256 embed dim + sep_loss = ClippedSDR(-30) + + params = list(spk_stack.parameters()) + list(sep_stack.parameters()) + list(spk_loss.parameters()) + optimizer = torch.optim.Adam(params, lr=0.003) + scheduler = None + if conf['training']['half_lr']: + scheduler = ReduceLROnPlateau(optimizer=optimizer, factor=0.5, + patience=5) + # Just after instantiating, save the args. Easy loading in the future. + exp_dir = conf['main_args']['exp_dir'] + os.makedirs(exp_dir, exist_ok=True) + conf_path = os.path.join(exp_dir, 'conf.yml') + with open(conf_path, 'w') as outfile: + yaml.safe_dump(conf, outfile) + + system = Wavesplit(spk_stack, sep_stack, optimizer, spk_loss, sep_loss, train_loader, val_loader, scheduler, conf) + # Define callbacks + checkpoint_dir = os.path.join(exp_dir, 'checkpoints/') + checkpoint = ModelCheckpoint(checkpoint_dir, monitor='val_loss', + mode='min', save_top_k=5, verbose=1) + early_stopping = False + if conf['training']['early_stop']: + early_stopping = EarlyStopping(monitor='val_loss', patience=10, + verbose=1) + + # Don't ask GPU if they are not available. + if not torch.cuda.is_available(): + print('No available GPU were found, set gpus to None') + conf['main_args']['gpus'] = None + trainer = pl.Trainer(max_nb_epochs=conf['training']['epochs'], + checkpoint_callback=checkpoint, + early_stop_callback=early_stopping, + default_save_path=exp_dir, + gpus=conf['main_args']['gpus'], + ) + trainer.fit(system) + + with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f: + json.dump(checkpoint.best_k_models, f, indent=0) + + +if __name__ == '__main__': + import yaml + from pprint import pprint as print + from asteroid.utils import prepare_parser_from_dict, parse_args_as_dict + + # We start with opening the config file conf.yml as a dictionary from + # which we can create parsers. Each top level key in the dictionary defined + # by the YAML file creates a group in the parser. + with open('local/conf.yml') as f: + def_conf = yaml.safe_load(f) + parser = prepare_parser_from_dict(def_conf, parser=parser) + # Arguments are then parsed into a hierarchical dictionary (instead of + # flat, as returned by argparse) to facilitate calls to the different + # asteroid methods (see in main). + # plain_args is the direct output of parser.parse_args() and contains all + # the attributes in an non-hierarchical structure. It can be useful to also + # have it so we included it here but it is not used. + arg_dic, plain_args = parse_args_as_dict(parser, return_plain_args=True) + print(arg_dic) + main(arg_dic) diff --git a/egs/wham/WaveSplit/utils/parse_options.sh b/egs/wham/WaveSplit/utils/parse_options.sh new file mode 100755 index 000000000..c2c3b31f2 --- /dev/null +++ b/egs/wham/WaveSplit/utils/parse_options.sh @@ -0,0 +1,97 @@ +#!/bin/bash + +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey); +# Arnab Ghoshal, Karel Vesely + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# Parse command-line options. +# To be sourced by another script (as in ". parse_options.sh"). +# Option format is: --option-name arg +# and shell variable "option_name" gets set to value "arg." +# The exception is --help, which takes no arguments, but prints the +# $help_message variable (if defined). + + +### +### The --config file options have lower priority to command line +### options, so we need to import them first... +### + +# Now import all the configs specified by command-line, in left-to-right order +for ((argpos=1; argpos<$#; argpos++)); do + if [ "${!argpos}" == "--config" ]; then + argpos_plus1=$((argpos+1)) + config=${!argpos_plus1} + [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 + . $config # source the config file. + fi +done + + +### +### Now we process the command line options +### +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + # If the enclosing script is called with --help option, print the help + # message and exit. Scripts should put help messages in $help_message + --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; + else printf "$help_message\n" 1>&2 ; fi; + exit 0 ;; + --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" + exit 1 ;; + # If the first command-line argument begins with "--" (e.g. --foo-bar), + # then work out the variable name as $name, which will equal "foo_bar". + --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; + # Next we test whether the variable in question is undefned-- if so it's + # an invalid option and we die. Note: $0 evaluates to the name of the + # enclosing script. + # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar + # is undefined. We then have to wrap this test inside "eval" because + # foo_bar is itself inside a variable ($name). + eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + + oldval="`eval echo \\$$name`"; + # Work out whether we seem to be expecting a Boolean argument. + if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval $name=\"$2\"; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + *) break; + esac +done + + +# Check for an empty argument to the --cmd option, which can easily occur as a +# result of scripting errors. +[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; + + +true; # so this script returns exit code 0. \ No newline at end of file diff --git a/egs/wham/WaveSplit/utils/prepare_python_env.sh b/egs/wham/WaveSplit/utils/prepare_python_env.sh new file mode 100755 index 000000000..3dc223334 --- /dev/null +++ b/egs/wham/WaveSplit/utils/prepare_python_env.sh @@ -0,0 +1,26 @@ +#!/bin/bash +# Usage ./utils/install_env.sh --install_dir A --asteroid_root B --pip_requires C +install_dir=~ +asteroid_root=../../../../ +pip_requires=../../../requirements.txt # Expects a requirement.txt + +. utils/parse_options.sh || exit 1 + +mkdir -p $install_dir +cd $install_dir +echo "Download and install latest version of miniconda3 into ${install_dir}" +wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh + +bash Miniconda3-latest-Linux-x86_64.sh -b -p miniconda3 +pip_path=$PWD/miniconda3/bin/pip + +rm Miniconda3-latest-Linux-x86_64.sh +cd - + +if [[ ! -z ${pip_requires} ]]; then + $pip_path install -r $pip_requires +fi +$pip_path install soundfile +$pip_path install -e $asteroid_root +#$pip_path install ${asteroid_root}/\[""evaluate""\] +echo -e "\nAsteroid has been installed in editable mode. Feel free to apply your changes !" \ No newline at end of file diff --git a/egs/wham/WaveSplit/wavesplit.py b/egs/wham/WaveSplit/wavesplit.py new file mode 100644 index 000000000..e43fad2ee --- /dev/null +++ b/egs/wham/WaveSplit/wavesplit.py @@ -0,0 +1,199 @@ +from torch import nn +import torch +from asteroid.masknn import norms +from asteroid.masknn import activations +from asteroid.utils import has_arg + + +class Conv1DBlock(nn.Module): + + def __init__(self, in_chan, hid_chan, kernel_size, padding, + dilation, norm_type="gLN"): + super(Conv1DBlock, self).__init__() + + conv_norm = norms.get(norm_type) + depth_conv1d = nn.Conv1d(in_chan, hid_chan, kernel_size, + padding=padding, dilation=dilation) + + self.out = nn.Sequential(depth_conv1d, nn.PReLU(), conv_norm(hid_chan)) + + def forward(self, x): + """ Input shape [batch, feats, seq]""" + + return self.out(x) + + +class SepConv1DBlock(nn.Module): + + def __init__(self, in_chan, hid_chan, spk_vec_chan, kernel_size, padding, + dilation, norm_type="gLN", use_FiLM=True): + super(SepConv1DBlock, self).__init__() + + self.use_FiLM = use_FiLM + conv_norm = norms.get(norm_type) + self.depth_conv1d = nn.Conv1d(in_chan, hid_chan, kernel_size, + padding=padding, dilation=dilation) + self.out = nn.Sequential(nn.PReLU(), + conv_norm(hid_chan)) + + # FiLM conditioning + if self.use_FiLM: + self.mul_lin = nn.Linear(spk_vec_chan, hid_chan) + self.add_lin = nn.Linear(spk_vec_chan, hid_chan) + + def apply_conditioning(self, spk_vec, squeezed): + bias = self.add_lin(spk_vec) + if self.use_FiLM: + mul = self.mul_lin(spk_vec) + return mul.unsqueeze(-1)*squeezed + bias.unsqueeze(-1) + else: + return squeezed + bias.unsqueeze(-1) + + def forward(self, x, spk_vec): + """ Input shape [batch, feats, seq]""" + + conditioned = self.apply_conditioning(spk_vec, self.depth_conv1d(x)) + + return self.out(conditioned) + + +class SpeakerStack(nn.Module): + # basically this is plain conv-tasnet remove this in future releases + + def __init__(self, n_src, embed_dim, n_blocks=14, n_repeats=1, + kernel_size=3, + norm_type="gLN"): + + super(SpeakerStack, self).__init__() + self.embed_dim = embed_dim + self.n_src = n_src + self.n_blocks = n_blocks + self.n_repeats = n_repeats + self.kernel_size = kernel_size + self.norm_type = norm_type + + # Succession of Conv1DBlock with exponentially increasing dilation. + self.TCN = nn.ModuleList() + for r in range(n_repeats): + for x in range(n_blocks): + padding = (kernel_size - 1) * 2 ** x // 2 + if r == 0 and x == 0: + in_chan = 1 + else: + in_chan = embed_dim + self.TCN.append(Conv1DBlock(in_chan, embed_dim, + kernel_size, padding=padding, + dilation=2 ** x, norm_type=norm_type)) + mask_conv = nn.Conv1d(embed_dim, n_src * embed_dim, 1) + self.mask_net = nn.Sequential(mask_conv) + + def forward(self, mixture_w): + """ + Args: + mixture_w (:class:`torch.Tensor`): Tensor of shape + [batch, n_filters, n_frames] + + Returns: + :class:`torch.Tensor`: + estimated mask of shape [batch, n_src, n_filters, n_frames] + """ + batch, n_frames = mixture_w.size() + output = mixture_w.unsqueeze(1) + for i in range(len(self.TCN)): + if i == 0: + output = self.TCN[i](output) + else: + residual = self.TCN[i](output) + output = output + residual + emb = self.mask_net(output) + + emb = emb.view(batch, self.n_src, self.embed_dim, n_frames) + emb = emb / torch.sqrt(torch.sum(emb**2, 2, keepdim=True)) + return emb + + +class SeparationStack(nn.Module): + # basically this is plain conv-tasnet remove this in future releases + + def __init__(self, src, embed_dim=256, spk_vec_dim=512, n_blocks=10, n_repeats=4, + kernel_size=3, + norm_type="gLN", return_all_layers=True): + + super(SeparationStack, self).__init__() + self.n_blocks = n_blocks + self.n_repeats = n_repeats + self.kernel_size = kernel_size + self.norm_type = norm_type + self.src = src + self.embed_dim = embed_dim + self.return_all = return_all_layers + + # layer_norm = norms.get(norm_type)(in_chan) + # bottleneck_conv = nn.Conv1d(in_chan, bn_chan, 1) + # self.bottleneck = nn.Sequential(layer_norm, bottleneck_conv) + # Succession of Conv1DBlock with exponentially increasing dilation. + self.TCN = nn.ModuleList() + for r in range(n_repeats): + for x in range(n_blocks): + if r == 0 and x == 0: + in_chan = 1 + else: + in_chan = embed_dim + padding = (kernel_size - 1) * 2 ** x // 2 + if not self.return_all: + self.TCN.append(SepConv1DBlock(in_chan, embed_dim, spk_vec_dim, + kernel_size, padding=padding, + dilation=2 ** x, norm_type=norm_type)) + else: + self.TCN.append(nn.ModuleList([ SepConv1DBlock(in_chan, embed_dim, spk_vec_dim, + kernel_size, padding=padding, + dilation=2 ** x, norm_type=norm_type), nn.Conv1d(embed_dim, self.src, 1)])) + + self.out = nn.Conv1d(embed_dim, self.src, 1) + + def forward(self, mixture_w, spk_vectors): + """ + Args: + mixture_w (:class:`torch.Tensor`): Tensor of shape + [batch, n_filters, n_frames] + + Returns: + :class:`torch.Tensor`: + estimated mask of shape [batch, n_src, n_filters, n_frames] + """ + output = mixture_w.unsqueeze(1) + outputs = [] + # output = self.bottleneck(mixture_w) + for i in range(len(self.TCN)): + if i == 0: + if self.return_all: + conv, linear = self.TCN[i] + output = conv(output, spk_vectors) + outputs.append(linear(output)) + else: + output = self.TCN[i](output, spk_vectors) + else: + if self.return_all: + conv, linear = self.TCN[i] + residual = conv(output, spk_vectors) + output = output + residual + outputs.append(linear(output)) + else: + residual = self.TCN[i](output, spk_vectors) + output = output + residual + + if self.return_all: + out = outputs + else: + out = output + + return out + + +if __name__ == "__main__": + sep = SeparationStack(2, 256, 512, 10, 3, kernel_size=3) + wave = torch.rand((2, 16000)) + spk_vectors = torch.rand((2, 2, 256)) + out = sep(wave, spk_vectors.reshape(2, 2*256)) + + diff --git a/egs/wham/WaveSplit/wavesplitwham.py b/egs/wham/WaveSplit/wavesplitwham.py new file mode 100644 index 000000000..d4db3a9ce --- /dev/null +++ b/egs/wham/WaveSplit/wavesplitwham.py @@ -0,0 +1,154 @@ +import torch +from torch.utils import data +import json +import os +import numpy as np +import soundfile as sf + +DATASET = 'WHAM' +# WHAM tasks +enh_single = {'mixture': 'mix_single', + 'sources': ['s1'], + 'infos': ['noise'], + 'default_nsrc': 1} +enh_both = {'mixture': 'mix_both', + 'sources': ['mix_clean'], + 'infos': ['noise'], + 'default_nsrc': 1} +sep_clean = {'mixture': 'mix_clean', + 'sources': ['s1', 's2'], + 'infos': [], + 'default_nsrc': 2} +sep_noisy = {'mixture': 'mix_both', + 'sources': ['s1', 's2'], + 'infos': ['noise'], + 'default_nsrc': 2} + +WHAM_TASKS = {'enhance_single': enh_single, + 'enhance_both': enh_both, + 'sep_clean': sep_clean, + 'sep_noisy': sep_noisy} +# Aliases. +WHAM_TASKS['enh_single'] = WHAM_TASKS['enhance_single'] +WHAM_TASKS['enh_both'] = WHAM_TASKS['enhance_both'] + + +class WaveSplitWhamDataset(data.Dataset): + """ Dataset class for WHAM source separation and speech enhancement tasks. + + Args: + json_dir (str): The path to the directory containing the json files. + task (str): One of ``'enh_single'``, ``'enh_both'``, ``'sep_clean'`` or + ``'sep_noisy'``. + + * ``'enh_single'`` for single speaker speech enhancement. + * ``'enh_both'`` for multi speaker speech enhancement. + * ``'sep_clean'`` for two-speaker clean source separation. + * ``'sep_noisy'`` for two-speaker noisy source separation. + + sample_rate (int, optional): The sampling rate of the wav files. + segment (float, optional): Length of the segments used for training, + in seconds. If None, use full utterances (e.g. for test). + nondefault_nsrc (int, optional): Number of sources in the training + targets. + If None, defaults to one for enhancement tasks and two for + separation tasks. + """ + def __init__(self, json_dir, task, sample_rate=8000, segment=4.0, + nondefault_nsrc=None): + super(WaveSplitWhamDataset, self).__init__() + if task not in WHAM_TASKS.keys(): + raise ValueError('Unexpected task {}, expected one of ' + '{}'.format(task, WHAM_TASKS.keys())) + # Task setting + self.json_dir = json_dir + self.task = task + self.task_dict = WHAM_TASKS[task] + self.sample_rate = sample_rate + self.seg_len = None if segment is None else int(segment * sample_rate) + if not nondefault_nsrc: + self.n_src = self.task_dict['default_nsrc'] + else: + assert nondefault_nsrc >= self.task_dict['default_nsrc'] + self.n_src = nondefault_nsrc + self.like_test = self.seg_len is None + # Load json examples + ex_json = os.path.join(json_dir, self.task_dict['mixture'] + '.json') + + with open(ex_json, 'r') as f: + examples = json.load(f) + + # Filter out short utterances only when segment is specified + self.examples = [] + orig_len = len(examples) + drop_utt, drop_len = 0, 0 + if not self.like_test: + for ex in examples: # Go backward + if ex["length"] < self.seg_len: + drop_utt += 1 + drop_len += ex["length"] + else: + self.examples.append(ex) + + print("Drop {} utts({:.2f} h) from {} (shorter than {} samples)".format( + drop_utt, drop_len/sample_rate/36000, orig_len, self.seg_len)) + + # count total number of speakers + speakers = set() + for ex in self.examples: + for spk in ex["spk_id"]: + speakers.add(spk[:3]) + + print("Total number of speakers {}".format(len(list(speakers)))) + + # convert speakers id into integers + indx = 0 + spk2indx = {} + for spk in list(speakers): + spk2indx[spk] = indx + indx +=1 + + for ex in self.examples: + new = [] + for spk in ex["spk_id"]: + new.append(spk2indx[spk[:3]]) + ex["spk_id"] = new + + def __len__(self): + return len(self.examples) + + def __getitem__(self, idx): + """ Gets a mixture/sources pair. + Returns: + mixture, vstack([source_arrays]) + """ + c_ex = self.examples[idx] + # Random start + if c_ex["length"] == self.seg_len or self.like_test: + rand_start = 0 + else: + rand_start = np.random.randint(0, c_ex["length"] - self.seg_len) + if self.like_test: + stop = None + else: + stop = rand_start + self.seg_len + # Load mixture + x, _ = sf.read(c_ex["mix"], start=rand_start, + stop=stop, dtype='float32') + #seg_len = torch.as_tensor([len(x)]) + # Load sources + source_arrays = [] + for src in c_ex["sources"]: + s, _ = sf.read(src, start=rand_start, + stop=stop, dtype='float32') + source_arrays.append(s) + sources = torch.from_numpy(np.vstack(source_arrays)) + + return torch.from_numpy(x), sources, torch.Tensor(c_ex["spk_id"]).long() + + +if __name__ == "__main__": + a = WaveSplitWhamDataset("/media/sam/Data/temp/asteroid/egs/wham/WaveSplit/data/wav8k/min/tr/", "sep_clean") + + for i in a: + print(i[-1]) \ No newline at end of file