Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Facing issue at the time of Zipformer training #1867

Open
mukherjeesougata opened this issue Jan 18, 2025 · 2 comments
Open

Facing issue at the time of Zipformer training #1867

mukherjeesougata opened this issue Jan 18, 2025 · 2 comments

Comments

@mukherjeesougata
Copy link

I am trying to run the Zipformer model using my custom dataset. For that the steps that I have followed are given below:-

  • I have prepared the data by running the command lhotse kaldi import {train, dev, test}/ 16000 manifests/{train, dev, test}_manifest.
  • I have completed the fbank extraction stage (stage 3) of prepare.sh script. which generated the following files and folders which are shown in the figure below:-

Image

  • After this I have prepared BPE based lang which generated the folder lang_bpe_500 containing bpe.model, tokens.txt, transcript_word.txt, unigram_500.model, unigram_500.vocab files
  • Finally I have run the CLI which is given below:-
    ./pruned_transducer_stateless7_streaming/train.py --world-size 2 --num-epochs 30 --start-epoch 1 --use-fp16 1 --exp-dir pruned_transducer_stateless7_streaming/exp --max-duration 200 --enable-musan False

I am getting the following error:-

  File "/DATA/Sougata/icefall_toolkit/icefall/egs/Hindi/ASR/./pruned_transducer_stateless7_streaming/train.py", line 1273, in <module>
    main()
  File "/DATA/Sougata/icefall_toolkit/icefall/egs/Hindi/ASR/./pruned_transducer_stateless7_streaming/train.py", line 1264, in main
    mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
  File "/DATA/anaconda3/envs/icefall/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 281, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
  File "/DATA/anaconda3/envs/icefall/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 237, in start_processes
    while not context.join():
  File "/DATA/anaconda3/envs/icefall/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 188, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/DATA/anaconda3/envs/icefall/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 75, in _wrap
    fn(i, *args)
  File "/DATA/Sougata/icefall_toolkit/icefall/egs/Hindi/ASR/pruned_transducer_stateless7_streaming/train.py", line 1144, in run
    train_one_epoch(
  File "/DATA/Sougata/icefall_toolkit/icefall/egs/Hindi/ASR/pruned_transducer_stateless7_streaming/train.py", line 814, in train_one_epoch
    loss, loss_info = compute_loss(
  File "/DATA/Sougata/icefall_toolkit/icefall/egs/Hindi/ASR/pruned_transducer_stateless7_streaming/train.py", line 685, in compute_loss
    simple_loss, pruned_loss = model(
  File "/DATA/anaconda3/envs/icefall/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/DATA/anaconda3/envs/icefall/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/DATA/anaconda3/envs/icefall/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1593, in forward
    else self._run_ddp_forward(*inputs, **kwargs)
  File "/DATA/anaconda3/envs/icefall/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1411, in _run_ddp_forward
    return self.module(*inputs, **kwargs)  # type: ignore[index]
  File "/DATA/anaconda3/envs/icefall/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/DATA/anaconda3/envs/icefall/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/DATA/Sougata/icefall_toolkit/icefall/egs/Hindi/ASR/pruned_transducer_stateless7_streaming/model.py", line 121, in forward
    assert torch.all(x_lens > 0)
AssertionError

After this, I ran the following Python script to check whether there was any issue related to audio file length and whether there were any empty or invalid audio files.

import os
import soundfile as sf
from concurrent.futures import ThreadPoolExecutor, as_completed

def is_empty_audio_file(file_path):
    """
    Check if the audio file is empty or invalid.

    Parameters:
        file_path (str): Path to the audio file.

    Returns:
        str or None: The file path if empty or invalid, otherwise None.
    """
    try:
        with sf.SoundFile(file_path) as audio:
            if len(audio) == 0:
                return file_path
    except Exception:
        return file_path  # Add to the list if unreadable
    return None

def read_wav_scp_files(wav_scp_files):
    """
    Read paths from wav.scp files.

    Parameters:
        wav_scp_files (list): List of paths to wav.scp files.

    Returns:
        list: List of audio file paths.
    """
    audio_files = []
    for wav_scp in wav_scp_files:
        with open(wav_scp, 'r') as f:
            for line in f:
                parts = line.strip().split(maxsplit=1)  # Split by whitespace
                if len(parts) == 2:  # Ensure the line has an ID and a path
                    audio_files.append(parts[1])
    return audio_files

def detect_empty_audio_files_from_scp(wav_scp_files, num_workers=8):
    """
    Detect empty audio files based on paths listed in wav.scp files.

    Parameters:
        wav_scp_files (list): List of paths to wav.scp files.
        num_workers (int): Number of threads to use for parallel processing.

    Returns:
        list: A list of file paths for audio files with no valid data.
    """
    empty_files = []

    # Read audio file paths from wav.scp files
    audio_files = read_wav_scp_files(wav_scp_files)

    # Process files in parallel
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        future_to_file = {executor.submit(is_empty_audio_file, file): file for file in audio_files}
        for future in as_completed(future_to_file):
            result = future.result()
            if result:
                empty_files.append(result)

    return empty_files

# Usage
wav_scp_files = [
    "/DATA/Sougata/Zipformer_dataset_files/train/wav.scp",
    "/DATA/Sougata/Zipformer_dataset_files/test_m/wav.scp",
    "/DATA/Sougata/Zipformer_dataset_files/dev/wav.scp"
]

empty_audio_files = detect_empty_audio_files_from_scp(wav_scp_files)

if empty_audio_files:
    print(f"Found {len(empty_audio_files)} empty or invalid audio files:")
    for empty_file in empty_audio_files:
        print(empty_file)
else:
    print("No empty audio files detected.")

The output of the above Python code was No empty audio files detected.
After that, I checked the cut statistics of the dataset for train, dev, and test sets. The cut statistics are given below:-

Cut statistics of train set:-
Image

Cut statistics of dev set:-
Image

Cut statistics of test set:-
Image

From the cut statistics we can see that many files have a duration of more than 30 seconds.

Kindly please help me in resolving this issue.

@csukuangfj
Copy link
Collaborator

The output of the above Python code was No empty audio files detected.

Please check for files nearly empty, e.g., 0.1 second or so long.


By the way, if you use the train.py from our librispeech recipe, it saves the batch causing the exception.

You can use torch.load() to load the saved batch and check it.

@csukuangfj
Copy link
Collaborator

Please enable

# train_cuts = train_cuts.filter(remove_short_and_long_utt)

You need to read

# Caution: There is a reason to select 20.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
if c.duration < 1.0 or c.duration > 20.0:

and change the threshold accordingly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants