-
Notifications
You must be signed in to change notification settings - Fork 447
Chime4 #423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
JorisCos
wants to merge
10
commits into
asteroid-team:master
Choose a base branch
from
JorisCos:chime4
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Chime4 #423
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
96183f4
Initial commit
JorisCos 13704af
fix
JorisCos 72039d2
add all transcriptions
JorisCos 6f2149d
refactor fake_source
JorisCos 4723e2a
add all possible models
JorisCos dc2da64
add asr_type
JorisCos fce54f6
black
JorisCos f9b5c00
extend wertracker
JorisCos 0c58deb
fix metrics
JorisCos 0615f71
Merge branch 'master' into chime4
JorisCos File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| import pandas as pd | ||
| import soundfile as sf | ||
| import torch | ||
| from torch.utils.data import Dataset, DataLoader | ||
| import random as random | ||
| import os | ||
|
|
||
|
|
||
| class CHiME4Dataset(Dataset): | ||
| """Dataset class for CHiME4 source separation tasks. Only supports 'real' | ||
| data | ||
|
|
||
| Args: | ||
| csv_dir (str): The path to the metadata file. | ||
| sample_rate (int) : The sample rate of the sources and mixtures. | ||
| segment (int) : The desired sources and mixtures length in s. | ||
|
|
||
| References | ||
| Emmanuel Vincent, Shinji Watanabe, Aditya Arie Nugraha, Jon Barker, and Ricard Marxer | ||
| An analysis of environment, microphone and data simulation mismatches in robust speech recognition | ||
| Computer Speech and Language, 2017. | ||
| """ | ||
|
|
||
| dataset_name = "CHiME4" | ||
|
|
||
| def __init__(self, csv_dir, sample_rate=16000, segment=3, return_id=False): | ||
| self.csv_dir = csv_dir | ||
| # Get the csv corresponding to origin | ||
| self.segment = segment | ||
| self.sample_rate = sample_rate | ||
| self.return_id = return_id | ||
| self.csv_path = [f for f in os.listdir(csv_dir) if "annotations" not in f][0] | ||
| # Open csv file and concatenate them | ||
| self.df = pd.read_csv(os.path.join(csv_dir, self.csv_path)) | ||
| # Get rid of the utterances too short | ||
| if self.segment is not None: | ||
| max_len = len(self.df) | ||
| self.seg_len = int(self.segment * self.sample_rate) | ||
| # Ignore the file shorter than the desired_length | ||
| self.df = self.df[self.df["duration"] >= self.seg_len] | ||
| print( | ||
| f"Drop {max_len - len(self.df)} utterances from {max_len} " | ||
| f"(shorter than {segment} seconds)" | ||
| ) | ||
| else: | ||
| self.seg_len = None | ||
|
|
||
| def __len__(self): | ||
| return len(self.df) | ||
|
|
||
| def __getitem__(self, idx): | ||
| # Get the row in dataframe | ||
| row = self.df.iloc[idx] | ||
| # Get mixture path | ||
| self.mixture_path = row["mixture_path"] | ||
| # If there is a seg start point is set randomly | ||
| if self.seg_len is not None: | ||
| start = random.randint(0, row["length"] - self.seg_len) | ||
| stop = start + self.seg_len | ||
| else: | ||
| start = 0 | ||
| stop = None | ||
|
|
||
| # Read the mixture | ||
| mixture, _ = sf.read(self.mixture_path, dtype="float32", start=start, stop=stop) | ||
| # Convert to torch tensor | ||
| mixture = torch.from_numpy(mixture) | ||
| if self.return_id: | ||
| id1 = row.wsj_id | ||
| return mixture, [id1] | ||
| return mixture |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,182 @@ | ||
| import os | ||
| import random | ||
| import soundfile as sf | ||
| import torch | ||
| import yaml | ||
| import json | ||
| import argparse | ||
| import numpy as np | ||
| import pandas as pd | ||
| from tqdm import tqdm | ||
| from pprint import pprint | ||
|
|
||
| from asteroid.data.chime4_dataset import CHiME4Dataset | ||
| from asteroid import ConvTasNet | ||
| from asteroid.models import save_publishable | ||
| from asteroid.utils import tensors_to_device | ||
| from asteroid.metrics import WERTracker, MockWERTracker | ||
|
|
||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument( | ||
| "--test_dir", type=str, required=True, help="Test directory including the csv files" | ||
| ) | ||
|
|
||
| parser.add_argument( | ||
| "--use_gpu", type=int, default=0, help="Whether to use the GPU for model execution" | ||
| ) | ||
| parser.add_argument("--exp_dir", default="exp/tmp", help="Experiment root") | ||
| parser.add_argument( | ||
| "--n_save_ex", type=int, default=1, help="Number of audio examples to save, -1 means all" | ||
| ) | ||
| parser.add_argument( | ||
| "--compute_wer", type=int, default=1, help="Compute WER using ESPNet's pretrained model" | ||
| ) | ||
| parser.add_argument( | ||
| "--asr_type", | ||
| default="noisy", | ||
| help="Choice for the ASR model whether trained on clean or noisy data. One of clean or noisy", | ||
| ) | ||
|
|
||
|
|
||
| # In CHiME 4 only the noisy data are available, hence no metrics. | ||
| COMPUTE_METRICS = [] | ||
|
|
||
|
|
||
| def update_compute_metrics(compute_wer, metric_list): | ||
| if not compute_wer: | ||
| return metric_list | ||
| try: | ||
| from espnet2.bin.asr_inference import Speech2Text | ||
| from espnet_model_zoo.downloader import ModelDownloader | ||
| except ModuleNotFoundError: | ||
| import warnings | ||
|
|
||
| warnings.warn("Couldn't find espnet installation. Continuing without.") | ||
| return metric_list | ||
| return metric_list + ["wer"] | ||
|
|
||
|
|
||
| def main(conf): | ||
|
|
||
| if conf["asr_type"] == "noisy": | ||
| asr_model_path = ( | ||
| "kamo-naoyuki/chime4_asr_train_asr_transformer3_raw_en_char_sp_valid.acc.ave" | ||
| ) | ||
| else: | ||
| asr_model_path = "kamo-naoyuki/wsj_transformer2" | ||
|
|
||
| compute_metrics = update_compute_metrics(conf["compute_wer"], COMPUTE_METRICS) | ||
| annot_path = [f for f in os.listdir(conf["test_dir"]) if "annotations" in f][0] | ||
| anno_df = pd.read_csv(os.path.join(conf["test_dir"], annot_path)) | ||
| wer_tracker = ( | ||
| MockWERTracker() if not conf["compute_wer"] else WERTracker(asr_model_path, anno_df) | ||
| ) | ||
| model_path = os.path.join(conf["exp_dir"], "best_model.pth") | ||
| model = ConvTasNet.from_pretrained(model_path) | ||
| # Handle device placement | ||
| if conf["use_gpu"]: | ||
| model.cuda() | ||
| model_device = next(model.parameters()).device | ||
| test_set = CHiME4Dataset( | ||
| csv_dir=conf["test_dir"], | ||
| sample_rate=conf["sample_rate"], | ||
| segment=None, | ||
| return_id=True, | ||
| ) # Uses all segment length | ||
| # Used to reorder sources only | ||
|
|
||
| # Randomly choose the indexes of sentences to save. | ||
| eval_save_dir = os.path.join(conf["exp_dir"], "chime4", conf["asr_type"]) | ||
| ex_save_dir = os.path.join(eval_save_dir, "examples/") | ||
| if conf["n_save_ex"] == -1: | ||
| conf["n_save_ex"] = len(test_set) | ||
| save_idx = random.sample(range(len(test_set)), conf["n_save_ex"]) | ||
| series_list = [] | ||
| torch.no_grad().__enter__() | ||
| for idx in tqdm(range(len(test_set))): | ||
| # Forward the network on the mixture. | ||
| mix, ids = test_set[idx] | ||
| mix = tensors_to_device(mix, device=model_device) | ||
| est_sources = model(mix.unsqueeze(0)) | ||
| mix_np = mix.cpu().data.numpy() | ||
| est_sources_np = est_sources.squeeze(0).cpu().data.numpy() | ||
| est_sources_np *= np.max(np.abs(mix_np)) / np.max(np.abs(est_sources_np)) | ||
| # For each utterance, we get a dictionary with the mixture path, | ||
| # the input and output metrics | ||
| utt_metrics = {"mix_path": test_set.mixture_path} | ||
| utt_metrics.update( | ||
| **wer_tracker( | ||
| mix=mix_np, | ||
| clean=None, | ||
| estimate=est_sources_np, | ||
| wav_id=ids, | ||
| sample_rate=conf["sample_rate"], | ||
| ) | ||
| ) | ||
| series_list.append(pd.Series(utt_metrics)) | ||
|
|
||
| # Save some examples in a folder. Wav files and metrics as text. | ||
| if idx in save_idx: | ||
| local_save_dir = os.path.join(ex_save_dir, "ex_{}/".format(idx)) | ||
| os.makedirs(local_save_dir, exist_ok=True) | ||
| sf.write(local_save_dir + "mixture.wav", mix_np, conf["sample_rate"]) | ||
| # Loop over the sources and estimates | ||
| for src_idx, est_src in enumerate(est_sources_np): | ||
| # est_src *= np.max(np.abs(mix_np)) / np.max(np.abs(est_src)) | ||
| sf.write( | ||
| local_save_dir + "s{}_estimate.wav".format(src_idx), | ||
| est_src, | ||
| conf["sample_rate"], | ||
| ) | ||
| # Write local metrics to the example folder. | ||
| with open(local_save_dir + "metrics.json", "w") as f: | ||
| json.dump(utt_metrics, f, indent=0) | ||
|
|
||
| # Save all metrics to the experiment folder. | ||
| all_metrics_df = pd.DataFrame(series_list) | ||
| all_metrics_df.to_csv(os.path.join(eval_save_dir, "all_metrics.csv")) | ||
|
|
||
| # Print and save summary metrics | ||
| final_results = {} | ||
| for metric_name in compute_metrics: | ||
| input_metric_name = "input_" + metric_name | ||
| ldf = all_metrics_df[metric_name] - all_metrics_df[input_metric_name] | ||
| final_results[metric_name] = all_metrics_df[metric_name].mean() | ||
| final_results[metric_name + "_imp"] = ldf.mean() | ||
|
|
||
| print("Overall metrics :") | ||
| pprint(final_results) | ||
| if conf["compute_wer"]: | ||
| print("\nWER report") | ||
| wer_card = wer_tracker.final_report_as_markdown() | ||
| print(wer_card) | ||
| # Save the report | ||
| with open(os.path.join(eval_save_dir, "final_wer.md"), "w") as f: | ||
| f.write(wer_card) | ||
| all_transcriptions = wer_tracker.transcriptions | ||
| with open(os.path.join(eval_save_dir, "all_transcriptions.json"), "w") as f: | ||
| json.dump(all_transcriptions, f, indent=4) | ||
|
|
||
| with open(os.path.join(eval_save_dir, "final_metrics.json"), "w") as f: | ||
| json.dump(final_results, f, indent=0) | ||
|
|
||
| model_dict = torch.load(model_path, map_location="cpu") | ||
| os.makedirs(os.path.join(conf["exp_dir"], "publish_dir"), exist_ok=True) | ||
| publishable = save_publishable( | ||
| os.path.join(conf["exp_dir"], "publish_dir"), | ||
| model_dict, | ||
| metrics=final_results, | ||
| train_conf=train_conf, | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| args = parser.parse_args() | ||
| arg_dic = dict(vars(args)) | ||
| # Load training config | ||
| conf_path = os.path.join(args.exp_dir, "conf.yml") | ||
| with open(conf_path) as f: | ||
| train_conf = yaml.safe_load(f) | ||
| arg_dic["sample_rate"] = train_conf["data"]["sample_rate"] | ||
| arg_dic["train_conf"] = train_conf | ||
| main(arg_dic) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only print things from
mixtureandestimates.And potential improvement.