-
Notifications
You must be signed in to change notification settings - Fork 447
Chime4 #423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Chime4 #423
Changes from 7 commits
96183f4
13704af
72039d2
6f2149d
4723e2a
dc2da64
fce54f6
f9b5c00
0c58deb
0615f71
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
| import pandas as pd | ||
| import soundfile as sf | ||
| import torch | ||
| from torch.utils.data import Dataset, DataLoader | ||
| import random as random | ||
| import os | ||
|
|
||
|
|
||
| class CHiME4(Dataset): | ||
| """Dataset class for CHiME4 source separation tasks. Only supports 'real' | ||
| data | ||
|
|
||
| Args: | ||
| csv_dir (str): The path to the metadata file. | ||
| sample_rate (int) : The sample rate of the sources and mixtures. | ||
| segment (int) : The desired sources and mixtures length in s. | ||
|
|
||
| References | ||
| Emmanuel Vincent, Shinji Watanabe, Aditya Arie Nugraha, Jon Barker, and Ricard Marxer | ||
| An analysis of environment, microphone and data simulation mismatches in robust speech recognition | ||
| Computer Speech and Language, 2017. | ||
| """ | ||
|
|
||
| dataset_name = "CHiME4" | ||
|
|
||
| def __init__(self, csv_dir, sample_rate=16000, segment=3, return_id=False): | ||
| self.csv_dir = csv_dir | ||
| # Get the csv corresponding to origin | ||
| self.segment = segment | ||
| self.sample_rate = sample_rate | ||
| self.return_id = return_id | ||
| self.csv_path = [f for f in os.listdir(csv_dir) if "annotations" not in f][0] | ||
| # Open csv file and concatenate them | ||
| self.df = pd.read_csv(os.path.join(csv_dir, self.csv_path)) | ||
| # Get rid of the utterances too short | ||
| if self.segment is not None: | ||
| max_len = len(self.df) | ||
| self.seg_len = int(self.segment * self.sample_rate) | ||
| # Ignore the file shorter than the desired_length | ||
| self.df = self.df[self.df["duration"] >= self.seg_len] | ||
| print( | ||
| f"Drop {max_len - len(self.df)} utterances from {max_len} " | ||
| f"(shorter than {segment} seconds)" | ||
| ) | ||
| else: | ||
| self.seg_len = None | ||
|
|
||
| def __len__(self): | ||
| return len(self.df) | ||
|
|
||
| def __getitem__(self, idx): | ||
| # Get the row in dataframe | ||
| row = self.df.iloc[idx] | ||
| # Get mixture path | ||
| self.mixture_path = row["mixture_path"] | ||
| # If there is a seg start point is set randomly | ||
| if self.seg_len is not None: | ||
| start = random.randint(0, row["length"] - self.seg_len) | ||
| stop = start + self.seg_len | ||
| else: | ||
| start = 0 | ||
| stop = None | ||
|
|
||
| # Read the mixture | ||
| mixture, _ = sf.read(self.mixture_path, dtype="float32", start=start, stop=stop) | ||
| # Convert to torch tensor | ||
| mixture = torch.from_numpy(mixture) | ||
| mock_source = torch.vstack([mixture]) | ||
| if self.return_id: | ||
| id1 = row.ID | ||
| return mixture, mock_source, [id1] | ||
| return mixture, mock_source | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's not return the mock_source. |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,185 @@ | ||
| import os | ||
| import random | ||
| import soundfile as sf | ||
| import torch | ||
| import yaml | ||
| import json | ||
| import argparse | ||
| import numpy as np | ||
| import pandas as pd | ||
| from tqdm import tqdm | ||
| from pprint import pprint | ||
|
|
||
| from asteroid.data.chime4_dataset import CHiME4 | ||
| from asteroid import ConvTasNet | ||
| from asteroid.models import save_publishable | ||
| from asteroid.utils import tensors_to_device | ||
| from asteroid.metrics import WERTracker, MockWERTracker | ||
|
|
||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument( | ||
| "--test_dir", type=str, required=True, help="Test directory including the csv files" | ||
| ) | ||
|
|
||
| parser.add_argument( | ||
| "--use_gpu", type=int, default=0, help="Whether to use the GPU for model execution" | ||
| ) | ||
| parser.add_argument("--exp_dir", default="exp/tmp", help="Experiment root") | ||
| parser.add_argument( | ||
| "--n_save_ex", type=int, default=10, help="Number of audio examples to save, -1 means all" | ||
| ) | ||
| parser.add_argument( | ||
| "--compute_wer", type=int, default=1, help="Compute WER using ESPNet's pretrained model" | ||
| ) | ||
| parser.add_argument( | ||
| "--asr_type", | ||
| default="noisy", | ||
| help="Choice for the ASR model whether trained on clean or noisy data. One of clean or noisy", | ||
| ) | ||
|
|
||
|
|
||
| # In CHiME 4 only the noisy data are available, hence no metrics. | ||
| COMPUTE_METRICS = [] | ||
|
|
||
|
|
||
| def update_compute_metrics(compute_wer, metric_list): | ||
| if not compute_wer: | ||
| return metric_list | ||
| try: | ||
| from espnet2.bin.asr_inference import Speech2Text | ||
| from espnet_model_zoo.downloader import ModelDownloader | ||
| except ModuleNotFoundError: | ||
| import warnings | ||
|
|
||
| warnings.warn("Couldn't find espnet installation. Continuing without.") | ||
| return metric_list | ||
| return metric_list + ["wer"] | ||
|
|
||
|
|
||
| def main(conf): | ||
|
|
||
| if conf["asr_type"] == "noisy": | ||
| asr_model_path = ( | ||
| "kamo-naoyuki/chime4_asr_train_asr_transformer3_raw_en_char_sp_valid.acc.ave" | ||
| ) | ||
| else: | ||
| asr_model_path = "kamo-naoyuki/wsj_transformer2" | ||
|
|
||
| compute_metrics = update_compute_metrics(conf["compute_wer"], COMPUTE_METRICS) | ||
| annot_path = [f for f in os.listdir(conf["test_dir"]) if "annotations" in f][0] | ||
| anno_df = pd.read_csv(os.path.join(conf["test_dir"], annot_path)) | ||
| wer_tracker = ( | ||
| MockWERTracker() if not conf["compute_wer"] else WERTracker(asr_model_path, anno_df) | ||
| ) | ||
| model_path = os.path.join(conf["exp_dir"], "best_model.pth") | ||
| model = ConvTasNet.from_pretrained(model_path) | ||
| # Handle device placement | ||
| if conf["use_gpu"]: | ||
| model.cuda() | ||
| model_device = next(model.parameters()).device | ||
| test_set = CHiME4( | ||
| csv_dir=conf["test_dir"], | ||
| sample_rate=conf["sample_rate"], | ||
| segment=None, | ||
| return_id=True, | ||
| ) # Uses all segment length | ||
| # Used to reorder sources only | ||
|
|
||
| # Randomly choose the indexes of sentences to save. | ||
| eval_save_dir = os.path.join(conf["exp_dir"], "chime4", conf["asr_type"]) | ||
| ex_save_dir = os.path.join(eval_save_dir, "examples/") | ||
| if conf["n_save_ex"] == -1: | ||
| conf["n_save_ex"] = len(test_set) | ||
| save_idx = random.sample(range(len(test_set)), conf["n_save_ex"]) | ||
| series_list = [] | ||
| torch.no_grad().__enter__() | ||
| for idx in tqdm(range(len(test_set))): | ||
| # Forward the network on the mixture. | ||
| mix, sources, ids = test_set[idx] | ||
| mix, sources = tensors_to_device([mix, sources], device=model_device) | ||
| est_sources = model(mix.unsqueeze(0)) | ||
| mix_np = mix.cpu().data.numpy() | ||
| sources_np = sources.cpu().data.numpy() | ||
| est_sources_np = est_sources.squeeze(0).cpu().data.numpy() | ||
| est_sources_np *= np.max(np.abs(mix_np)) / np.max(np.abs(est_sources_np)) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove the code related to |
||
| # For each utterance, we get a dictionary with the mixture path, | ||
| # the input and output metrics | ||
| utt_metrics = {"mix_path": test_set.mixture_path} | ||
| utt_metrics.update( | ||
| **wer_tracker( | ||
| mix=mix_np, | ||
| clean=sources_np, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Extend |
||
| estimate=est_sources_np, | ||
| wav_id=ids, | ||
| sample_rate=conf["sample_rate"], | ||
| ) | ||
| ) | ||
| series_list.append(pd.Series(utt_metrics)) | ||
|
|
||
| # Save some examples in a folder. Wav files and metrics as text. | ||
| if idx in save_idx: | ||
| local_save_dir = os.path.join(ex_save_dir, "ex_{}/".format(idx)) | ||
| os.makedirs(local_save_dir, exist_ok=True) | ||
| sf.write(local_save_dir + "mixture.wav", mix_np, conf["sample_rate"]) | ||
| # Loop over the sources and estimates | ||
| for src_idx, src in enumerate(sources_np): | ||
| sf.write(local_save_dir + "s{}.wav".format(src_idx), src, conf["sample_rate"]) | ||
| for src_idx, est_src in enumerate(est_sources_np): | ||
| # est_src *= np.max(np.abs(mix_np)) / np.max(np.abs(est_src)) | ||
| sf.write( | ||
| local_save_dir + "s{}_estimate.wav".format(src_idx), | ||
| est_src, | ||
| conf["sample_rate"], | ||
| ) | ||
| # Write local metrics to the example folder. | ||
| with open(local_save_dir + "metrics.json", "w") as f: | ||
| json.dump(utt_metrics, f, indent=0) | ||
|
|
||
| # Save all metrics to the experiment folder. | ||
| all_metrics_df = pd.DataFrame(series_list) | ||
| all_metrics_df.to_csv(os.path.join(eval_save_dir, "all_metrics.csv")) | ||
|
|
||
| # Print and save summary metrics | ||
| final_results = {} | ||
| for metric_name in compute_metrics: | ||
| input_metric_name = "input_" + metric_name | ||
| ldf = all_metrics_df[metric_name] - all_metrics_df[input_metric_name] | ||
| final_results[metric_name] = all_metrics_df[metric_name].mean() | ||
| final_results[metric_name + "_imp"] = ldf.mean() | ||
|
|
||
| print("Overall metrics :") | ||
| pprint(final_results) | ||
| if conf["compute_wer"]: | ||
| print("\nWER report") | ||
| wer_card = wer_tracker.final_report_as_markdown() | ||
| print(wer_card) | ||
|
Comment on lines
+139
to
+152
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only print things from |
||
| # Save the report | ||
| with open(os.path.join(eval_save_dir, "final_wer.md"), "w") as f: | ||
| f.write(wer_card) | ||
| all_transcriptions = wer_tracker.all_transcriptions() | ||
| with open(os.path.join(eval_save_dir, "all_transcriptions.json"), "w") as f: | ||
| json.dump(all_transcriptions, f, indent=4) | ||
|
|
||
| with open(os.path.join(eval_save_dir, "final_metrics.json"), "w") as f: | ||
| json.dump(final_results, f, indent=0) | ||
|
|
||
| model_dict = torch.load(model_path, map_location="cpu") | ||
| os.makedirs(os.path.join(conf["exp_dir"], "publish_dir"), exist_ok=True) | ||
| publishable = save_publishable( | ||
| os.path.join(conf["exp_dir"], "publish_dir"), | ||
| model_dict, | ||
| metrics=final_results, | ||
| train_conf=train_conf, | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| args = parser.parse_args() | ||
| arg_dic = dict(vars(args)) | ||
| # Load training config | ||
| conf_path = os.path.join(args.exp_dir, "conf.yml") | ||
| with open(conf_path) as f: | ||
| train_conf = yaml.safe_load(f) | ||
| arg_dic["sample_rate"] = train_conf["data"]["sample_rate"] | ||
| arg_dic["train_conf"] = train_conf | ||
| main(arg_dic) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,106 @@ | ||
| import os | ||
| import argparse | ||
| from glob import glob | ||
| import pandas as pd | ||
| import numpy as np | ||
|
|
||
| # Command line arguments | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument("--chime3_dir", type=str, default=None, help="Path to CHiME3 root directory") | ||
|
|
||
| # Set seed for random generation | ||
| SEED = 4 | ||
| np.random.seed(SEED) | ||
|
|
||
|
|
||
| def main(args): | ||
| chime3_dir = args.chime3_dir | ||
| create_local_metadata(chime3_dir) | ||
|
|
||
|
|
||
| def create_local_metadata(chime3_dir): | ||
| # Get CHiME-3 annotation files | ||
| c3_annot_files = [ | ||
| f for f in glob(os.path.join(chime3_dir, "data", "annotations", "*real*.json")) | ||
| ] | ||
| # Get CHiME-4 annotation files | ||
| c4_annot_files = [ | ||
| f for f in glob(os.path.join(chime3_dir, "data", "annotations", "*real*.list")) | ||
| ] | ||
| for c3_annot_file_path in c3_annot_files: | ||
| # Read CHiME-3 annotation file | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it could be helpful to move the entire |
||
| c3_annot_file = pd.read_json(c3_annot_file_path) | ||
| # subsets : "tr" "dt" "et" origin "real" or "simu" | ||
| subset, origin = os.path.split(c3_annot_file_path)[1].replace(".json", "").split("_") | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the entire code could benefit from using the Also I think it could be helpful to add a comment with a file name example everywhere filenames are parsed or constructed, for example |
||
| # Look for associated CHiME-4 file | ||
| if c3_annot_file_path.replace(".json", "_1ch_track.list") in c4_annot_files: | ||
| # Read CHiME-4 annotation file | ||
| c4_annot_file = pd.read_csv( | ||
| c3_annot_file_path.replace(".json", "_1ch_track.list"), header=None, names=["path"] | ||
| ) | ||
| else: | ||
| c4_annot_file = None | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This case isn't handled in the |
||
| df, df_2 = create_dataframe(chime3_dir, c3_annot_file, c4_annot_file, subset, origin) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could there be better names than |
||
| write_dataframe(df, df_2, subset, origin) | ||
|
|
||
|
|
||
| def create_dataframe(chime3_dir, c3_annot_file, c4_annot_file, subset, origin): | ||
| # Empty list for DataFrame creation | ||
| row_list = [] | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, is there a more helpful name than "row" and "row 2" here? |
||
| row_list_2 = [] | ||
| for row in c3_annot_file.itertuples(): | ||
| speaker = row.speaker | ||
| ID = row.wsj_name | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Which type of ID is this? Utterance ID? |
||
| env = row.environment | ||
| # if we are not dealing with tr subset | ||
| if "tr" not in subset: | ||
| mixture_path = c4_annot_file[c4_annot_file["path"].str.contains(ID + "_" + env)].values[ | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here, and also in the |
||
| 0 | ||
| ][0] | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know this is black, but this is unreadable. |
||
| mixture_path = os.path.join(chime3_dir, "data/audio/16kHz/isolated/", mixture_path) | ||
|
|
||
| # if we are dealing with the tr subset | ||
| else: | ||
| channel = np.random.randint(1, 7) | ||
| mixture_path = os.path.join( | ||
| chime3_dir, | ||
| "data/audio/16kHz/isolated/", | ||
| subset + "_" + env.lower() + "_" + origin, | ||
| speaker + "_" + ID + "_" + f".CH{channel}" ".wav", | ||
| ) | ||
| dot = row.dot | ||
| duration = row.end - row.start | ||
| temp_dict = { | ||
| "ID": ID, | ||
| "subset": subset, | ||
| "origin": origin, | ||
| "env": env, | ||
| "mixture_path": mixture_path, | ||
| "duration": duration, | ||
| } | ||
| trans_dict = {"utt_id": ID, "text": dot} | ||
| row_list.append(temp_dict) | ||
| row_list_2.append(trans_dict) | ||
| df = pd.DataFrame(row_list) | ||
| df_2 = pd.DataFrame(row_list_2) | ||
| return df, df_2 | ||
|
|
||
|
|
||
| def write_dataframe(df, df2, subset, origin): | ||
| if "et" in subset: | ||
| subdir = "test" | ||
| elif "dt" in subset: | ||
| subdir = "val" | ||
| else: | ||
| subdir = "train" | ||
| save_dir = os.path.join("data", subdir) | ||
| os.makedirs(save_dir, exist_ok=True) | ||
| save_path = os.path.join(save_dir, origin + "_1_ch_track.csv") | ||
| df.to_csv(save_path, index=False) | ||
| save_path2 = os.path.join(save_dir, origin + "_1_ch_track_annotations.csv") | ||
| df2.to_csv(save_path2, index=False) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| args = parser.parse_args() | ||
| main(args) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's call this
CHiME4Dataset.