Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions asteroid/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .kinect_wsj import KinectWsjMixDataset
from .fuss_dataset import FUSSDataset
from .dampvsep_dataset import DAMPVSEPSinglesDataset
from .chime4_dataset import CHiME4

__all__ = [
"AVSpeechDataset",
Expand All @@ -22,4 +23,5 @@
"KinectWsjMixDataset",
"FUSSDataset",
"DAMPVSEPSinglesDataset",
"CHiME4",
]
72 changes: 72 additions & 0 deletions asteroid/data/chime4_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import pandas as pd
import soundfile as sf
import torch
from torch.utils.data import Dataset, DataLoader
import random as random
import os


class CHiME4(Dataset):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's call this CHiME4Dataset.

"""Dataset class for CHiME4 source separation tasks. Only supports 'real'
data

Args:
csv_dir (str): The path to the metadata file.
sample_rate (int) : The sample rate of the sources and mixtures.
segment (int) : The desired sources and mixtures length in s.

References
Emmanuel Vincent, Shinji Watanabe, Aditya Arie Nugraha, Jon Barker, and Ricard Marxer
An analysis of environment, microphone and data simulation mismatches in robust speech recognition
Computer Speech and Language, 2017.
"""

dataset_name = "CHiME4"

def __init__(self, csv_dir, sample_rate=16000, segment=3, return_id=False):
self.csv_dir = csv_dir
# Get the csv corresponding to origin
self.segment = segment
self.sample_rate = sample_rate
self.return_id = return_id
self.csv_path = [f for f in os.listdir(csv_dir) if "annotations" not in f][0]
# Open csv file and concatenate them
self.df = pd.read_csv(os.path.join(csv_dir, self.csv_path))
# Get rid of the utterances too short
if self.segment is not None:
max_len = len(self.df)
self.seg_len = int(self.segment * self.sample_rate)
# Ignore the file shorter than the desired_length
self.df = self.df[self.df["duration"] >= self.seg_len]
print(
f"Drop {max_len - len(self.df)} utterances from {max_len} "
f"(shorter than {segment} seconds)"
)
else:
self.seg_len = None

def __len__(self):
return len(self.df)

def __getitem__(self, idx):
# Get the row in dataframe
row = self.df.iloc[idx]
# Get mixture path
self.mixture_path = row["mixture_path"]
# If there is a seg start point is set randomly
if self.seg_len is not None:
start = random.randint(0, row["length"] - self.seg_len)
stop = start + self.seg_len
else:
start = 0
stop = None

# Read the mixture
mixture, _ = sf.read(self.mixture_path, dtype="float32", start=start, stop=stop)
# Convert to torch tensor
mixture = torch.from_numpy(mixture)
mock_source = torch.vstack([mixture])
if self.return_id:
id1 = row.ID
return mixture, mock_source, [id1]
return mixture, mock_source
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not return the mock_source.

185 changes: 185 additions & 0 deletions egs/chime4/ConvTasNet/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import os
import random
import soundfile as sf
import torch
import yaml
import json
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm
from pprint import pprint

from asteroid.data.chime4_dataset import CHiME4
from asteroid import ConvTasNet
from asteroid.models import save_publishable
from asteroid.utils import tensors_to_device
from asteroid.metrics import WERTracker, MockWERTracker

parser = argparse.ArgumentParser()
parser.add_argument(
"--test_dir", type=str, required=True, help="Test directory including the csv files"
)

parser.add_argument(
"--use_gpu", type=int, default=0, help="Whether to use the GPU for model execution"
)
parser.add_argument("--exp_dir", default="exp/tmp", help="Experiment root")
parser.add_argument(
"--n_save_ex", type=int, default=10, help="Number of audio examples to save, -1 means all"
)
parser.add_argument(
"--compute_wer", type=int, default=1, help="Compute WER using ESPNet's pretrained model"
)
parser.add_argument(
"--asr_type",
default="noisy",
help="Choice for the ASR model whether trained on clean or noisy data. One of clean or noisy",
)


# In CHiME 4 only the noisy data are available, hence no metrics.
COMPUTE_METRICS = []


def update_compute_metrics(compute_wer, metric_list):
if not compute_wer:
return metric_list
try:
from espnet2.bin.asr_inference import Speech2Text
from espnet_model_zoo.downloader import ModelDownloader
except ModuleNotFoundError:
import warnings

warnings.warn("Couldn't find espnet installation. Continuing without.")
return metric_list
return metric_list + ["wer"]


def main(conf):

if conf["asr_type"] == "noisy":
asr_model_path = (
"kamo-naoyuki/chime4_asr_train_asr_transformer3_raw_en_char_sp_valid.acc.ave"
)
else:
asr_model_path = "kamo-naoyuki/wsj_transformer2"

compute_metrics = update_compute_metrics(conf["compute_wer"], COMPUTE_METRICS)
annot_path = [f for f in os.listdir(conf["test_dir"]) if "annotations" in f][0]
anno_df = pd.read_csv(os.path.join(conf["test_dir"], annot_path))
wer_tracker = (
MockWERTracker() if not conf["compute_wer"] else WERTracker(asr_model_path, anno_df)
)
model_path = os.path.join(conf["exp_dir"], "best_model.pth")
model = ConvTasNet.from_pretrained(model_path)
# Handle device placement
if conf["use_gpu"]:
model.cuda()
model_device = next(model.parameters()).device
test_set = CHiME4(
csv_dir=conf["test_dir"],
sample_rate=conf["sample_rate"],
segment=None,
return_id=True,
) # Uses all segment length
# Used to reorder sources only

# Randomly choose the indexes of sentences to save.
eval_save_dir = os.path.join(conf["exp_dir"], "chime4", conf["asr_type"])
ex_save_dir = os.path.join(eval_save_dir, "examples/")
if conf["n_save_ex"] == -1:
conf["n_save_ex"] = len(test_set)
save_idx = random.sample(range(len(test_set)), conf["n_save_ex"])
series_list = []
torch.no_grad().__enter__()
for idx in tqdm(range(len(test_set))):
# Forward the network on the mixture.
mix, sources, ids = test_set[idx]
mix, sources = tensors_to_device([mix, sources], device=model_device)
est_sources = model(mix.unsqueeze(0))
mix_np = mix.cpu().data.numpy()
sources_np = sources.cpu().data.numpy()
est_sources_np = est_sources.squeeze(0).cpu().data.numpy()
est_sources_np *= np.max(np.abs(mix_np)) / np.max(np.abs(est_sources_np))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the code related to sources.

# For each utterance, we get a dictionary with the mixture path,
# the input and output metrics
utt_metrics = {"mix_path": test_set.mixture_path}
utt_metrics.update(
**wer_tracker(
mix=mix_np,
clean=sources_np,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extend WerTracker to work without clean reference. And remove clean.

estimate=est_sources_np,
wav_id=ids,
sample_rate=conf["sample_rate"],
)
)
series_list.append(pd.Series(utt_metrics))

# Save some examples in a folder. Wav files and metrics as text.
if idx in save_idx:
local_save_dir = os.path.join(ex_save_dir, "ex_{}/".format(idx))
os.makedirs(local_save_dir, exist_ok=True)
sf.write(local_save_dir + "mixture.wav", mix_np, conf["sample_rate"])
# Loop over the sources and estimates
for src_idx, src in enumerate(sources_np):
sf.write(local_save_dir + "s{}.wav".format(src_idx), src, conf["sample_rate"])
for src_idx, est_src in enumerate(est_sources_np):
# est_src *= np.max(np.abs(mix_np)) / np.max(np.abs(est_src))
sf.write(
local_save_dir + "s{}_estimate.wav".format(src_idx),
est_src,
conf["sample_rate"],
)
# Write local metrics to the example folder.
with open(local_save_dir + "metrics.json", "w") as f:
json.dump(utt_metrics, f, indent=0)

# Save all metrics to the experiment folder.
all_metrics_df = pd.DataFrame(series_list)
all_metrics_df.to_csv(os.path.join(eval_save_dir, "all_metrics.csv"))

# Print and save summary metrics
final_results = {}
for metric_name in compute_metrics:
input_metric_name = "input_" + metric_name
ldf = all_metrics_df[metric_name] - all_metrics_df[input_metric_name]
final_results[metric_name] = all_metrics_df[metric_name].mean()
final_results[metric_name + "_imp"] = ldf.mean()

print("Overall metrics :")
pprint(final_results)
if conf["compute_wer"]:
print("\nWER report")
wer_card = wer_tracker.final_report_as_markdown()
print(wer_card)
Comment on lines +139 to +152
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only print things from mixture and estimates.
And potential improvement.

# Save the report
with open(os.path.join(eval_save_dir, "final_wer.md"), "w") as f:
f.write(wer_card)
all_transcriptions = wer_tracker.all_transcriptions()
with open(os.path.join(eval_save_dir, "all_transcriptions.json"), "w") as f:
json.dump(all_transcriptions, f, indent=4)

with open(os.path.join(eval_save_dir, "final_metrics.json"), "w") as f:
json.dump(final_results, f, indent=0)

model_dict = torch.load(model_path, map_location="cpu")
os.makedirs(os.path.join(conf["exp_dir"], "publish_dir"), exist_ok=True)
publishable = save_publishable(
os.path.join(conf["exp_dir"], "publish_dir"),
model_dict,
metrics=final_results,
train_conf=train_conf,
)


if __name__ == "__main__":
args = parser.parse_args()
arg_dic = dict(vars(args))
# Load training config
conf_path = os.path.join(args.exp_dir, "conf.yml")
with open(conf_path) as f:
train_conf = yaml.safe_load(f)
arg_dic["sample_rate"] = train_conf["data"]["sample_rate"]
arg_dic["train_conf"] = train_conf
main(arg_dic)
106 changes: 106 additions & 0 deletions egs/chime4/ConvTasNet/local/create_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import os
import argparse
from glob import glob
import pandas as pd
import numpy as np

# Command line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--chime3_dir", type=str, default=None, help="Path to CHiME3 root directory")

# Set seed for random generation
SEED = 4
np.random.seed(SEED)


def main(args):
chime3_dir = args.chime3_dir
create_local_metadata(chime3_dir)


def create_local_metadata(chime3_dir):
# Get CHiME-3 annotation files
c3_annot_files = [
f for f in glob(os.path.join(chime3_dir, "data", "annotations", "*real*.json"))
]
# Get CHiME-4 annotation files
c4_annot_files = [
f for f in glob(os.path.join(chime3_dir, "data", "annotations", "*real*.list"))
]
for c3_annot_file_path in c3_annot_files:
# Read CHiME-3 annotation file
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it could be helpful to move the entire for loop body into a separate function so as to make it obvious that all loop iterations are entirely independent.

c3_annot_file = pd.read_json(c3_annot_file_path)
# subsets : "tr" "dt" "et" origin "real" or "simu"
subset, origin = os.path.split(c3_annot_file_path)[1].replace(".json", "").split("_")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the entire code could benefit from using the pathlib API. Eg here, c3_annot_file_path.with_suffix("").name.split("_")

Also I think it could be helpful to add a comment with a file name example everywhere filenames are parsed or constructed, for example

# Extract subset and origin from /foo/bar/<subset>_<origin>.json

# Look for associated CHiME-4 file
if c3_annot_file_path.replace(".json", "_1ch_track.list") in c4_annot_files:
# Read CHiME-4 annotation file
c4_annot_file = pd.read_csv(
c3_annot_file_path.replace(".json", "_1ch_track.list"), header=None, names=["path"]
)
else:
c4_annot_file = None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case isn't handled in the create_dataframe function

df, df_2 = create_dataframe(chime3_dir, c3_annot_file, c4_annot_file, subset, origin)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could there be better names than df and df_2?

write_dataframe(df, df_2, subset, origin)


def create_dataframe(chime3_dir, c3_annot_file, c4_annot_file, subset, origin):
# Empty list for DataFrame creation
row_list = []
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, is there a more helpful name than "row" and "row 2" here?

row_list_2 = []
for row in c3_annot_file.itertuples():
speaker = row.speaker
ID = row.wsj_name
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which type of ID is this? Utterance ID?
Can we change the name to clarify, and put everything lower case?

env = row.environment
# if we are not dealing with tr subset
if "tr" not in subset:
mixture_path = c4_annot_file[c4_annot_file["path"].str.contains(ID + "_" + env)].values[
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, and also in the else case below, it could be helpful to add a comment with an example of what we're looking for in the annotation data frame, like the I suggested with the path splitting above

0
][0]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is black, but this is unreadable.
We can call c4_anno and c3_anno instead of c4_annot_file, or smth else, as you want.

mixture_path = os.path.join(chime3_dir, "data/audio/16kHz/isolated/", mixture_path)

# if we are dealing with the tr subset
else:
channel = np.random.randint(1, 7)
mixture_path = os.path.join(
chime3_dir,
"data/audio/16kHz/isolated/",
subset + "_" + env.lower() + "_" + origin,
speaker + "_" + ID + "_" + f".CH{channel}" ".wav",
)
dot = row.dot
duration = row.end - row.start
temp_dict = {
"ID": ID,
"subset": subset,
"origin": origin,
"env": env,
"mixture_path": mixture_path,
"duration": duration,
}
trans_dict = {"utt_id": ID, "text": dot}
row_list.append(temp_dict)
row_list_2.append(trans_dict)
df = pd.DataFrame(row_list)
df_2 = pd.DataFrame(row_list_2)
return df, df_2


def write_dataframe(df, df2, subset, origin):
if "et" in subset:
subdir = "test"
elif "dt" in subset:
subdir = "val"
else:
subdir = "train"
save_dir = os.path.join("data", subdir)
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, origin + "_1_ch_track.csv")
df.to_csv(save_path, index=False)
save_path2 = os.path.join(save_dir, origin + "_1_ch_track_annotations.csv")
df2.to_csv(save_path2, index=False)


if __name__ == "__main__":
args = parser.parse_args()
main(args)
Loading