Skip to content

Commit

Permalink
use lmdb
Browse files Browse the repository at this point in the history
  • Loading branch information
youngsheen committed Jun 7, 2024
1 parent f824da3 commit f7a30a0
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 105 deletions.
5 changes: 0 additions & 5 deletions preprocess/data_handler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
# Copyright (c) Facebook, Inc. and its affiliates.

# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import pathlib
import logging
import torch
Expand Down
5 changes: 0 additions & 5 deletions preprocess/distributed.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
# Copyright (c) Facebook, Inc. and its affiliates.

# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import subprocess
from dataclasses import dataclass
Expand Down
4 changes: 2 additions & 2 deletions preprocess/encodec_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,6 @@ def get_features(self, x, sr):
for start in range(0, x.size(-1), self.max_chunk):
encoded_frames = self.model.encode(x[...,start : start + self.max_chunk])
codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # [B, n_q, T]
feat.append(codes.squeeze(0))
feat.append(codes.squeeze(0).cpu())

return {"codec": torch.cat(feat, 1).cpu()}
return {"codec": torch.cat(feat, dim=1)}
9 changes: 0 additions & 9 deletions preprocess/get_manifest.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,3 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Data pre-processing: build vocabularies and binarize training data.
"""

import argparse
import glob
import os
Expand Down
24 changes: 11 additions & 13 deletions preprocess/run.sh
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
SPLIT=test-clean
SPLIT=test-other
ROOT=/data3/yongxinzhu

export TORCH_HOME="/data/yongxinzhu/.cache/torch"
export FAIRSEQ2_CACHE_DIR="/data/yongxinzhu/.cache/fairseq2"

python preprocess/get_manifest.py \
--root datasets/librispeech/LibriSpeech/$SPLIT \
--dest datasets/librispeech \
--root $ROOT/LibriSpeech/LibriSpeech/$SPLIT \
--dest $ROOT/LibriSpeech \
--ext flac \
--name $SPLIT

CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nnodes=1 --nproc_per_node=4 --master_port=6666 \
preprocess/transcribe.py \
--manifest datasets/librispeech/$SPLIT.tsv \
--seamless


CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nnodes=1 --nproc_per_node=4 --master_port=6666 \
preprocess/transcribe.py \
--manifest datasets/librispeech/$SPLIT.tsv \
--codec --bandwidth 6
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nnodes=1 --nproc_per_node=8 --master_port=6669 \
preprocess/transcribe1.py \
--manifest $ROOT/LibriSpeech/$SPLIT.tsv \
--bandwidth 6 --fp16
20 changes: 8 additions & 12 deletions preprocess/seamless_reader.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
# Copyright (c) Facebook, Inc. and its affiliates.

# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


import torch
import torch.nn.functional as F

Expand All @@ -22,7 +16,7 @@

class Wav2vecFeatureReader(torch.nn.Module):
def __init__(
self, checkpoint_path, kmeans_path, layer=6, max_chunk=100 * 16_000, lazy_load=False
self, checkpoint_path, kmeans_path, layer=None, dtype = torch.float32, max_chunk=100 * 16_000, lazy_load=False
):
super().__init__()
# NB: fairseq doesn't support pathlib.Path
Expand All @@ -34,19 +28,19 @@ def __init__(
self.out_layer_number = layer - 1
self.max_chunk = max_chunk
# this is useful for determining the device
self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float).cuda())
self.register_buffer("_float_tensor", torch.tensor([0], dtype=dtype).cuda())
if not self.lazy_load:
self.load_checkpoint_()

@torch.no_grad() # otherwise some non-leaf nodes appear which breaks serialization
def load_checkpoint_(self):
wav2vec2_model = load_wav2vec2_model(
self.model_name_or_card, device=self.device, dtype=torch.float32
self.model_name_or_card, device=self.device, dtype=self._float_tensor.dtype
)
wav2vec2_model.eval()
assert isinstance(wav2vec2_model, Wav2Vec2Model)
self.model = Wav2Vec2LayerOutputModel(wav2vec2_model)
self.kmeans_model = KmeansModel(self.kmeans_url, self.device, torch.float32)
self.kmeans_model = KmeansModel(self.kmeans_url, self.device, self._float_tensor.dtype)
self.collate = Collater(pad_value=1, pad_to_multiple=2)

@property
Expand All @@ -69,6 +63,8 @@ def get_features(self, inputs, sr):
inputs = inputs.view(1, -1)
inputs = F.layer_norm(inputs, inputs.shape)

inputs = inputs.type_as(self._float_tensor)

if inputs.size(1) > self.max_chunk:
print("too long:", inputs.size(1) / 16000, "s")

Expand All @@ -91,12 +87,12 @@ def get_features(self, inputs, sr):
batch = SequenceBatch(seqs=seqs, padding_mask=padding_mask)
features = self.model(batch, self.out_layer_number).squeeze(0)
units = self.kmeans_model(features)
feat.append(units.unsqueeze(0))
feat.append(units.unsqueeze(0).cpu())

#units, durations = torch.unique_consecutive(units, return_counts=True)

item = {
"units": torch.cat(feat, 1).squeeze(0).cpu(), #no reuduce
"units": torch.cat(feat, dim = 1).squeeze(0), #no reuduce
}
return item

117 changes: 58 additions & 59 deletions preprocess/transcribe.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
# Copyright (c) Facebook, Inc. and its affiliates.

# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


import torch.distributed as distr
import torch
import pathlib
import numpy as np
import lmdb
import shutil
import os
import logging
import tqdm
Expand All @@ -34,20 +30,15 @@ def get_args():
required=True,
help="Path to the dataset manifest file"
)
parser.add_argument(
"--codec",
action="store_true",
)
parser.add_argument(
"--bandwidth",
type=int,
default=6,
)
parser.add_argument(
"--seamless",
"--fp16",
action="store_true",
)

parser.add_argument("--distributed_port", type=int, default=58554)

args = parser.parse_args()
Expand All @@ -56,68 +47,75 @@ def get_args():
return args


def worker_shard_path(fname, suffix, worker_id) -> pathlib.Path:
return fname.with_suffix(f".{suffix}_partial_{worker_id}")


def transcribe(args, rank, world_size):
def transcribe_lmdb(args, rank, world_size):
dataset = ManifestDataset(args.manifest)

if args.codec:
speech_encoder = EncodecFeatureReader(
bandwidth = args.bandwidth,
)
os.makedirs(args.manifest.parent / f"meta24khz_{args.bandwidth}kpbs_codec", exist_ok=True)
output_files = jsonlines.open(worker_shard_path(args.manifest.parent / f"meta24khz_{args.bandwidth}kpbs_codec" / args.manifest.stem, "jsonl", rank), mode="w")
elif args.seamless:
speech_encoder = Wav2vecFeatureReader(
checkpoint_path = "xlsr2_1b_v2",
kmeans_path = "https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/kmeans_10k.npy",
layer=35
)
os.makedirs(args.manifest.parent / "xlsr2_unit", exist_ok=True)
output_files = jsonlines.open(worker_shard_path(args.manifest.parent / "xlsr2_unit" / args.manifest.stem, "jsonl", rank), mode="w")
else:
raise NotImplementedError


speech_encoder_enocodec = EncodecFeatureReader(
bandwidth = args.bandwidth,
)
os.makedirs(args.manifest.parent / f"meta24khz_{args.bandwidth}kpbs_codec", exist_ok=True)
lmdb_path = worker_shard_path(args.manifest.parent / f"meta24khz_{args.bandwidth}kpbs_codec" / args.manifest.stem, "", rank).as_posix()
if os.path.exists(lmdb_path):
shutil.rmtree(lmdb_path)
output_files_encodec = lmdb.open(lmdb_path, map_size=int(1e12))

speech_encoder_seamless = Wav2vecFeatureReader(
checkpoint_path = "xlsr2_1b_v2",
kmeans_path = "https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/kmeans_10k.npy",
layer=35,
dtype=torch.float16 if args.fp16 else torch.float32
)
os.makedirs(args.manifest.parent / "xlsr2_unit", exist_ok=True)
lmdb_path = worker_shard_path(args.manifest.parent / "xlsr2_unit" / args.manifest.stem, "", rank).as_posix()
if os.path.exists(lmdb_path):
shutil.rmtree(lmdb_path)
output_files_seamless = lmdb.open(lmdb_path, map_size=int(1e12))


for i in tqdm.tqdm(range(rank, len(dataset), world_size)):
audio_path = dataset[i]

waveform, sr = torchaudio.load(audio_path)
waveform = waveform.squeeze(0)
encoded = speech_encoder(waveform, sr)
encoded_encodec = speech_encoder_enocodec(waveform, sr)
encoded_seamless = speech_encoder_seamless(waveform, sr)

if args.codec:
stream = encoded['codec'].tolist()
stream = {"name": audio_path.as_posix(), "codec": stream}

if args.seamless:
stream = encoded['units'].tolist()
stream = {"name": audio_path.as_posix(), "unit": stream}
with output_files_encodec.begin(write=True) as txn:
stream = encoded_encodec['codec'].numpy()
txn.put(audio_path.as_posix().encode('utf-8'), stream.tobytes())

output_files.write(stream)
output_files.close()
with output_files_seamless.begin(write=True) as txn:
stream = encoded_seamless['units'].numpy()
txn.put(audio_path.as_posix().encode('utf-8'), stream.tobytes())

output_files_encodec.close()
output_files_seamless.close()

if args.codec:
return args.manifest.parent / f"meta24khz_{args.bandwidth}kpbs_codec" / args.manifest.stem
else:
return args.manifest.parent / "xlsr2_unit" / args.manifest.stem
return args.manifest.parent / f"meta24khz_{args.bandwidth}kpbs_codec" / args.manifest.stem, args.manifest.parent / "xlsr2_unit" / args.manifest.stem


def merge_files(full_output, suffix, n_workers):
output = full_output.with_suffix(f".{suffix}")
def merge_lmdb(full_output, suffix, n_workers):
env = lmdb.open(full_output.as_posix(), map_size=int(1e12))
with env.begin(write=True) as txn:
for worker_id in range(n_workers):
partial_path = worker_shard_path(full_output, suffix, worker_id)
lmdb_env = lmdb.open(partial_path.as_posix(), readonly=True)
with lmdb_env.begin() as lmdb_txn:
for key, value in lmdb_txn.cursor():
txn.put(key, value)
lmdb_env.close()
env.close()

run_list = ["cat"]
for worker_id in range(n_workers):
partial_path = worker_shard_path(full_output, suffix, worker_id)
run_list.append(partial_path.as_posix())
run_list.append(">")
run_list.append(output.as_posix())
subprocess.run(" ".join(run_list), shell=True, check=True)
for worker_id in range(n_workers):
partial_path = worker_shard_path(full_output, suffix, worker_id)
partial_path.unlink()
shutil.rmtree(partial_path)
print(f"Deleted: {partial_path}")


def worker_shard_path(fname, suffix, worker_id) -> pathlib.Path:
return fname.with_suffix(f".{suffix}_partial_{worker_id}")


def main(args):
Expand All @@ -126,13 +124,14 @@ def main(args):

n_gpus = torch.cuda.device_count()
with torch.cuda.device(context.local_rank % n_gpus):
output_file = transcribe(args, context.rank, context.world_size)
output_file_codec, output_file_seamless = transcribe_lmdb(args, context.rank, context.world_size)

if context.world_size > 1:
distr.barrier()

if context.is_leader:
merge_files(output_file, "jsonl", context.world_size)
merge_lmdb(output_file_codec, "", context.world_size)
merge_lmdb(output_file_seamless, "", context.world_size)

if __name__ == "__main__":
args = get_args()
Expand Down

0 comments on commit f7a30a0

Please sign in to comment.