diff --git a/asteroid/data/__init__.py b/asteroid/data/__init__.py index 4698dd0db..2b9685150 100644 --- a/asteroid/data/__init__.py +++ b/asteroid/data/__init__.py @@ -9,6 +9,7 @@ from .kinect_wsj import KinectWsjMixDataset from .fuss_dataset import FUSSDataset from .dampvsep_dataset import DAMPVSEPSinglesDataset +from .chime4_dataset import CHiME4Dataset __all__ = [ "AVSpeechDataset", @@ -22,4 +23,5 @@ "KinectWsjMixDataset", "FUSSDataset", "DAMPVSEPSinglesDataset", + "CHiME4Dataset", ] diff --git a/asteroid/data/chime4_dataset.py b/asteroid/data/chime4_dataset.py new file mode 100644 index 000000000..028a0312c --- /dev/null +++ b/asteroid/data/chime4_dataset.py @@ -0,0 +1,71 @@ +import pandas as pd +import soundfile as sf +import torch +from torch.utils.data import Dataset, DataLoader +import random as random +import os + + +class CHiME4Dataset(Dataset): + """Dataset class for CHiME4 source separation tasks. Only supports 'real' + data + + Args: + csv_dir (str): The path to the metadata file. + sample_rate (int) : The sample rate of the sources and mixtures. + segment (int) : The desired sources and mixtures length in s. + + References + Emmanuel Vincent, Shinji Watanabe, Aditya Arie Nugraha, Jon Barker, and Ricard Marxer + An analysis of environment, microphone and data simulation mismatches in robust speech recognition + Computer Speech and Language, 2017. + """ + + dataset_name = "CHiME4" + + def __init__(self, csv_dir, sample_rate=16000, segment=3, return_id=False): + self.csv_dir = csv_dir + # Get the csv corresponding to origin + self.segment = segment + self.sample_rate = sample_rate + self.return_id = return_id + self.csv_path = [f for f in os.listdir(csv_dir) if "annotations" not in f][0] + # Open csv file and concatenate them + self.df = pd.read_csv(os.path.join(csv_dir, self.csv_path)) + # Get rid of the utterances too short + if self.segment is not None: + max_len = len(self.df) + self.seg_len = int(self.segment * self.sample_rate) + # Ignore the file shorter than the desired_length + self.df = self.df[self.df["duration"] >= self.seg_len] + print( + f"Drop {max_len - len(self.df)} utterances from {max_len} " + f"(shorter than {segment} seconds)" + ) + else: + self.seg_len = None + + def __len__(self): + return len(self.df) + + def __getitem__(self, idx): + # Get the row in dataframe + row = self.df.iloc[idx] + # Get mixture path + self.mixture_path = row["mixture_path"] + # If there is a seg start point is set randomly + if self.seg_len is not None: + start = random.randint(0, row["length"] - self.seg_len) + stop = start + self.seg_len + else: + start = 0 + stop = None + + # Read the mixture + mixture, _ = sf.read(self.mixture_path, dtype="float32", start=start, stop=stop) + # Convert to torch tensor + mixture = torch.from_numpy(mixture) + if self.return_id: + id1 = row.wsj_id + return mixture, [id1] + return mixture diff --git a/asteroid/metrics.py b/asteroid/metrics.py index 70398064d..3032e23dd 100644 --- a/asteroid/metrics.py +++ b/asteroid/metrics.py @@ -290,19 +290,24 @@ def __call__( self.mix_counter += out_count local_mix_counter += out_count self.input_txt_list.append(dict(utt_id=tmp_id, text=txt)) - # Average WER for the clean pair - for i, (wav, tmp_id) in enumerate(zip(clean, wav_id)): - txt = self.predict_hypothesis(wav) - out_count = Counter( - self.hsdi( - truth=self.trans_dic[tmp_id], hypothesis=txt, transformation=self.transformation + if clean is not None: + # Average WER for the clean pair + for i, (wav, tmp_id) in enumerate(zip(clean, wav_id)): + txt = self.predict_hypothesis(wav) + out_count = Counter( + self.hsdi( + truth=self.trans_dic[tmp_id], + hypothesis=txt, + transformation=self.transformation, + ) ) - ) - self.clean_counter += out_count - local_clean_counter += out_count - self.clean_txt_list.append(dict(utt_id=tmp_id, text=txt)) - trans_dict["clean"][f"utt_id_{i}"] = tmp_id - trans_dict["clean"][f"txt_{i}"] = txt + self.clean_counter += out_count + local_clean_counter += out_count + self.clean_txt_list.append(dict(utt_id=tmp_id, text=txt)) + trans_dict["clean"][f"utt_id_{i}"] = tmp_id + trans_dict["clean"][f"txt_{i}"] = txt + else: + self.clean_counter = None # Average WER for the estimate pair for i, (est, tmp_id) in enumerate(zip(estimate, wav_id)): txt = self.predict_hypothesis(est) @@ -316,12 +321,16 @@ def __call__( self.output_txt_list.append(dict(utt_id=tmp_id, text=txt)) trans_dict["estimates"][f"utt_id_{i}"] = tmp_id trans_dict["estimates"][f"txt_{i}"] = txt + self.transcriptions.append(trans_dict) - return dict( + wer_dict = dict( input_wer=self.wer_from_hsdi(**dict(local_mix_counter)), - clean_wer=self.wer_from_hsdi(**dict(local_clean_counter)), wer=self.wer_from_hsdi(**dict(local_est_counter)), ) + if clean is not None: + wer_dict["clean_wer"] = self.wer_from_hsdi(**dict(local_clean_counter)) + + return wer_dict @staticmethod def wer_from_hsdi(hits=0, substitutions=0, deletions=0, insertions=0): @@ -359,31 +368,35 @@ def _df_to_dict(df): def final_df(self): """Generate a MarkDown table, as done by ESPNet.""" mix_n_word = sum(self.mix_counter[k] for k in ["hits", "substitutions", "deletions"]) - clean_n_word = sum(self.clean_counter[k] for k in ["hits", "substitutions", "deletions"]) est_n_word = sum(self.est_counter[k] for k in ["hits", "substitutions", "deletions"]) mix_wer = self.wer_from_hsdi(**dict(self.mix_counter)) - clean_wer = self.wer_from_hsdi(**dict(self.clean_counter)) est_wer = self.wer_from_hsdi(**dict(self.est_counter)) mix_hsdi = [ self.mix_counter[k] for k in ["hits", "substitutions", "deletions", "insertions"] ] - clean_hsdi = [ - self.clean_counter[k] for k in ["hits", "substitutions", "deletions", "insertions"] - ] est_hsdi = [ self.est_counter[k] for k in ["hits", "substitutions", "deletions", "insertions"] ] # Snt Wrd HSDI Err S.Err for_mix = [len(self.mix_counter), mix_n_word] + mix_hsdi + [mix_wer, "-"] - for_clean = [len(self.clean_counter), clean_n_word] + clean_hsdi + [clean_wer, "-"] for_est = [len(self.est_counter), est_n_word] + est_hsdi + [est_wer, "-"] - table = [ - ["test_clean / mixture"] + for_mix, - ["test_clean / clean"] + for_clean, - ["test_clean / separated"] + for_est, + ["ground_truth / mixture"] + for_mix, + ["ground_truth / separated"] + for_est, ] + + if self.clean_counter is not None: + clean_n_word = sum( + self.clean_counter[k] for k in ["hits", "substitutions", "deletions"] + ) + clean_wer = self.wer_from_hsdi(**dict(self.clean_counter)) + clean_hsdi = [ + self.clean_counter[k] for k in ["hits", "substitutions", "deletions", "insertions"] + ] + for_clean = [len(self.clean_counter), clean_n_word] + clean_hsdi + [clean_wer, "-"] + table.insert(1, ["ground_truth / clean"] + for_mix) + df = pd.DataFrame( table, columns=["dataset", "Snt", "Wrd", "Corr", "Sub", "Del", "Ins", "Err", "S.Err"] ) diff --git a/egs/chime4/ConvTasNet/eval.py b/egs/chime4/ConvTasNet/eval.py new file mode 100644 index 000000000..86a966a4d --- /dev/null +++ b/egs/chime4/ConvTasNet/eval.py @@ -0,0 +1,182 @@ +import os +import random +import soundfile as sf +import torch +import yaml +import json +import argparse +import numpy as np +import pandas as pd +from tqdm import tqdm +from pprint import pprint + +from asteroid.data.chime4_dataset import CHiME4Dataset +from asteroid import ConvTasNet +from asteroid.models import save_publishable +from asteroid.utils import tensors_to_device +from asteroid.metrics import WERTracker, MockWERTracker + +parser = argparse.ArgumentParser() +parser.add_argument( + "--test_dir", type=str, required=True, help="Test directory including the csv files" +) + +parser.add_argument( + "--use_gpu", type=int, default=0, help="Whether to use the GPU for model execution" +) +parser.add_argument("--exp_dir", default="exp/tmp", help="Experiment root") +parser.add_argument( + "--n_save_ex", type=int, default=1, help="Number of audio examples to save, -1 means all" +) +parser.add_argument( + "--compute_wer", type=int, default=1, help="Compute WER using ESPNet's pretrained model" +) +parser.add_argument( + "--asr_type", + default="noisy", + help="Choice for the ASR model whether trained on clean or noisy data. One of clean or noisy", +) + + +# In CHiME 4 only the noisy data are available, hence no metrics. +COMPUTE_METRICS = [] + + +def update_compute_metrics(compute_wer, metric_list): + if not compute_wer: + return metric_list + try: + from espnet2.bin.asr_inference import Speech2Text + from espnet_model_zoo.downloader import ModelDownloader + except ModuleNotFoundError: + import warnings + + warnings.warn("Couldn't find espnet installation. Continuing without.") + return metric_list + return metric_list + ["wer"] + + +def main(conf): + + if conf["asr_type"] == "noisy": + asr_model_path = ( + "kamo-naoyuki/chime4_asr_train_asr_transformer3_raw_en_char_sp_valid.acc.ave" + ) + else: + asr_model_path = "kamo-naoyuki/wsj_transformer2" + + compute_metrics = update_compute_metrics(conf["compute_wer"], COMPUTE_METRICS) + annot_path = [f for f in os.listdir(conf["test_dir"]) if "annotations" in f][0] + anno_df = pd.read_csv(os.path.join(conf["test_dir"], annot_path)) + wer_tracker = ( + MockWERTracker() if not conf["compute_wer"] else WERTracker(asr_model_path, anno_df) + ) + model_path = os.path.join(conf["exp_dir"], "best_model.pth") + model = ConvTasNet.from_pretrained(model_path) + # Handle device placement + if conf["use_gpu"]: + model.cuda() + model_device = next(model.parameters()).device + test_set = CHiME4Dataset( + csv_dir=conf["test_dir"], + sample_rate=conf["sample_rate"], + segment=None, + return_id=True, + ) # Uses all segment length + # Used to reorder sources only + + # Randomly choose the indexes of sentences to save. + eval_save_dir = os.path.join(conf["exp_dir"], "chime4", conf["asr_type"]) + ex_save_dir = os.path.join(eval_save_dir, "examples/") + if conf["n_save_ex"] == -1: + conf["n_save_ex"] = len(test_set) + save_idx = random.sample(range(len(test_set)), conf["n_save_ex"]) + series_list = [] + torch.no_grad().__enter__() + for idx in tqdm(range(len(test_set))): + # Forward the network on the mixture. + mix, ids = test_set[idx] + mix = tensors_to_device(mix, device=model_device) + est_sources = model(mix.unsqueeze(0)) + mix_np = mix.cpu().data.numpy() + est_sources_np = est_sources.squeeze(0).cpu().data.numpy() + est_sources_np *= np.max(np.abs(mix_np)) / np.max(np.abs(est_sources_np)) + # For each utterance, we get a dictionary with the mixture path, + # the input and output metrics + utt_metrics = {"mix_path": test_set.mixture_path} + utt_metrics.update( + **wer_tracker( + mix=mix_np, + clean=None, + estimate=est_sources_np, + wav_id=ids, + sample_rate=conf["sample_rate"], + ) + ) + series_list.append(pd.Series(utt_metrics)) + + # Save some examples in a folder. Wav files and metrics as text. + if idx in save_idx: + local_save_dir = os.path.join(ex_save_dir, "ex_{}/".format(idx)) + os.makedirs(local_save_dir, exist_ok=True) + sf.write(local_save_dir + "mixture.wav", mix_np, conf["sample_rate"]) + # Loop over the sources and estimates + for src_idx, est_src in enumerate(est_sources_np): + # est_src *= np.max(np.abs(mix_np)) / np.max(np.abs(est_src)) + sf.write( + local_save_dir + "s{}_estimate.wav".format(src_idx), + est_src, + conf["sample_rate"], + ) + # Write local metrics to the example folder. + with open(local_save_dir + "metrics.json", "w") as f: + json.dump(utt_metrics, f, indent=0) + + # Save all metrics to the experiment folder. + all_metrics_df = pd.DataFrame(series_list) + all_metrics_df.to_csv(os.path.join(eval_save_dir, "all_metrics.csv")) + + # Print and save summary metrics + final_results = {} + for metric_name in compute_metrics: + input_metric_name = "input_" + metric_name + ldf = all_metrics_df[metric_name] - all_metrics_df[input_metric_name] + final_results[metric_name] = all_metrics_df[metric_name].mean() + final_results[metric_name + "_imp"] = ldf.mean() + + print("Overall metrics :") + pprint(final_results) + if conf["compute_wer"]: + print("\nWER report") + wer_card = wer_tracker.final_report_as_markdown() + print(wer_card) + # Save the report + with open(os.path.join(eval_save_dir, "final_wer.md"), "w") as f: + f.write(wer_card) + all_transcriptions = wer_tracker.transcriptions + with open(os.path.join(eval_save_dir, "all_transcriptions.json"), "w") as f: + json.dump(all_transcriptions, f, indent=4) + + with open(os.path.join(eval_save_dir, "final_metrics.json"), "w") as f: + json.dump(final_results, f, indent=0) + + model_dict = torch.load(model_path, map_location="cpu") + os.makedirs(os.path.join(conf["exp_dir"], "publish_dir"), exist_ok=True) + publishable = save_publishable( + os.path.join(conf["exp_dir"], "publish_dir"), + model_dict, + metrics=final_results, + train_conf=train_conf, + ) + + +if __name__ == "__main__": + args = parser.parse_args() + arg_dic = dict(vars(args)) + # Load training config + conf_path = os.path.join(args.exp_dir, "conf.yml") + with open(conf_path) as f: + train_conf = yaml.safe_load(f) + arg_dic["sample_rate"] = train_conf["data"]["sample_rate"] + arg_dic["train_conf"] = train_conf + main(arg_dic) diff --git a/egs/chime4/ConvTasNet/local/create_metadata.py b/egs/chime4/ConvTasNet/local/create_metadata.py new file mode 100644 index 000000000..6f2499445 --- /dev/null +++ b/egs/chime4/ConvTasNet/local/create_metadata.py @@ -0,0 +1,115 @@ +import os +import argparse +from glob import glob +import pandas as pd +import numpy as np + +# Command line arguments +parser = argparse.ArgumentParser() +parser.add_argument("--chime3_dir", type=str, default=None, help="Path to CHiME3 root directory") + +# Set seed for random generation +SEED = 4 +np.random.seed(SEED) + + +def main(args): + chime3_dir = args.chime3_dir + create_local_metadata(chime3_dir) + + +def create_local_metadata(chime3_dir): + # Get CHiME-3 annotation files + c3_annot_files = [ + f for f in glob(os.path.join(chime3_dir, "data", "annotations", "*real*.json")) + ] + # Get CHiME-4 annotation files + c4_annot_files = [ + f for f in glob(os.path.join(chime3_dir, "data", "annotations", "*real*.list")) + ] + for c3_annot_file_path in c3_annot_files: + c3_annot_file, c4_annot_file, subset, origin = match_annotation_files( + c3_annot_file_path, c4_annot_files + ) + df_audio_path, df_annot = create_dataframe( + chime3_dir, c3_annot_file, c4_annot_file, subset, origin + ) + write_dataframe(df_audio_path, df_annot, subset, origin) + + +def match_annotation_files(c3_anno_path, c4_anno): + # Read CHiME-3 annotation file + c3_annot_file = pd.read_json(c3_anno_path) + # Extract subset and origin from /foo/bar/_.json + subset, origin = os.path.split(c3_anno_path)[1].replace(".json", "").split("_") + # Look for associated CHiME-4 file + if c3_anno_path.replace(".json", "_1ch_track.list") in c4_anno: + # Read CHiME-4 annotation file + c4_annot_file = pd.read_csv( + c3_anno_path.replace(".json", "_1ch_track.list"), header=None, names=["path"] + ) + else: + c4_annot_file = None + return c3_annot_file, c4_annot_file, subset, origin + + +def create_dataframe(chime3_dir, c3_anno, c4_anno, subset, origin): + # Empty list for DataFrame creation + row_path_list = [] + row_annot_list = [] + for row in c3_anno.itertuples(): + speaker = row.speaker + wsj_id = row.wsj_name + env = row.environment + # if we are dealing with et or dt subset + if c4_anno is not None: + # Find current c3_annot_file wsj_id in c4_annot_file path list + # Path are stored like __real/_..wav + mixture_path = c4_anno[c4_anno["path"].str.contains(wsj_id + "_" + env)].values[0][0] + mixture_path = os.path.join(chime3_dir, "data/audio/16kHz/isolated/", mixture_path) + + # if we are dealing with the tr subset + else: + channel = np.random.randint(1, 7) + mixture_path = os.path.join( + chime3_dir, + "data/audio/16kHz/isolated/", + subset + "_" + env.lower() + "_" + origin, + speaker + "_" + wsj_id + "_" + f".CH{channel}" ".wav", + ) + dot = row.dot + duration = row.end - row.start + temp_dict = { + "wsj_id": wsj_id, + "subset": subset, + "origin": origin, + "env": env, + "mixture_path": mixture_path, + "duration": duration, + } + trans_dict = {"utt_id": wsj_id, "text": dot} + row_path_list.append(temp_dict) + row_annot_list.append(trans_dict) + df_audio_path = pd.DataFrame(row_path_list) + df_annot = pd.DataFrame(row_annot_list) + return df_audio_path, df_annot + + +def write_dataframe(df, df2, subset, origin): + if "et" in subset: + subdir = "test" + elif "dt" in subset: + subdir = "val" + else: + subdir = "train" + save_dir = os.path.join("data", subdir) + os.makedirs(save_dir, exist_ok=True) + save_path = os.path.join(save_dir, origin + "_1_ch_track.csv") + df.to_csv(save_path, index=False) + save_path2 = os.path.join(save_dir, origin + "_1_ch_track_annotations.csv") + df2.to_csv(save_path2, index=False) + + +if __name__ == "__main__": + args = parser.parse_args() + main(args) diff --git a/egs/chime4/ConvTasNet/run.sh b/egs/chime4/ConvTasNet/run.sh new file mode 100644 index 000000000..c5d095c0f --- /dev/null +++ b/egs/chime4/ConvTasNet/run.sh @@ -0,0 +1,64 @@ +#!/bin/bash + +# Exit on error +set -e +set -o pipefail + +# The root directory containing CHiME3 +storage_dir= + +# Directory containing the pretrained model +exp_dir= +# After running the recipe a first time, you can run it from stage 3 directly to train new models. + +# Path to the python you'll use for the experiment. Defaults to the current python +# You can run ./utils/prepare_python_env.sh to create a suitable python environment, paste the output here. +python_path=python + +# Example usage +# ./run.sh --stage 3 --tag my_tag --task sep_noisy --id 0,1 + +# General +stage=0 # Controls from which stage to start +tag="" # Controls the directory name associated to the experiment +# You can ask for several GPUs using id (passed to CUDA_VISIBLE_DEVICES) + +eval_use_gpu=0 +# Need to --compute_wer 1 --eval_mode max to be sure the user knows all the metrics +# are for the all mode. +compute_wer=1 + +# Choice for the ASR model whether trained on clean or noisy data. One of clean or noisy +asr_type=noisy + +. utils/parse_options.sh + +test_dir=data/test + +if [[ $stage -le 0 ]]; then + echo "Stage 0: Generating CHiME-4 dataset" + $python_path local/create_metadata.py --chime3_dir $storage_dir/CHiME3/ +fi + +if [[ $stage -le 1 ]]; then + echo "Stage 2 : Evaluation" + echo "Results from the following experiment will be stored in $exp_dir/chime4/$asr_type" + + if [[ $compute_wer -eq 1 ]]; then + + # Install espnet if not instaled + if ! python -c "import espnet" &> /dev/null; then + echo 'This recipe requires espnet. Installing requirements.' + $python_path -m pip install espnet_model_zoo + $python_path -m pip install jiwer + $python_path -m pip install tabulate + fi + fi + + $python_path eval.py \ + --exp_dir $exp_dir \ + --test_dir $test_dir \ + --use_gpu $eval_use_gpu \ + --compute_wer $compute_wer \ + --asr_type $asr_type +fi diff --git a/egs/chime4/ConvTasNet/utils b/egs/chime4/ConvTasNet/utils new file mode 120000 index 000000000..00cd3a3b9 --- /dev/null +++ b/egs/chime4/ConvTasNet/utils @@ -0,0 +1 @@ +../../wham/ConvTasNet/utils/ \ No newline at end of file diff --git a/egs/chime4/README.md b/egs/chime4/README.md new file mode 100644 index 000000000..7f501bff7 --- /dev/null +++ b/egs/chime4/README.md @@ -0,0 +1,38 @@ +### The CHiME-4 dataset + +The CHiME-4 dataset is part of the 4th CHiME speech separation and recognition challenge. + +It was released in 2016 and revisits the datasets originally recorded for CHiME-3. + +All data and information are available [here](http://spandh.dcs.shef.ac.uk/chime_challenge/CHiME4/index.html). + +For now, this recipe only deals with the `real_1_ch_track` part of the dataset. +As the channel to use for the training set wasn't defined by +the challenge's rules, we will set it randomly. + +**Note :** +This dataset uses real noisy data. This means the clean speech from the noisy +utterances is not available. This makes it unsuitable for the usual training +procedure. + + + +**References** +~~~BibTeX +@article{vincent:hal-01399180, + TITLE = {{An analysis of environment, microphone and data simulation mismatches in robust speech recognition}}, + AUTHOR = {Vincent, Emmanuel and Watanabe, Shinji and Nugraha, Aditya Arie and Barker, Jon and Marxer, Ricard}, + URL = {https://hal.inria.fr/hal-01399180}, + JOURNAL = {{Computer Speech and Language}}, + PUBLISHER = {{Elsevier}}, + VOLUME = {46}, + PAGES = {535-557}, + YEAR = {2017}, + MONTH = Jul, + DOI = {10.1016/j.csl.2016.11.005}, + KEYWORDS = {speech enhancement ; Robust ASR ; train/test mismatch ; microphone array}, + PDF = {https://hal.inria.fr/hal-01399180/file/vincent_CSL16.pdf}, + HAL_ID = {hal-01399180}, + HAL_VERSION = {v1}, +} +~~~ \ No newline at end of file