Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions asteroid/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .kinect_wsj import KinectWsjMixDataset
from .fuss_dataset import FUSSDataset
from .dampvsep_dataset import DAMPVSEPSinglesDataset
from .chime4_dataset import CHiME4Dataset

__all__ = [
"AVSpeechDataset",
Expand All @@ -22,4 +23,5 @@
"KinectWsjMixDataset",
"FUSSDataset",
"DAMPVSEPSinglesDataset",
"CHiME4Dataset",
]
71 changes: 71 additions & 0 deletions asteroid/data/chime4_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import pandas as pd
import soundfile as sf
import torch
from torch.utils.data import Dataset, DataLoader
import random as random
import os


class CHiME4Dataset(Dataset):
"""Dataset class for CHiME4 source separation tasks. Only supports 'real'
data

Args:
csv_dir (str): The path to the metadata file.
sample_rate (int) : The sample rate of the sources and mixtures.
segment (int) : The desired sources and mixtures length in s.

References
Emmanuel Vincent, Shinji Watanabe, Aditya Arie Nugraha, Jon Barker, and Ricard Marxer
An analysis of environment, microphone and data simulation mismatches in robust speech recognition
Computer Speech and Language, 2017.
"""

dataset_name = "CHiME4"

def __init__(self, csv_dir, sample_rate=16000, segment=3, return_id=False):
self.csv_dir = csv_dir
# Get the csv corresponding to origin
self.segment = segment
self.sample_rate = sample_rate
self.return_id = return_id
self.csv_path = [f for f in os.listdir(csv_dir) if "annotations" not in f][0]
# Open csv file and concatenate them
self.df = pd.read_csv(os.path.join(csv_dir, self.csv_path))
# Get rid of the utterances too short
if self.segment is not None:
max_len = len(self.df)
self.seg_len = int(self.segment * self.sample_rate)
# Ignore the file shorter than the desired_length
self.df = self.df[self.df["duration"] >= self.seg_len]
print(
f"Drop {max_len - len(self.df)} utterances from {max_len} "
f"(shorter than {segment} seconds)"
)
else:
self.seg_len = None

def __len__(self):
return len(self.df)

def __getitem__(self, idx):
# Get the row in dataframe
row = self.df.iloc[idx]
# Get mixture path
self.mixture_path = row["mixture_path"]
# If there is a seg start point is set randomly
if self.seg_len is not None:
start = random.randint(0, row["length"] - self.seg_len)
stop = start + self.seg_len
else:
start = 0
stop = None

# Read the mixture
mixture, _ = sf.read(self.mixture_path, dtype="float32", start=start, stop=stop)
# Convert to torch tensor
mixture = torch.from_numpy(mixture)
if self.return_id:
id1 = row.wsj_id
return mixture, [id1]
return mixture
61 changes: 37 additions & 24 deletions asteroid/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,19 +290,24 @@ def __call__(
self.mix_counter += out_count
local_mix_counter += out_count
self.input_txt_list.append(dict(utt_id=tmp_id, text=txt))
# Average WER for the clean pair
for i, (wav, tmp_id) in enumerate(zip(clean, wav_id)):
txt = self.predict_hypothesis(wav)
out_count = Counter(
self.hsdi(
truth=self.trans_dic[tmp_id], hypothesis=txt, transformation=self.transformation
if clean is not None:
# Average WER for the clean pair
for i, (wav, tmp_id) in enumerate(zip(clean, wav_id)):
txt = self.predict_hypothesis(wav)
out_count = Counter(
self.hsdi(
truth=self.trans_dic[tmp_id],
hypothesis=txt,
transformation=self.transformation,
)
)
)
self.clean_counter += out_count
local_clean_counter += out_count
self.clean_txt_list.append(dict(utt_id=tmp_id, text=txt))
trans_dict["clean"][f"utt_id_{i}"] = tmp_id
trans_dict["clean"][f"txt_{i}"] = txt
self.clean_counter += out_count
local_clean_counter += out_count
self.clean_txt_list.append(dict(utt_id=tmp_id, text=txt))
trans_dict["clean"][f"utt_id_{i}"] = tmp_id
trans_dict["clean"][f"txt_{i}"] = txt
else:
self.clean_counter = None
# Average WER for the estimate pair
for i, (est, tmp_id) in enumerate(zip(estimate, wav_id)):
txt = self.predict_hypothesis(est)
Expand All @@ -316,12 +321,16 @@ def __call__(
self.output_txt_list.append(dict(utt_id=tmp_id, text=txt))
trans_dict["estimates"][f"utt_id_{i}"] = tmp_id
trans_dict["estimates"][f"txt_{i}"] = txt

self.transcriptions.append(trans_dict)
return dict(
wer_dict = dict(
input_wer=self.wer_from_hsdi(**dict(local_mix_counter)),
clean_wer=self.wer_from_hsdi(**dict(local_clean_counter)),
wer=self.wer_from_hsdi(**dict(local_est_counter)),
)
if clean is not None:
wer_dict["clean_wer"] = self.wer_from_hsdi(**dict(local_clean_counter))

return wer_dict

@staticmethod
def wer_from_hsdi(hits=0, substitutions=0, deletions=0, insertions=0):
Expand Down Expand Up @@ -359,31 +368,35 @@ def _df_to_dict(df):
def final_df(self):
"""Generate a MarkDown table, as done by ESPNet."""
mix_n_word = sum(self.mix_counter[k] for k in ["hits", "substitutions", "deletions"])
clean_n_word = sum(self.clean_counter[k] for k in ["hits", "substitutions", "deletions"])
est_n_word = sum(self.est_counter[k] for k in ["hits", "substitutions", "deletions"])
mix_wer = self.wer_from_hsdi(**dict(self.mix_counter))
clean_wer = self.wer_from_hsdi(**dict(self.clean_counter))
est_wer = self.wer_from_hsdi(**dict(self.est_counter))

mix_hsdi = [
self.mix_counter[k] for k in ["hits", "substitutions", "deletions", "insertions"]
]
clean_hsdi = [
self.clean_counter[k] for k in ["hits", "substitutions", "deletions", "insertions"]
]
est_hsdi = [
self.est_counter[k] for k in ["hits", "substitutions", "deletions", "insertions"]
]
# Snt Wrd HSDI Err S.Err
for_mix = [len(self.mix_counter), mix_n_word] + mix_hsdi + [mix_wer, "-"]
for_clean = [len(self.clean_counter), clean_n_word] + clean_hsdi + [clean_wer, "-"]
for_est = [len(self.est_counter), est_n_word] + est_hsdi + [est_wer, "-"]

table = [
["test_clean / mixture"] + for_mix,
["test_clean / clean"] + for_clean,
["test_clean / separated"] + for_est,
["ground_truth / mixture"] + for_mix,
["ground_truth / separated"] + for_est,
]

if self.clean_counter is not None:
clean_n_word = sum(
self.clean_counter[k] for k in ["hits", "substitutions", "deletions"]
)
clean_wer = self.wer_from_hsdi(**dict(self.clean_counter))
clean_hsdi = [
self.clean_counter[k] for k in ["hits", "substitutions", "deletions", "insertions"]
]
for_clean = [len(self.clean_counter), clean_n_word] + clean_hsdi + [clean_wer, "-"]
table.insert(1, ["ground_truth / clean"] + for_mix)

df = pd.DataFrame(
table, columns=["dataset", "Snt", "Wrd", "Corr", "Sub", "Del", "Ins", "Err", "S.Err"]
)
Expand Down
182 changes: 182 additions & 0 deletions egs/chime4/ConvTasNet/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import os
import random
import soundfile as sf
import torch
import yaml
import json
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm
from pprint import pprint

from asteroid.data.chime4_dataset import CHiME4Dataset
from asteroid import ConvTasNet
from asteroid.models import save_publishable
from asteroid.utils import tensors_to_device
from asteroid.metrics import WERTracker, MockWERTracker

parser = argparse.ArgumentParser()
parser.add_argument(
"--test_dir", type=str, required=True, help="Test directory including the csv files"
)

parser.add_argument(
"--use_gpu", type=int, default=0, help="Whether to use the GPU for model execution"
)
parser.add_argument("--exp_dir", default="exp/tmp", help="Experiment root")
parser.add_argument(
"--n_save_ex", type=int, default=1, help="Number of audio examples to save, -1 means all"
)
parser.add_argument(
"--compute_wer", type=int, default=1, help="Compute WER using ESPNet's pretrained model"
)
parser.add_argument(
"--asr_type",
default="noisy",
help="Choice for the ASR model whether trained on clean or noisy data. One of clean or noisy",
)


# In CHiME 4 only the noisy data are available, hence no metrics.
COMPUTE_METRICS = []


def update_compute_metrics(compute_wer, metric_list):
if not compute_wer:
return metric_list
try:
from espnet2.bin.asr_inference import Speech2Text
from espnet_model_zoo.downloader import ModelDownloader
except ModuleNotFoundError:
import warnings

warnings.warn("Couldn't find espnet installation. Continuing without.")
return metric_list
return metric_list + ["wer"]


def main(conf):

if conf["asr_type"] == "noisy":
asr_model_path = (
"kamo-naoyuki/chime4_asr_train_asr_transformer3_raw_en_char_sp_valid.acc.ave"
)
else:
asr_model_path = "kamo-naoyuki/wsj_transformer2"

compute_metrics = update_compute_metrics(conf["compute_wer"], COMPUTE_METRICS)
annot_path = [f for f in os.listdir(conf["test_dir"]) if "annotations" in f][0]
anno_df = pd.read_csv(os.path.join(conf["test_dir"], annot_path))
wer_tracker = (
MockWERTracker() if not conf["compute_wer"] else WERTracker(asr_model_path, anno_df)
)
model_path = os.path.join(conf["exp_dir"], "best_model.pth")
model = ConvTasNet.from_pretrained(model_path)
# Handle device placement
if conf["use_gpu"]:
model.cuda()
model_device = next(model.parameters()).device
test_set = CHiME4Dataset(
csv_dir=conf["test_dir"],
sample_rate=conf["sample_rate"],
segment=None,
return_id=True,
) # Uses all segment length
# Used to reorder sources only

# Randomly choose the indexes of sentences to save.
eval_save_dir = os.path.join(conf["exp_dir"], "chime4", conf["asr_type"])
ex_save_dir = os.path.join(eval_save_dir, "examples/")
if conf["n_save_ex"] == -1:
conf["n_save_ex"] = len(test_set)
save_idx = random.sample(range(len(test_set)), conf["n_save_ex"])
series_list = []
torch.no_grad().__enter__()
for idx in tqdm(range(len(test_set))):
# Forward the network on the mixture.
mix, ids = test_set[idx]
mix = tensors_to_device(mix, device=model_device)
est_sources = model(mix.unsqueeze(0))
mix_np = mix.cpu().data.numpy()
est_sources_np = est_sources.squeeze(0).cpu().data.numpy()
est_sources_np *= np.max(np.abs(mix_np)) / np.max(np.abs(est_sources_np))
# For each utterance, we get a dictionary with the mixture path,
# the input and output metrics
utt_metrics = {"mix_path": test_set.mixture_path}
utt_metrics.update(
**wer_tracker(
mix=mix_np,
clean=None,
estimate=est_sources_np,
wav_id=ids,
sample_rate=conf["sample_rate"],
)
)
series_list.append(pd.Series(utt_metrics))

# Save some examples in a folder. Wav files and metrics as text.
if idx in save_idx:
local_save_dir = os.path.join(ex_save_dir, "ex_{}/".format(idx))
os.makedirs(local_save_dir, exist_ok=True)
sf.write(local_save_dir + "mixture.wav", mix_np, conf["sample_rate"])
# Loop over the sources and estimates
for src_idx, est_src in enumerate(est_sources_np):
# est_src *= np.max(np.abs(mix_np)) / np.max(np.abs(est_src))
sf.write(
local_save_dir + "s{}_estimate.wav".format(src_idx),
est_src,
conf["sample_rate"],
)
# Write local metrics to the example folder.
with open(local_save_dir + "metrics.json", "w") as f:
json.dump(utt_metrics, f, indent=0)

# Save all metrics to the experiment folder.
all_metrics_df = pd.DataFrame(series_list)
all_metrics_df.to_csv(os.path.join(eval_save_dir, "all_metrics.csv"))

# Print and save summary metrics
final_results = {}
for metric_name in compute_metrics:
input_metric_name = "input_" + metric_name
ldf = all_metrics_df[metric_name] - all_metrics_df[input_metric_name]
final_results[metric_name] = all_metrics_df[metric_name].mean()
final_results[metric_name + "_imp"] = ldf.mean()

print("Overall metrics :")
pprint(final_results)
if conf["compute_wer"]:
print("\nWER report")
wer_card = wer_tracker.final_report_as_markdown()
print(wer_card)
Comment on lines +139 to +152
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only print things from mixture and estimates.
And potential improvement.

# Save the report
with open(os.path.join(eval_save_dir, "final_wer.md"), "w") as f:
f.write(wer_card)
all_transcriptions = wer_tracker.transcriptions
with open(os.path.join(eval_save_dir, "all_transcriptions.json"), "w") as f:
json.dump(all_transcriptions, f, indent=4)

with open(os.path.join(eval_save_dir, "final_metrics.json"), "w") as f:
json.dump(final_results, f, indent=0)

model_dict = torch.load(model_path, map_location="cpu")
os.makedirs(os.path.join(conf["exp_dir"], "publish_dir"), exist_ok=True)
publishable = save_publishable(
os.path.join(conf["exp_dir"], "publish_dir"),
model_dict,
metrics=final_results,
train_conf=train_conf,
)


if __name__ == "__main__":
args = parser.parse_args()
arg_dic = dict(vars(args))
# Load training config
conf_path = os.path.join(args.exp_dir, "conf.yml")
with open(conf_path) as f:
train_conf = yaml.safe_load(f)
arg_dic["sample_rate"] = train_conf["data"]["sample_rate"]
arg_dic["train_conf"] = train_conf
main(arg_dic)
Loading