diff --git a/docs/source/conf.py b/docs/source/conf.py index 7a25e79102..a7c308cb3c 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -127,6 +127,22 @@ def _get_pattern(): return ret +def reset_mpl(gallery_conf, fname): + from sphinx_gallery.scrapers import _reset_matplotlib + + _reset_matplotlib(gallery_conf, fname) + import matplotlib + + matplotlib.rcParams.update( + { + "image.interpolation": "none", + "figure.figsize": (9.6, 4.8), + "font.size": 8.0, + "axes.axisbelow": True, + } + ) + + sphinx_gallery_conf = { "examples_dirs": [ "../../examples/tutorials", @@ -139,6 +155,7 @@ def _get_pattern(): "promote_jupyter_magic": True, "first_notebook_cell": None, "doc_module": ("torchaudio",), + "reset_modules": (reset_mpl, "seaborn"), } autosummary_generate = True diff --git a/docs/source/index.rst b/docs/source/index.rst index 61e9fdc0d6..23ed1ba7a6 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -71,8 +71,8 @@ model implementations and application components. tutorials/online_asr_tutorial tutorials/device_asr tutorials/device_avsr - tutorials/forced_alignment_for_multilingual_data_tutorial tutorials/forced_alignment_tutorial + tutorials/forced_alignment_for_multilingual_data_tutorial tutorials/tacotron2_pipeline_tutorial tutorials/mvdr_tutorial tutorials/hybrid_demucs_tutorial @@ -147,6 +147,13 @@ Tutorials .. customcardstart:: +.. customcarditem:: + :header: On device audio-visual automatic speech recognition + :card_description: Learn how to stream audio and video from laptop webcam and perform audio-visual automatic speech recognition using Emformer-RNNT model. + :image: https://download.pytorch.org/torchaudio/doc-assets/avsr/transformed.gif + :link: tutorials/device_avsr.html + :tags: I/O,Pipelines,RNNT + .. customcarditem:: :header: Loading waveform Tensors from files and saving them :card_description: Learn how to query/load audio files and save waveform tensors to files, using torchaudio.info, torchaudio.load and torchaudio.save functions. diff --git a/examples/tutorials/additive_synthesis_tutorial.py b/examples/tutorials/additive_synthesis_tutorial.py index d6407f95bc..329611918a 100644 --- a/examples/tutorials/additive_synthesis_tutorial.py +++ b/examples/tutorials/additive_synthesis_tutorial.py @@ -85,7 +85,7 @@ # -def show(freq, amp, waveform, sample_rate, zoom=None, vol=0.1): +def plot(freq, amp, waveform, sample_rate, zoom=None, vol=0.1): t = torch.arange(waveform.size(0)) / sample_rate fig, axes = plt.subplots(4, 1, sharex=True) @@ -101,7 +101,7 @@ def show(freq, amp, waveform, sample_rate, zoom=None, vol=0.1): for i in range(4): axes[i].grid(True) pos = axes[2].get_position() - plt.tight_layout() + fig.tight_layout() if zoom is not None: ax = fig.add_axes([pos.x0 + 0.02, pos.y0 + 0.03, pos.width / 2.5, pos.height / 2.0]) @@ -168,7 +168,7 @@ def sawtooth_wave(freq0, amp0, num_pitches, sample_rate): freq0 = torch.full((NUM_FRAMES, 1), F0) amp0 = torch.ones((NUM_FRAMES, 1)) freq, amp, waveform = sawtooth_wave(freq0, amp0, int(SAMPLE_RATE / F0), SAMPLE_RATE) -show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0)) +plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0)) ###################################################################### # @@ -183,7 +183,7 @@ def sawtooth_wave(freq0, amp0, num_pitches, sample_rate): freq0 = F0 + f_dev * torch.sin(phase).unsqueeze(-1) freq, amp, waveform = sawtooth_wave(freq0, amp0, int(SAMPLE_RATE / F0), SAMPLE_RATE) -show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0)) +plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0)) ###################################################################### # Square wave @@ -220,7 +220,7 @@ def square_wave(freq0, amp0, num_pitches, sample_rate): freq0 = torch.full((NUM_FRAMES, 1), F0) amp0 = torch.ones((NUM_FRAMES, 1)) freq, amp, waveform = square_wave(freq0, amp0, int(SAMPLE_RATE / F0 / 2), SAMPLE_RATE) -show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0)) +plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0)) ###################################################################### # Triangle wave @@ -256,7 +256,7 @@ def triangle_wave(freq0, amp0, num_pitches, sample_rate): # freq, amp, waveform = triangle_wave(freq0, amp0, int(SAMPLE_RATE / F0 / 2), SAMPLE_RATE) -show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0)) +plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0)) ###################################################################### # Inharmonic Paritials @@ -296,7 +296,7 @@ def triangle_wave(freq0, amp0, num_pitches, sample_rate): waveform = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE) -show(freq, amp, waveform, SAMPLE_RATE, vol=0.4) +plot(freq, amp, waveform, SAMPLE_RATE, vol=0.4) ###################################################################### # @@ -308,7 +308,7 @@ def triangle_wave(freq0, amp0, num_pitches, sample_rate): freq = extend_pitch(freq0, num_tones) waveform = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE) -show(freq, amp, waveform, SAMPLE_RATE) +plot(freq, amp, waveform, SAMPLE_RATE) ###################################################################### # References diff --git a/examples/tutorials/asr_inference_with_ctc_decoder_tutorial.py b/examples/tutorials/asr_inference_with_ctc_decoder_tutorial.py index 154f8589f7..955dc3c029 100644 --- a/examples/tutorials/asr_inference_with_ctc_decoder_tutorial.py +++ b/examples/tutorials/asr_inference_with_ctc_decoder_tutorial.py @@ -407,30 +407,45 @@ def forward(self, emission: torch.Tensor) -> List[str]: # -def plot_alignments(waveform, emission, tokens, timesteps): - fig, ax = plt.subplots(figsize=(32, 10)) - - ax.plot(waveform) - - ratio = waveform.shape[0] / emission.shape[1] - word_start = 0 - - for i in range(len(tokens)): - if i != 0 and tokens[i - 1] == "|": - word_start = timesteps[i] - if tokens[i] != "|": - plt.annotate(tokens[i].upper(), (timesteps[i] * ratio, waveform.max() * 1.02), size=14) - elif i != 0: - word_end = timesteps[i] - ax.axvspan(word_start * ratio, word_end * ratio, alpha=0.1, color="red") - - xticks = ax.get_xticks() - plt.xticks(xticks, xticks / bundle.sample_rate) - ax.set_xlabel("time (sec)") - ax.set_xlim(0, waveform.shape[0]) - - -plot_alignments(waveform[0], emission, predicted_tokens, timesteps) +def plot_alignments(waveform, emission, tokens, timesteps, sample_rate): + + t = torch.arange(waveform.size(0)) / sample_rate + ratio = waveform.size(0) / emission.size(1) / sample_rate + + chars = [] + words = [] + word_start = None + for token, timestep in zip(tokens, timesteps * ratio): + if token == "|": + if word_start is not None: + words.append((word_start, timestep)) + word_start = None + else: + chars.append((token, timestep)) + if word_start is None: + word_start = timestep + + fig, axes = plt.subplots(3, 1) + + def _plot(ax, xlim): + ax.plot(t, waveform) + for token, timestep in chars: + ax.annotate(token.upper(), (timestep, 0.5)) + for word_start, word_end in words: + ax.axvspan(word_start, word_end, alpha=0.1, color="red") + ax.set_ylim(-0.6, 0.7) + ax.set_yticks([0]) + ax.grid(True, axis="y") + ax.set_xlim(xlim) + + _plot(axes[0], (0.3, 2.5)) + _plot(axes[1], (2.5, 4.7)) + _plot(axes[2], (4.7, 6.9)) + axes[2].set_xlabel("time (sec)") + fig.tight_layout() + + +plot_alignments(waveform[0], emission, predicted_tokens, timesteps, bundle.sample_rate) ###################################################################### diff --git a/examples/tutorials/audio_data_augmentation_tutorial.py b/examples/tutorials/audio_data_augmentation_tutorial.py index cbe53b5326..3d9d3922ee 100644 --- a/examples/tutorials/audio_data_augmentation_tutorial.py +++ b/examples/tutorials/audio_data_augmentation_tutorial.py @@ -100,7 +100,6 @@ def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None): if xlim: axes[c].set_xlim(xlim) figure.suptitle(title) - plt.show(block=False) ###################################################################### @@ -122,7 +121,6 @@ def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None): if xlim: axes[c].set_xlim(xlim) figure.suptitle(title) - plt.show(block=False) ###################################################################### diff --git a/examples/tutorials/audio_datasets_tutorial.py b/examples/tutorials/audio_datasets_tutorial.py index d3c16ffddb..2d540b78fe 100644 --- a/examples/tutorials/audio_datasets_tutorial.py +++ b/examples/tutorials/audio_datasets_tutorial.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ Audio Datasets ============== @@ -10,10 +9,6 @@ available datasets. """ -# When running this tutorial in Google Colab, install the required packages -# with the following. -# !pip install torchaudio - import torch import torchaudio @@ -21,22 +16,13 @@ print(torchaudio.__version__) ###################################################################### -# Preparing data and utility functions (skip this section) -# -------------------------------------------------------- # -# @title Prepare data and utility functions. {display-mode: "form"} -# @markdown -# @markdown You do not need to look into this cell. -# @markdown Just execute once and you are good to go. - -# ------------------------------------------------------------------------------- -# Preparation of data and helper functions. -# ------------------------------------------------------------------------------- import os +import IPython + import matplotlib.pyplot as plt -from IPython.display import Audio, display _SAMPLE_DIR = "_assets" @@ -44,34 +30,13 @@ os.makedirs(YESNO_DATASET_PATH, exist_ok=True) -def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None): +def plot_specgram(waveform, sample_rate, title="Spectrogram"): waveform = waveform.numpy() - num_channels, _ = waveform.shape - - figure, axes = plt.subplots(num_channels, 1) - if num_channels == 1: - axes = [axes] - for c in range(num_channels): - axes[c].specgram(waveform[c], Fs=sample_rate) - if num_channels > 1: - axes[c].set_ylabel(f"Channel {c+1}") - if xlim: - axes[c].set_xlim(xlim) + figure, ax = plt.subplots() + ax.specgram(waveform[0], Fs=sample_rate) figure.suptitle(title) - plt.show(block=False) - - -def play_audio(waveform, sample_rate): - waveform = waveform.numpy() - - num_channels, _ = waveform.shape - if num_channels == 1: - display(Audio(waveform[0], rate=sample_rate)) - elif num_channels == 2: - display(Audio((waveform[0], waveform[1]), rate=sample_rate)) - else: - raise ValueError("Waveform with more than 2 channels are not supported.") + figure.tight_layout() ###################################################################### @@ -79,10 +44,25 @@ def play_audio(waveform, sample_rate): # :py:class:`torchaudio.datasets.YESNO` dataset. # - dataset = torchaudio.datasets.YESNO(YESNO_DATASET_PATH, download=True) -for i in [1, 3, 5]: - waveform, sample_rate, label = dataset[i] - plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}") - play_audio(waveform, sample_rate) +###################################################################### +# +i = 1 +waveform, sample_rate, label = dataset[i] +plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}") +IPython.display.Audio(waveform, rate=sample_rate) + +###################################################################### +# +i = 3 +waveform, sample_rate, label = dataset[i] +plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}") +IPython.display.Audio(waveform, rate=sample_rate) + +###################################################################### +# +i = 5 +waveform, sample_rate, label = dataset[i] +plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}") +IPython.display.Audio(waveform, rate=sample_rate) diff --git a/examples/tutorials/audio_feature_augmentation_tutorial.py b/examples/tutorials/audio_feature_augmentation_tutorial.py index 197d03d04b..6e69ef5056 100644 --- a/examples/tutorials/audio_feature_augmentation_tutorial.py +++ b/examples/tutorials/audio_feature_augmentation_tutorial.py @@ -19,25 +19,19 @@ print(torchaudio.__version__) ###################################################################### -# Preparing data and utility functions (skip this section) -# -------------------------------------------------------- +# Preparation +# ----------- # -# @title Prepare data and utility functions. {display-mode: "form"} -# @markdown -# @markdown You do not need to look into this cell. -# @markdown Just execute once and you are good to go. -# @markdown -# @markdown In this tutorial, we will use a speech data from [VOiCES dataset](https://iqtlabs.github.io/voices/), -# @markdown which is licensed under Creative Commos BY 4.0. - -# ------------------------------------------------------------------------------- -# Preparation of data and helper functions. -# ------------------------------------------------------------------------------- import librosa import matplotlib.pyplot as plt from torchaudio.utils import download_asset +###################################################################### +# In this tutorial, we will use a speech data from +# `VOiCES dataset `__, +# which is licensed under Creative Commos BY 4.0. + SAMPLE_WAV_SPEECH_PATH = download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav") @@ -75,16 +69,9 @@ def get_spectrogram( return spectrogram(waveform) -def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=None): - fig, axs = plt.subplots(1, 1) - axs.set_title(title or "Spectrogram (db)") - axs.set_ylabel(ylabel) - axs.set_xlabel("frame") - im = axs.imshow(librosa.power_to_db(spec), origin="lower", aspect=aspect) - if xmax: - axs.set_xlim((0, xmax)) - fig.colorbar(im, ax=axs) - plt.show(block=False) +def plot_spec(ax, spec, title, ylabel="freq_bin"): + ax.set_title(title) + ax.imshow(librosa.power_to_db(spec), origin="lower", aspect="auto") ###################################################################### @@ -108,43 +95,47 @@ def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=No spec = get_spectrogram(power=None) stretch = T.TimeStretch() -rate = 1.2 -spec_ = stretch(spec, rate) -plot_spectrogram(torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect="equal", xmax=304) - -plot_spectrogram(torch.abs(spec[0]), title="Original", aspect="equal", xmax=304) - -rate = 0.9 -spec_ = stretch(spec, rate) -plot_spectrogram(torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect="equal", xmax=304) +spec_12 = stretch(spec, overriding_rate=1.2) +spec_09 = stretch(spec, overriding_rate=0.9) ###################################################################### -# TimeMasking -# ----------- # -torch.random.manual_seed(4) -spec = get_spectrogram() -plot_spectrogram(spec[0], title="Original") +def plot(): + fig, axes = plt.subplots(3, 1, sharex=True, sharey=True) + plot_spec(axes[0], torch.abs(spec_12[0]), title="Stretched x1.2") + plot_spec(axes[1], torch.abs(spec[0]), title="Original") + plot_spec(axes[2], torch.abs(spec_09[0]), title="Stretched x0.9") + fig.tight_layout() -masking = T.TimeMasking(time_mask_param=80) -spec = masking(spec) -plot_spectrogram(spec[0], title="Masked along time axis") +plot() ###################################################################### -# FrequencyMasking -# ---------------- +# Time and Frequency Masking +# -------------------------- # - torch.random.manual_seed(4) +time_masking = T.TimeMasking(time_mask_param=80) +freq_masking = T.FrequencyMasking(freq_mask_param=80) + spec = get_spectrogram() -plot_spectrogram(spec[0], title="Original") +time_masked = time_masking(spec) +freq_masked = freq_masking(spec) + +###################################################################### +# + + +def plot(): + fig, axes = plt.subplots(3, 1, sharex=True, sharey=True) + plot_spec(axes[0], spec[0], title="Original") + plot_spec(axes[1], time_masked[0], title="Masked along time axis") + plot_spec(axes[2], freq_masked[0], title="Masked along frequency axis") + fig.tight_layout() -masking = T.FrequencyMasking(freq_mask_param=80) -spec = masking(spec) -plot_spectrogram(spec[0], title="Masked along frequency axis") +plot() diff --git a/examples/tutorials/audio_feature_extractions_tutorial.py b/examples/tutorials/audio_feature_extractions_tutorial.py index 63b71bc14a..eb43c6dca8 100644 --- a/examples/tutorials/audio_feature_extractions_tutorial.py +++ b/examples/tutorials/audio_feature_extractions_tutorial.py @@ -75,7 +75,6 @@ def plot_waveform(waveform, sr, title="Waveform", ax=None): ax.grid(True) ax.set_xlim([0, time_axis[-1]]) ax.set_title(title) - plt.show(block=False) def plot_spectrogram(specgram, title=None, ylabel="freq_bin", ax=None): @@ -85,7 +84,6 @@ def plot_spectrogram(specgram, title=None, ylabel="freq_bin", ax=None): ax.set_title(title) ax.set_ylabel(ylabel) ax.imshow(librosa.power_to_db(specgram), origin="lower", aspect="auto", interpolation="nearest") - plt.show(block=False) def plot_fbank(fbank, title=None): @@ -94,7 +92,6 @@ def plot_fbank(fbank, title=None): axs.imshow(fbank, aspect="auto") axs.set_ylabel("frequency bin") axs.set_xlabel("mel bin") - plt.show(block=False) ###################################################################### @@ -486,7 +483,6 @@ def plot_pitch(waveform, sr, pitch): axis2.plot(time_axis, pitch[0], linewidth=2, label="Pitch", color="green") axis2.legend(loc=0) - plt.show(block=False) plot_pitch(SPEECH_WAVEFORM, SAMPLE_RATE, pitch) diff --git a/examples/tutorials/audio_io_tutorial.py b/examples/tutorials/audio_io_tutorial.py index 6fd0f1f2e9..15ef25cc6e 100644 --- a/examples/tutorials/audio_io_tutorial.py +++ b/examples/tutorials/audio_io_tutorial.py @@ -181,7 +181,6 @@ def plot_waveform(waveform, sample_rate): if num_channels > 1: axes[c].set_ylabel(f"Channel {c+1}") figure.suptitle("waveform") - plt.show(block=False) ###################################################################### @@ -204,7 +203,6 @@ def plot_specgram(waveform, sample_rate, title="Spectrogram"): if num_channels > 1: axes[c].set_ylabel(f"Channel {c+1}") figure.suptitle(title) - plt.show(block=False) ###################################################################### diff --git a/examples/tutorials/audio_resampling_tutorial.py b/examples/tutorials/audio_resampling_tutorial.py index 33b1ffec53..1398e1d69a 100644 --- a/examples/tutorials/audio_resampling_tutorial.py +++ b/examples/tutorials/audio_resampling_tutorial.py @@ -105,7 +105,6 @@ def plot_sweep( axis.yaxis.grid(True, alpha=0.67) figure.suptitle(f"{title} (sample rate: {sample_rate} Hz)") plt.colorbar(cax) - plt.show(block=True) ###################################################################### diff --git a/examples/tutorials/ctc_forced_alignment_api_tutorial.py b/examples/tutorials/ctc_forced_alignment_api_tutorial.py index a0d3d7acb7..7d6f02a7f4 100644 --- a/examples/tutorials/ctc_forced_alignment_api_tutorial.py +++ b/examples/tutorials/ctc_forced_alignment_api_tutorial.py @@ -5,254 +5,277 @@ **Author**: `Xiaohui Zhang `__ -This tutorial shows how to align transcripts to speech with -``torchaudio``'s CTC forced alignment API proposed in the paper -`“Scaling Speech Technology to 1,000+ -Languages” `__, -and one advanced usage, i.e. dealing with transcription errors with a token. - -Though there’s some overlap in visualization -diagrams, the scope here is different from the `“Forced Alignment with -Wav2Vec2” `__ -tutorial, which focuses on a step-by-step demonstration of the forced -alignment generation algorithm (without using an API) described in the -`paper `__ with a Wav2Vec2 model. - +This tutorial shows how to align transcripts to speech using +:py:func:`torchaudio.functional.forced_align` +which was developed along the work of +`Scaling Speech Technology to 1,000+ Languages `__. + +The forced alignment is a process to align transcript with speech. +We cover the basics of forced alignment in `Forced Alignment with +Wav2Vec2 <./forced_alignment_tutorial.html>`__ with simplified +step-by-step Python implementations. + +:py:func:`~torchaudio.functional.forced_align` has custom CPU and CUDA +implementations which are more performant than the vanilla Python +implementation above, and are more accurate. +It can also handle missing transcript with special token. + +For examples of aligning multiple languages, please refer to +`Forced alignment for multilingual data <./forced_alignment_for_multilingual_data_tutorial.html>`__. """ import torch import torchaudio + print(torch.__version__) print(torchaudio.__version__) +###################################################################### +# -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -print(device) - +from dataclasses import dataclass +from typing import List -try: - from torchaudio.functional import forced_align -except ModuleNotFoundError: - print( - "Failed to import the forced alignment API. " - "Please install torchaudio nightly builds. " - "Please refer to https://pytorch.org/get-started/locally " - "for instructions to install a nightly build." - ) - raise +import IPython +import matplotlib.pyplot as plt ###################################################################### -# Basic usages -# ------------ -# -# In this section, we cover the following content: -# -# 1. Generate frame-wise class probabilites from audio waveform from a CTC -# acoustic model. -# 2. Compute frame-level alignments using TorchAudio’s forced alignment -# API. -# 3. Obtain token-level alignments from frame-level alignments. -# 4. Obtain word-level alignments from token-level alignments. # +from torchaudio.functional import forced_align + +torch.random.manual_seed(0) +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print(device) ###################################################################### # Preparation -# ~~~~~~~~~~~ +# ----------- # -# First we import the necessary packages, and fetch data that we work on. +# First we prepare the speech data and the transcript we area going +# to use. # -# %matplotlib inline -from dataclasses import dataclass - -import IPython -import matplotlib -import matplotlib.pyplot as plt - -matplotlib.rcParams["figure.figsize"] = [16.0, 4.8] - -torch.random.manual_seed(0) - SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav") -sample_rate = 16000 - +TRANSCRIPT = "I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT" ###################################################################### -# Generate frame-wise class posteriors from a CTC acoustic model -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Generating emissions and tokens +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# :py:func:`~torchaudio.functional.forced_align` takes emission and +# token sequences and outputs timestaps of the tokens and their scores. # -# The first step is to generate the class probabilities (i.e. posteriors) -# of each audio frame using a CTC model. -# Here we use :py:func:`torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H`. +# Emission reperesents the frame-wise probability distribution over +# tokens, and it can be obtained by passing waveform to an acoustic +# model. +# Tokens are numerical expression of transcripts. It can be obtained by +# simply mapping each character to the index of token list. +# The emission and the token sequences must be using the same set of tokens. +# +# We can use pre-trained Wav2Vec2 model to obtain emission from speech, +# and map transcript to tokens. +# Here, we use :py:data:`~torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H`, +# which bandles pre-trained model weights with associated labels. # bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H model = bundle.get_model().to(device) -labels = bundle.get_labels() with torch.inference_mode(): waveform, _ = torchaudio.load(SPEECH_FILE) - emissions, _ = model(waveform.to(device)) - emissions = torch.log_softmax(emissions, dim=-1) + emission, _ = model(waveform.to(device)) + emission = torch.log_softmax(emission, dim=-1) + + +###################################################################### +# + -emission = emissions.cpu().detach() -dictionary = {c: i for i, c in enumerate(labels)} +def plot_emission(emission): + plt.imshow(emission.cpu().T) + plt.title("Frame-wise class probabilities") + plt.xlabel("Time") + plt.ylabel("Labels") + plt.tight_layout() -print(dictionary) +plot_emission(emission[0]) ###################################################################### -# Visualization -# ^^^^^^^^^^^^^ -# +# We create a dictionary, which maps each label into token. + +labels = bundle.get_labels() +DICTIONARY = {c: i for i, c in enumerate(labels)} + +for k, v in DICTIONARY.items(): + print(f"{k}: {v}") + +###################################################################### +# converting transcript to tokens is as simple as -plt.imshow(emission[0].T) -plt.colorbar() -plt.title("Frame-wise class probabilities") -plt.xlabel("Time") -plt.ylabel("Labels") -plt.show() +tokens = [DICTIONARY[c] for c in TRANSCRIPT] +print(" ".join(str(t) for t in tokens)) ###################################################################### # Computing frame-level alignments -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# -------------------------------- # -# Then we call TorchAudio’s forced alignment API to compute the -# frame-level alignment between each audio frame and each token in the -# transcript. We first explain the inputs and outputs of the API -# ``functional.forced_align``. Note that this API works on both CPU and -# GPU. In the current tutorial we demonstrate it on CPU. +# Now we call TorchAudio’s forced alignment API to compute the +# frame-level alignment. For the detail of function signature, please +# refer to :py:func:`~torchaudio.functional.forced_align`. # -# **Inputs**: # -# ``emission``: a 2D tensor of size :math:`T \times N`, where :math:`T` is -# the number of frames (after sub-sampling by the acoustic model, if any), -# and :math:`N` is the vocabulary size. -# -# ``targets``: a 1D tensor vector of size :math:`M`, where :math:`M` is -# the length of the transcript, and each element is a token ID looked up -# from the vocabulary. For example, the ``targets`` tensor repsenting the -# transcript “i had…” is :math:`[5, 18, 4, 16, ...]`. -# -# ``input lengths``: :math:`T`. -# -# ``target lengths``: :math:`M`. -# -# **Outputs**: -# -# ``frame_alignment``: a 1D tensor of size :math:`T` storing the aligned -# token index (looked up from the vocabulary) of each frame, e.g. for the -# segment corresponding to “i had” in the given example , the -# frame_alignment is -# :math:`[...0, 0, 5, 0, 0, 18, 18, 4, 0, 0, 0, 16,...]`, where :math:`0` -# represents the blank symbol. + + +def align(emission, tokens): + alignments, scores = forced_align( + emission, + targets=torch.tensor([tokens], dtype=torch.int32, device=emission.device), + input_lengths=torch.tensor([emission.size(1)], device=emission.device), + target_lengths=torch.tensor([len(tokens)], device=emission.device), + blank=0, + ) + + scores = scores.exp() # convert back to probability + alignments, scores = alignments[0], scores[0] # remove batch dimension for simplicity + return alignments.tolist(), scores.tolist() + + +frame_alignment, frame_scores = align(emission, tokens) + +###################################################################### +# Now let's look at the output. +# Notice that the alignment is expressed in the frame cordinate of +# emission, which is different from the original waveform. + +for i, (ali, score) in enumerate(zip(frame_alignment, frame_scores)): + print(f"{i:3d}: {ali:2d} [{labels[ali]}], {score:.2f}") + +###################################################################### # -# ``frame_scores``: a 1D tensor of size :math:`T` storing the confidence -# score (0 to 1) for each each frame. For each frame, the score should be -# close to one if the alignment quality is good. +# The ``Frame`` instance represents the most likely token at each frame +# with its confidence. +# +# When interpreting it, one must remember that the meaning of blank token +# and repeated token are context dependent. +# +# .. note:: +# +# When same token occured after blank tokens, it is not treated as +# a repeat, but as a new occurrence. +# +# .. code-block:: +# +# a a a b -> a b +# a - - b -> a b +# a a - b -> a b +# a - a b -> a a b +# ^^^ ^^^ +# +# .. code-block:: +# +# 29: 0 [-], 1.00 +# 30: 7 [I], 1.00 # Start of "I" +# 31: 0 [-], 0.98 # repeat (blank token) +# 32: 0 [-], 1.00 # repeat (blank token) +# 33: 1 [|], 0.85 # Start of "|" (word boundary) +# 34: 1 [|], 1.00 # repeat (same token) +# 35: 0 [-], 0.61 # repeat (blank token) +# 36: 8 [H], 1.00 # Start of "H" +# 37: 0 [-], 1.00 # repeat (blank token) +# 38: 4 [A], 1.00 # Start of "A" +# 39: 0 [-], 0.99 # repeat (blank token) +# 40: 11 [D], 0.92 # Start of "D" +# 41: 0 [-], 0.93 # repeat (blank token) +# 42: 1 [|], 0.98 # Start of "|" +# 43: 1 [|], 1.00 # repeat (same token) +# 44: 3 [T], 1.00 # Start of "T" +# 45: 3 [T], 0.90 # repeat (same token) +# 46: 8 [H], 1.00 # Start of "H" +# 47: 0 [-], 1.00 # repeat (blank token) + +###################################################################### +# Resolve blank and repeated tokens +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -# From the outputs ``frame_alignment`` and ``frame_scores``, we generate a +# Next step is to resolve the repetation. So that what alignment represents +# do not depend on previous alignments. +# From the outputs ``alignment`` and ``scores``, we generate a # list called ``frames`` storing information of all frames aligned to -# non-blank tokens. Each element contains 1) token_index: the aligned -# token’s index in the transcript 2) time_index: the current frame’s index -# in the input audio (or more precisely, the row dimension of the emission -# matrix) 3) the confidence scores of the current frame. -# -# For the given example, the first few elements of the list ``frames`` -# corresponding to “i had” looks as the following: -# -# ``Frame(token_index=0, time_index=32, score=0.9994410872459412)`` +# non-blank tokens. # -# ``Frame(token_index=1, time_index=35, score=0.9980823993682861)`` +# Each element contains the following # -# ``Frame(token_index=1, time_index=36, score=0.9295750260353088)`` -# -# ``Frame(token_index=2, time_index=37, score=0.9997448325157166)`` -# -# ``Frame(token_index=3, time_index=41, score=0.9991760849952698)`` -# -# ``...`` -# -# The interpretation is: -# -# The token with index :math:`0` in the transcript, i.e. “i”, is aligned -# to the :math:`32`\ th audio frame, with confidence :math:`0.9994`. The -# token with index :math:`1` in the transcript, i.e. “h”, is aligned to -# the :math:`35`\ th and :math:`36`\ th audio frames, with confidence -# :math:`0.9981` and :math:`0.9296` respectively. The token with index -# :math:`2` in the transcript, i.e. “a”, is aligned to the :math:`35`\ th -# and :math:`36`\ th audio frames, with confidence :math:`0.9997`. The -# token with index :math:`3` in the transcript, i.e. “d”, is aligned to -# the :math:`41`\ th audio frame, with confidence :math:`0.9992`. -# -# From such information stored in the ``frames`` list, we’ll compute -# token-level and word-level alignments easily. +# - ``token_index``: the aligned token’s index **in the transcript** +# - ``time_index``: the current frame’s index in emission +# - ``score``: scores of the current frame. # +# ``token_index`` is the index of each token in the transcript, +# i.e. the current frame aligns to the N-th character from the transcript. @dataclass class Frame: - # This is the index of each token in the transcript, - # i.e. the current frame aligns to the N-th character from the transcript. token_index: int time_index: int score: float -def compute_alignments(transcript, dictionary, emission): - frames = [] - tokens = [dictionary[c] for c in transcript.replace(" ", "")] - - targets = torch.tensor(tokens, dtype=torch.int32).unsqueeze(0) - input_lengths = torch.tensor([emission.shape[1]]) - target_lengths = torch.tensor([targets.shape[1]]) +###################################################################### +# - # This is the key step, where we call the forced alignment API functional.forced_align to compute alignments. - frame_alignment, frame_scores = forced_align(emission, targets, input_lengths, target_lengths, 0) - assert frame_alignment.shape[1] == input_lengths[0].item() - assert targets.shape[1] == target_lengths[0].item() +def obtain_token_level_alignments(alignments, scores) -> List[Frame]: + assert len(alignments) == len(scores) token_index = -1 prev_hyp = 0 - for i in range(frame_alignment.shape[1]): - if frame_alignment[0][i].item() == 0: + frames = [] + for i, (ali, score) in enumerate(zip(alignments, scores)): + if ali == 0: prev_hyp = 0 continue - if frame_alignment[0][i].item() != prev_hyp: + if ali != prev_hyp: token_index += 1 - frames.append(Frame(token_index, i, frame_scores[0][i].exp().item())) - prev_hyp = frame_alignment[0][i].item() - return frames, frame_alignment, frame_scores - - -transcript = "I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT" -frames, frame_alignment, frame_scores = compute_alignments(transcript, dictionary, emission) + frames.append(Frame(token_index, i, score)) + prev_hyp = ali + return frames ###################################################################### -# Obtain token-level alignments and confidence scores -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +frames = obtain_token_level_alignments(frame_alignment, frame_scores) + +print("Time\tLabel\tScore") +for f in frames: + print(f"{f.time_index:3d}\t{TRANSCRIPT[f.token_index]}\t{f.score:.2f}") + ###################################################################### +# Obtain token-level alignments and confidence scores +# --------------------------------------------------- +# # The frame-level alignments contains repetations for the same labels. # Another format “token-level alignment”, which specifies the aligned # frame ranges for each transcript token, contains the same information, # while being more convenient to apply to some downstream tasks -# (e.g. computing word-level alignments). +# (e.g. computing word-level alignments). # # Now we demonstrate how to obtain token-level alignments and confidence # scores by simply merging frame-level alignments and averaging # frame-level confidence scores. # +###################################################################### +# The following class represents the label, its score and the time span +# of its occurance. +# + -# Merge the labels @dataclass class Segment: label: str @@ -261,13 +284,16 @@ class Segment: score: float def __repr__(self): - return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})" + return f"{self.label:2s} ({self.score:4.2f}): [{self.start:4d}, {self.end:4d})" - @property - def length(self): + def __len__(self): return self.end - self.start +###################################################################### +# + + def merge_repeats(frames, transcript): transcript_nospace = transcript.replace(" ", "") i1, i2 = 0, 0 @@ -288,29 +314,31 @@ def merge_repeats(frames, transcript): return segments -segments = merge_repeats(frames, transcript) +###################################################################### +# +segments = merge_repeats(frames, TRANSCRIPT) for seg in segments: print(seg) ###################################################################### # Visualization -# ^^^^^^^^^^^^^ +# ~~~~~~~~~~~~~ # def plot_label_prob(segments, transcript): - fig, ax2 = plt.subplots(figsize=(16, 4)) + fig, ax = plt.subplots() - ax2.set_title("frame-level and token-level confidence scores") + ax.set_title("frame-level and token-level confidence scores") xs, hs, ws = [], [], [] for seg in segments: if seg.label != "|": xs.append((seg.end + seg.start) / 2 + 0.4) hs.append(seg.score) ws.append(seg.end - seg.start) - ax2.annotate(seg.label, (seg.start + 0.8, -0.07), weight="bold") - ax2.bar(xs, hs, width=ws, color="gray", alpha=0.5, edgecolor="black") + ax.annotate(seg.label, (seg.start + 0.8, -0.07), weight="bold") + ax.bar(xs, hs, width=ws, color="gray", alpha=0.5, edgecolor="black") xs, hs = [], [] for p in frames: @@ -319,27 +347,28 @@ def plot_label_prob(segments, transcript): xs.append(p.time_index + 1) hs.append(p.score) - ax2.bar(xs, hs, width=0.5, alpha=0.5) - ax2.axhline(0, color="black") - ax2.set_ylim(-0.1, 1.1) + ax.bar(xs, hs, width=0.5, alpha=0.5) + ax.set_ylim(-0.1, 1.1) + ax.grid(True, axis="y") + fig.tight_layout() -plot_label_prob(segments, transcript) -plt.tight_layout() -plt.show() +plot_label_prob(segments, TRANSCRIPT) ###################################################################### # From the visualized scores, we can see that, for tokens spanning over -# more multiple frames, e.g. “T” in “THAT, the token-level confidence +# more multiple frames, e.g. “T” in “THAT, the token-level confidence # score is the average of frame-level confidence scores. To make this # clearer, we don’t plot confidence scores for blank frames, which was # plotted in the”Label probability with and without repeatation” figure in -# the previous tutorial `“Forced Alignment with -# Wav2Vec2” `__. +# the previous tutorial +# `Forced Alignment with Wav2Vec2 <./forced_alignment_tutorial.html>`__. # + +###################################################################### # Obtain word-level alignments and confidence scores -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# -------------------------------------------------- # @@ -367,7 +396,7 @@ def merge_words(transcript, segments, separator=" "): s = 0 segs = segments[i1 + s : i2 + s] word = "".join([seg.label for seg in segs]) - score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs) + score = sum(seg.score * len(seg) for seg in segs) / sum(len(seg) for seg in segs) words.append(Segment(word, segments[i1 + s].start, segments[i2 + s - 1].end, score)) i1 = i2 else: @@ -376,59 +405,43 @@ def merge_words(transcript, segments, separator=" "): return words -word_segments = merge_words(transcript, segments, "|") +word_segments = merge_words(TRANSCRIPT, segments, "|") ###################################################################### # Visualization -# ^^^^^^^^^^^^^ +# ~~~~~~~~~~~~~ # -def plot_alignments(segments, word_segments, waveform, input_lengths, scale=10): - fig, ax2 = plt.subplots(figsize=(64, 12)) - plt.rcParams.update({"font.size": 30}) +def plot_alignments(waveform, emission, segments, word_segments, sample_rate=bundle.sample_rate): + fig, ax = plt.subplots() - # The original waveform - ratio = waveform.size(1) / input_lengths - ax2.plot(waveform) - ax2.set_ylim(-1.0 * scale, 1.0 * scale) - ax2.set_xlim(0, waveform.size(-1)) + ax.specgram(waveform[0], Fs=sample_rate) + # The original waveform + ratio = waveform.size(1) / sample_rate / emission.size(1) for word in word_segments: - x0 = ratio * word.start - x1 = ratio * word.end - ax2.axvspan(x0, x1, alpha=0.1, color="red") - ax2.annotate(f"{word.score:.2f}", (x0, 0.8 * scale)) + t0, t1 = ratio * word.start, ratio * word.end + ax.axvspan(t0, t1, facecolor="None", hatch="/", edgecolor="white") + ax.annotate(f"{word.score:.2f}", (t0, sample_rate * 0.51), annotation_clip=False) for seg in segments: if seg.label != "|": - ax2.annotate(seg.label, (seg.start * ratio, 0.9 * scale)) + ax.annotate(seg.label, (seg.start * ratio, sample_rate * 0.53), annotation_clip=False) - xticks = ax2.get_xticks() - plt.xticks(xticks, xticks / sample_rate, fontsize=50) - ax2.set_xlabel("time [second]", fontsize=40) - ax2.set_yticks([]) + ax.set_xlabel("time [second]") + fig.tight_layout() -plot_alignments( - segments, - word_segments, - waveform, - emission.shape[1], - 1, -) -plt.show() +plot_alignments(waveform, emission, segments, word_segments) ###################################################################### -# A trick to embed the resulting audio to the generated file. -# `IPython.display.Audio` has to be the last call in a cell, -# and there should be only one call par cell. -def display_segment(i, waveform, word_segments, frame_alignment): - ratio = waveform.size(1) / frame_alignment.size(1) +def display_segment(i, waveform, word_segments, frame_alignment, sample_rate=bundle.sample_rate): + ratio = waveform.size(1) / len(frame_alignment) word = word_segments[i] x0 = int(ratio * word.start) x1 = int(ratio * word.end) @@ -437,8 +450,10 @@ def display_segment(i, waveform, word_segments, frame_alignment): return IPython.display.Audio(segment.numpy(), rate=sample_rate) +###################################################################### + # Generate the audio for each segment -print(transcript) +print(TRANSCRIPT) IPython.display.Audio(SPEECH_FILE) ###################################################################### @@ -488,62 +503,71 @@ def display_segment(i, waveform, word_segments, frame_alignment): ###################################################################### -# Advanced usage: Dealing with missing transcripts using the token -# --------------------------------------------------------------------------- +# Advanced: Handling transcripts with ```` token +# ---------------------------------------------------- # # Now let’s look at when the transcript is partially missing, how can we -# improve alignment quality using the token, which is capable of modeling +# improve alignment quality using the ```` token, which is capable of modeling # any token. # # Here we use the same English example as used above. But we remove the -# beginning text “i had that curiosity beside me at” from the transcript. +# beginning text ``“i had that curiosity beside me at”`` from the transcript. # Aligning audio with such transcript results in wrong alignments of the # existing word “this”. However, this issue can be mitigated by using the -# token to model the missing text. +# ```` token to model the missing text. # -# Reload the emission tensor in order to add the extra dimension corresponding to the token. -with torch.inference_mode(): - waveform, _ = torchaudio.load(SPEECH_FILE) - emissions, _ = model(waveform.to(device)) - emissions = torch.log_softmax(emissions, dim=-1) +###################################################################### +# First, we extend the dictionary to include the ```` token. - # Append the extra dimension corresponding to the token - extra_dim = torch.zeros(emissions.shape[0], emissions.shape[1], 1) - emissions = torch.cat((emissions.cpu(), extra_dim), 2) - emission = emissions.detach() +DICTIONARY["*"] = len(DICTIONARY) -# Extend the dictionary to include the token. -dictionary["*"] = 29 +###################################################################### +# Next, we extend the emission tensor with the extra dimension +# corresponding to the ```` token. +# -assert len(dictionary) == emission.shape[2] +extra_dim = torch.zeros(emission.shape[0], emission.shape[1], 1, device=device) +emission = torch.cat((emission, extra_dim), 2) + +assert len(DICTIONARY) == emission.shape[2] + + +###################################################################### +# The following function combines all the processes, and compute +# word segments from emission in one-go. def compute_and_plot_alignments(transcript, dictionary, emission, waveform): - frames, frame_alignment, _ = compute_alignments(transcript, dictionary, emission) + tokens = [dictionary[c] for c in transcript] + alignment, scores = align(emission, tokens) + frames = obtain_token_level_alignments(alignment, scores) segments = merge_repeats(frames, transcript) word_segments = merge_words(transcript, segments, "|") - plot_alignments(segments, word_segments, waveform, emission.shape[1], 1) - plt.show() - return word_segments, frame_alignment - + plot_alignments(waveform, emission, segments, word_segments) + plt.xlim([0, None]) -# original: -word_segments, frame_alignment = compute_and_plot_alignments(transcript, dictionary, emission, waveform) ###################################################################### +# **Original** -# Demonstrate the effect of token for dealing with deletion errors -# ("i had that curiosity beside me at" missing from the transcript): -transcript = "THIS|MOMENT" -word_segments, frame_alignment = compute_and_plot_alignments(transcript, dictionary, emission, waveform) +compute_and_plot_alignments(TRANSCRIPT, DICTIONARY, emission, waveform) ###################################################################### +# **With token** +# +# Now we replace the first part of the transcript with the ```` token. -# Replacing the missing transcript with the token: -transcript = "*|THIS|MOMENT" -word_segments, frame_alignment = compute_and_plot_alignments(transcript, dictionary, emission, waveform) +compute_and_plot_alignments("*|THIS|MOMENT", DICTIONARY, emission, waveform) + +###################################################################### +# **Without token** +# +# As a comparison, the following aligns the partial transcript +# without using ```` token. +# It demonstrates the effect of ```` token for dealing with deletion errors. +compute_and_plot_alignments("THIS|MOMENT", DICTIONARY, emission, waveform) ###################################################################### # Conclusion @@ -551,7 +575,7 @@ def compute_and_plot_alignments(transcript, dictionary, emission, waveform): # # In this tutorial, we looked at how to use torchaudio’s forced alignment # API to align and segment speech files, and demonstrated one advanced usage: -# How introducing a token could improve alignment accuracy when +# How introducing a ```` token could improve alignment accuracy when # transcription errors exist. # diff --git a/examples/tutorials/device_avsr.py b/examples/tutorials/device_avsr.py index f7a12ec943..0bb7a5792d 100644 --- a/examples/tutorials/device_avsr.py +++ b/examples/tutorials/device_avsr.py @@ -69,7 +69,7 @@ # ------------------- # # Firstly, we define the function to collect videos from microphone and -# camera. To be specific, we use :py:func:`~torchaudio.io.StreamReader` +# camera. To be specific, we use :py:class:`~torchaudio.io.StreamReader` # class for the purpose of data collection, which supports capturing # audio/video from microphone and camera. For the detailed usage of this # class, please refer to the diff --git a/examples/tutorials/filter_design_tutorial.py b/examples/tutorials/filter_design_tutorial.py index 944a7df3f8..1637eb0cc2 100644 --- a/examples/tutorials/filter_design_tutorial.py +++ b/examples/tutorials/filter_design_tutorial.py @@ -89,7 +89,7 @@ def plot_sinc_ir(irs, cutoff): num_filts, window_size = irs.shape half = window_size // 2 - fig, axes = plt.subplots(num_filts, 1, sharex=True, figsize=(6.4, 4.8 * 1.5)) + fig, axes = plt.subplots(num_filts, 1, sharex=True, figsize=(9.6, 8)) t = torch.linspace(-half, half - 1, window_size) for ax, ir, coff, color in zip(axes, irs, cutoff, plt.cm.tab10.colors): ax.plot(t, ir, linewidth=1.2, color=color, zorder=4, label=f"Cutoff: {coff}") @@ -100,7 +100,7 @@ def plot_sinc_ir(irs, cutoff): "(Frequencies are relative to Nyquist frequency)" ) axes[-1].set_xticks([i * half // 4 for i in range(-4, 5)]) - plt.tight_layout() + fig.tight_layout() ###################################################################### @@ -130,7 +130,7 @@ def plot_sinc_fr(frs, cutoff, band=False): num_filts, num_fft = frs.shape num_ticks = num_filts + 1 if band else num_filts - fig, axes = plt.subplots(num_filts, 1, sharex=True, sharey=True, figsize=(6.4, 4.8 * 1.5)) + fig, axes = plt.subplots(num_filts, 1, sharex=True, sharey=True, figsize=(9.6, 8)) for ax, fr, coff, color in zip(axes, frs, cutoff, plt.cm.tab10.colors): ax.grid(True) ax.semilogy(fr, color=color, zorder=4, label=f"Cutoff: {coff}") @@ -146,7 +146,7 @@ def plot_sinc_fr(frs, cutoff, band=False): "Frequency response of sinc low-pass filter for different cut-off frequencies\n" "(Frequencies are relative to Nyquist frequency)" ) - plt.tight_layout() + fig.tight_layout() ###################################################################### @@ -275,7 +275,7 @@ def plot_ir(magnitudes, ir, num_fft=2048): axes[i].grid(True) axes[1].set(title="Frequency Response") axes[2].set(title="Frequency Response (log-scale)", xlabel="Frequency") - axes[2].legend(loc="lower right") + axes[2].legend(loc="center right") fig.tight_layout() diff --git a/examples/tutorials/forced_alignment_for_multilingual_data_tutorial.py b/examples/tutorials/forced_alignment_for_multilingual_data_tutorial.py index 01333d7175..6f78b0e5d3 100644 --- a/examples/tutorials/forced_alignment_for_multilingual_data_tutorial.py +++ b/examples/tutorials/forced_alignment_for_multilingual_data_tutorial.py @@ -6,15 +6,14 @@ This tutorial shows how to compute forced alignments for speech data from multiple non-English languages using ``torchaudio``'s CTC forced alignment -API described in `“CTC forced alignment -tutorial” `__ -and the multilingual Wav2vec2 model proposed in the paper `“Scaling +API described in `CTC forced alignment tutorial <./forced_alignment_tutorial.html>`__ +and the multilingual Wav2vec2 model proposed in the paper `Scaling Speech Technology to 1,000+ -Languages” `__. +Languages `__. + The model was trained on 23K of audio data from 1100+ languages using -the `“uroman vocabulary” `__ +the `uroman vocabulary `__ as targets. - """ import torch @@ -23,53 +22,46 @@ print(torch.__version__) print(torchaudio.__version__) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(device) - -try: - from torchaudio.functional import forced_align -except ModuleNotFoundError: - print( - "Failed to import the forced alignment API. " - "Please install torchaudio nightly builds. " - "Please refer to https://pytorch.org/get-started/locally " - "for instructions to install a nightly build." - ) - raise - ###################################################################### # Preparation # ----------- # -# Here we import necessary packages, and define utility functions for -# computing the frame-level alignments (using the API -# ``functional.forced_align``), token-level and word-level alignments, and -# also alignment visualization utilities. -# -# %matplotlib inline from dataclasses import dataclass import IPython - import matplotlib.pyplot as plt +from torchaudio.functional import forced_align -torch.random.manual_seed(0) -sample_rate = 16000 +###################################################################### +# + +SAMPLE_RATE = 16000 + + +###################################################################### +# +# Here we define utility functions for computing the frame-level +# alignments (using the API :py:func:`torchaudio.functional.forced_align`), +# token-level and word-level alignments. +# For the detail of these functions please refer to +# `CTC forced alignment API tutorial <./ctc_forced_alignment_api_tutorial.html>`__. +# @dataclass class Frame: - # This is the index of each token in the transcript, - # i.e. the current frame aligns to the N-th character from the transcript. token_index: int time_index: int score: float +###################################################################### +# @dataclass class Segment: label: str @@ -78,39 +70,42 @@ class Segment: score: float def __repr__(self): - return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})" + return f"{self.label:2s} ({self.score:4.2f}): [{self.start:4d}, {self.end:4d})" - @property - def length(self): + def __len__(self): return self.end - self.start -# compute frame-level and word-level alignments using torchaudio's forced alignment API +###################################################################### +# + + def compute_alignments(transcript, dictionary, emission): - frames = [] tokens = [dictionary[c] for c in transcript.replace(" ", "")] - targets = torch.tensor(tokens, dtype=torch.int32).unsqueeze(0) - input_lengths = torch.tensor([emission.shape[1]]) - target_lengths = torch.tensor([targets.shape[1]]) + targets = torch.tensor([tokens], dtype=torch.int32, device=emission.device) + input_lengths = torch.tensor([emission.shape[1]], device=emission.device) + target_lengths = torch.tensor([targets.shape[1]], device=emission.device) + + alignment, scores = forced_align(emission, targets, input_lengths, target_lengths, 0) - # This is the key step, where we call the forced alignment API functional.forced_align to compute frame alignments. - frame_alignment, frame_scores = forced_align(emission, targets, input_lengths, target_lengths, 0) + scores = scores.exp() # convert back to probability + alignment, scores = alignment[0].tolist(), scores[0].tolist() - assert frame_alignment.shape[1] == input_lengths[0].item() - assert targets.shape[1] == target_lengths[0].item() + assert len(alignment) == len(scores) == emission.size(1) token_index = -1 prev_hyp = 0 - for i in range(frame_alignment.shape[1]): - if frame_alignment[0][i].item() == 0: + frames = [] + for i, (ali, score) in enumerate(zip(alignment, scores)): + if ali == 0: prev_hyp = 0 continue - if frame_alignment[0][i].item() != prev_hyp: + if ali != prev_hyp: token_index += 1 - frames.append(Frame(token_index, i, frame_scores[0][i].exp().item())) - prev_hyp = frame_alignment[0][i].item() + frames.append(Frame(token_index, i, score)) + prev_hyp = ali # compute frame alignments from token alignments transcript_nospace = transcript.replace(" ", "") @@ -140,52 +135,59 @@ def compute_alignments(transcript, dictionary, emission): if i1 != i2: if i3 == len(transcript) - 1: i2 += 1 - s = 0 - segs = segments[i1 + s : i2 + s] - word = "".join([seg.label for seg in segs]) - score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs) - words.append(Segment(word, segments[i1 + s].start, segments[i2 + s - 1].end, score)) + segs = segments[i1:i2] + word = "".join([s.label for s in segs]) + score = sum(s.score * len(s) for s in segs) / sum(len(s) for s in segs) + words.append(Segment(word, segs[0].start, segs[-1].end + 1, score)) i1 = i2 else: i2 += 1 i3 += 1 + return segments, words - num_frames = frame_alignment.shape[1] - return segments, words, num_frames +###################################################################### +# + + +def plot_emission(emission): + fig, ax = plt.subplots() + ax.imshow(emission.T, aspect="auto") + ax.set_title("Emission") + fig.tight_layout() -# utility function for plotting word alignments -def plot_alignments(segments, word_segments, waveform, input_lengths, scale=10): - fig, ax2 = plt.subplots(figsize=(64, 12)) - plt.rcParams.update({"font.size": 30}) - # The original waveform - ratio = waveform.size(1) / input_lengths - ax2.plot(waveform) - ax2.set_ylim(-1.0 * scale, 1.0 * scale) - ax2.set_xlim(0, waveform.size(-1)) +###################################################################### +# + +# utility function for plotting word alignments +def plot_alignments(waveform, emission, segments, word_segments, sample_rate=SAMPLE_RATE): + fig, ax = plt.subplots() + ax.specgram(waveform[0], Fs=sample_rate) + xlim = ax.get_xlim() + ratio = waveform.size(1) / sample_rate / emission.size(1) for word in word_segments: - x0 = ratio * word.start - x1 = ratio * word.end - ax2.axvspan(x0, x1, alpha=0.1, color="red") - ax2.annotate(f"{word.score:.2f}", (x0, 0.8 * scale)) + t0, t1 = word.start * ratio, word.end * ratio + ax.axvspan(t0, t1, facecolor="None", hatch="/", edgecolor="white") + ax.annotate(f"{word.score:.2f}", (t0, sample_rate * 0.51), annotation_clip=False) for seg in segments: if seg.label != "|": - ax2.annotate(seg.label, (seg.start * ratio, 0.9 * scale)) + ax.annotate(seg.label, (seg.start * ratio, sample_rate * 0.53), annotation_clip=False) - xticks = ax2.get_xticks() - plt.xticks(xticks, xticks / sample_rate, fontsize=50) - ax2.set_xlabel("time [second]", fontsize=40) - ax2.set_yticks([]) + ax.set_xlabel("time [second]") + ax.set_xlim(xlim) + fig.tight_layout() + return IPython.display.Audio(waveform, rate=sample_rate) + + +###################################################################### +# # utility function for playing audio segments. -# A trick to embed the resulting audio to the generated file. -# `IPython.display.Audio` has to be the last call in a cell, -# and there should be only one call par cell. -def display_segment(i, waveform, word_segments, num_frames): +def display_segment(i, waveform, word_segments, num_frames, sample_rate=SAMPLE_RATE): ratio = waveform.size(1) / num_frames word = word_segments[i] x0 = int(ratio * word.start) @@ -241,26 +243,21 @@ def display_segment(i, waveform, word_segments, num_frames): ) ) model.eval() +model.to(device) def get_emission(waveform): - # NOTE: this step is essential - waveform = torch.nn.functional.layer_norm(waveform, waveform.shape) - - emissions, _ = model(waveform) - emissions = torch.log_softmax(emissions, dim=-1) - emission = emissions.cpu().detach() - - # Append the extra dimension corresponding to the token - extra_dim = torch.zeros(emissions.shape[0], emissions.shape[1], 1) - emissions = torch.cat((emissions.cpu(), extra_dim), 2) - emission = emissions.detach() - return emission, waveform + with torch.inference_mode(): + # NOTE: this step is essential + waveform = torch.nn.functional.layer_norm(waveform, waveform.shape) + emission, _ = model(waveform) + return torch.log_softmax(emission, dim=-1) # Construct the dictionary -# '@' represents the OOV token, '*' represents the token. +# '@' represents the OOV token # and are fairseq's legacy tokens, which're not used. +# token is omitted as we do not use it in this tutorial dictionary = { "": 0, "": 1, @@ -293,7 +290,6 @@ def get_emission(waveform): "'": 28, "q": 29, "x": 30, - "*": 31, } @@ -304,11 +300,8 @@ def get_emission(waveform): # romanizer and using it to obtain romanized transcripts, and PyThon # commands required for further normalizing the romanized transcript. # - -# %% # .. code-block:: bash # -# %%bash # Save the raw transcript to a file # echo 'raw text' > text.txt # git clone https://github.com/isi-nlp/uroman @@ -334,141 +327,77 @@ def get_emission(waveform): ###################################################################### -# German example: -# ~~~~~~~~~~~~~~~~ +# German +# ~~~~~~ -text_raw = ( - "aber seit ich bei ihnen das brot hole brauch ich viel weniger schulze wandte sich ab die kinder taten ihm leid" -) -text_normalized = ( - "aber seit ich bei ihnen das brot hole brauch ich viel weniger schulze wandte sich ab die kinder taten ihm leid" -) speech_file = torchaudio.utils.download_asset("tutorial-assets/10349_8674_000087.flac", progress=False) -waveform, _ = torchaudio.load(speech_file) - -emission, waveform = get_emission(waveform) -assert len(dictionary) == emission.shape[2] -transcript = text_normalized - -segments, word_segments, num_frames = compute_alignments(transcript, dictionary, emission) -plot_alignments(segments, word_segments, waveform, emission.shape[1]) +text_raw = "aber seit ich bei ihnen das brot hole" +text_normalized = "aber seit ich bei ihnen das brot hole" print("Raw Transcript: ", text_raw) print("Normalized Transcript: ", text_normalized) -IPython.display.Audio(waveform, rate=sample_rate) ###################################################################### # -display_segment(0, waveform, word_segments, num_frames) +waveform, _ = torchaudio.load(speech_file, frame_offset=int(0.5 * SAMPLE_RATE), num_frames=int(2.5 * SAMPLE_RATE)) +emission = get_emission(waveform.to(device)) +num_frames = emission.size(1) +plot_emission(emission[0].cpu()) ###################################################################### # -display_segment(1, waveform, word_segments, num_frames) +segments, word_segments = compute_alignments(text_normalized, dictionary, emission) -###################################################################### -# - -display_segment(2, waveform, word_segments, num_frames) - -###################################################################### -# - -display_segment(3, waveform, word_segments, num_frames) - - -###################################################################### -# - -display_segment(4, waveform, word_segments, num_frames) +plot_alignments(waveform, emission, segments, word_segments) ###################################################################### # -display_segment(5, waveform, word_segments, num_frames) - -###################################################################### -# - -display_segment(6, waveform, word_segments, num_frames) - -###################################################################### -# - -display_segment(7, waveform, word_segments, num_frames) - -###################################################################### -# - -display_segment(8, waveform, word_segments, num_frames) - -###################################################################### -# - -display_segment(9, waveform, word_segments, num_frames) - -###################################################################### -# - -display_segment(10, waveform, word_segments, num_frames) - -###################################################################### -# - -display_segment(11, waveform, word_segments, num_frames) - -###################################################################### -# - -display_segment(12, waveform, word_segments, num_frames) +display_segment(0, waveform, word_segments, num_frames) ###################################################################### # -display_segment(13, waveform, word_segments, num_frames) +display_segment(1, waveform, word_segments, num_frames) ###################################################################### # -display_segment(14, waveform, word_segments, num_frames) +display_segment(2, waveform, word_segments, num_frames) ###################################################################### # -display_segment(15, waveform, word_segments, num_frames) - -###################################################################### -# +display_segment(3, waveform, word_segments, num_frames) -display_segment(16, waveform, word_segments, num_frames) ###################################################################### # -display_segment(17, waveform, word_segments, num_frames) +display_segment(4, waveform, word_segments, num_frames) ###################################################################### # -display_segment(18, waveform, word_segments, num_frames) +display_segment(5, waveform, word_segments, num_frames) ###################################################################### # -display_segment(19, waveform, word_segments, num_frames) +display_segment(6, waveform, word_segments, num_frames) ###################################################################### # -display_segment(20, waveform, word_segments, num_frames) - +display_segment(7, waveform, word_segments, num_frames) ###################################################################### -# Chinese example: -# ~~~~~~~~~~~~~~~~ +# Chinese +# ~~~~~~~ # # Chinese is a character-based language, and there is not explicit word-level # tokenization (separated by spaces) in its raw written form. In order to @@ -478,98 +407,36 @@ def get_emission(waveform): # However this is not needed if you only want character-level alignments. # -text_raw = "关 服务 高端 产品 仍 处于 供不应求 的 局面" -text_normalized = "guan fuwu gaoduan chanpin reng chuyu gongbuyingqiu de jumian" speech_file = torchaudio.utils.download_asset("tutorial-assets/mvdr/clean_speech.wav", progress=False) -waveform, _ = torchaudio.load(speech_file) -waveform = waveform[0:1] - -emission, waveform = get_emission(waveform) - -transcript = text_normalized -segments, word_segments, num_frames = compute_alignments(transcript, dictionary, emission) -plot_alignments(segments, word_segments, waveform, emission.shape[1]) +text_raw = "关 服务 高端 产品 仍 处于 供不应求 的 局面" +text_normalized = "guan fuwu gaoduan chanpin reng chuyu gongbuyingqiu de jumian" print("Raw Transcript: ", text_raw) print("Normalized Transcript: ", text_normalized) -IPython.display.Audio(waveform, rate=sample_rate) ###################################################################### # -display_segment(0, waveform, word_segments, num_frames) - - -###################################################################### -# - -display_segment(1, waveform, word_segments, num_frames) - -###################################################################### -# - -display_segment(2, waveform, word_segments, num_frames) - -###################################################################### -# - -display_segment(3, waveform, word_segments, num_frames) - - -###################################################################### -# - -display_segment(4, waveform, word_segments, num_frames) - -###################################################################### -# - -display_segment(5, waveform, word_segments, num_frames) - -###################################################################### -# - -display_segment(6, waveform, word_segments, num_frames) - -###################################################################### -# +waveform, _ = torchaudio.load(speech_file) +waveform = waveform[0:1] -display_segment(7, waveform, word_segments, num_frames) +emission = get_emission(waveform.to(device)) +num_frames = emission.size(1) +plot_emission(emission[0].cpu()) ###################################################################### # -display_segment(8, waveform, word_segments, num_frames) - - -###################################################################### -# Polish example: -# ~~~~~~~~~~~~~~~ - - -text_raw = "wtedy ujrzałem na jego brzuchu okrągłą czarną ranę dlaczego mi nie powiedziałeś szepnąłem ze łzami" -text_normalized = "wtedy ujrzalem na jego brzuchu okragla czarna rane dlaczego mi nie powiedziales szepnalem ze lzami" -speech_file = torchaudio.utils.download_asset("tutorial-assets/5090_1447_000088.flac", progress=False) -waveform, _ = torchaudio.load(speech_file) - -emission, waveform = get_emission(waveform) - -transcript = text_normalized +segments, word_segments = compute_alignments(text_normalized, dictionary, emission) -segments, word_segments, num_frames = compute_alignments(transcript, dictionary, emission) -plot_alignments(segments, word_segments, waveform, emission.shape[1]) - -print("Raw Transcript: ", text_raw) -print("Normalized Transcript: ", text_normalized) -IPython.display.Audio(waveform, rate=sample_rate) +plot_alignments(waveform, emission, segments, word_segments) ###################################################################### # display_segment(0, waveform, word_segments, num_frames) - ###################################################################### # @@ -585,7 +452,6 @@ def get_emission(waveform): display_segment(3, waveform, word_segments, num_frames) - ###################################################################### # @@ -611,68 +477,40 @@ def get_emission(waveform): display_segment(8, waveform, word_segments, num_frames) -###################################################################### -# - -display_segment(9, waveform, word_segments, num_frames) ###################################################################### -# +# Polish +# ~~~~~~ -display_segment(10, waveform, word_segments, num_frames) +speech_file = torchaudio.utils.download_asset("tutorial-assets/5090_1447_000088.flac", progress=False) -###################################################################### -# +text_raw = "wtedy ujrzałem na jego brzuchu okrągłą czarną ranę" +text_normalized = "wtedy ujrzalem na jego brzuchu okragla czarna rane" -display_segment(11, waveform, word_segments, num_frames) +print("Raw Transcript: ", text_raw) +print("Normalized Transcript: ", text_normalized) ###################################################################### # -display_segment(12, waveform, word_segments, num_frames) - -###################################################################### -# +waveform, _ = torchaudio.load(speech_file, num_frames=int(4.5 * SAMPLE_RATE)) -display_segment(13, waveform, word_segments, num_frames) +emission = get_emission(waveform.to(device)) +num_frames = emission.size(1) +plot_emission(emission[0].cpu()) ###################################################################### # -display_segment(14, waveform, word_segments, num_frames) +segments, word_segments = compute_alignments(text_normalized, dictionary, emission) - -###################################################################### -# Portuguese example: -# ~~~~~~~~~~~~~~~~~~~ - - -text_raw = ( - "mas na imensa extensão onde se esconde o inconsciente imortal só me responde um bramido um queixume e nada mais" -) -text_normalized = ( - "mas na imensa extensao onde se esconde o inconsciente imortal so me responde um bramido um queixume e nada mais" -) -speech_file = torchaudio.utils.download_asset("tutorial-assets/6566_5323_000027.flac", progress=False) -waveform, _ = torchaudio.load(speech_file) - -emission, waveform = get_emission(waveform) - -transcript = text_normalized - -segments, word_segments, num_frames = compute_alignments(transcript, dictionary, emission) -plot_alignments(segments, word_segments, waveform, emission.shape[1]) - -print("Raw Transcript: ", text_raw) -print("Normalized Transcript: ", text_normalized) -IPython.display.Audio(waveform, rate=sample_rate) +plot_alignments(waveform, emission, segments, word_segments) ###################################################################### # display_segment(0, waveform, word_segments, num_frames) - ###################################################################### # @@ -688,7 +526,6 @@ def get_emission(waveform): display_segment(3, waveform, word_segments, num_frames) - ###################################################################### # @@ -710,94 +547,38 @@ def get_emission(waveform): display_segment(7, waveform, word_segments, num_frames) ###################################################################### -# - -display_segment(8, waveform, word_segments, num_frames) - -###################################################################### -# - -display_segment(9, waveform, word_segments, num_frames) - -###################################################################### -# - -display_segment(10, waveform, word_segments, num_frames) - -###################################################################### -# - -display_segment(11, waveform, word_segments, num_frames) - -###################################################################### -# - -display_segment(12, waveform, word_segments, num_frames) +# Portuguese +# ~~~~~~~~~~ -###################################################################### -# - -display_segment(13, waveform, word_segments, num_frames) - -###################################################################### -# - -display_segment(14, waveform, word_segments, num_frames) - -###################################################################### -# - -display_segment(15, waveform, word_segments, num_frames) +speech_file = torchaudio.utils.download_asset("tutorial-assets/6566_5323_000027.flac", progress=False) -###################################################################### -# +text_raw = "na imensa extensão onde se esconde o inconsciente imortal" +text_normalized = "na imensa extensao onde se esconde o inconsciente imortal" -display_segment(16, waveform, word_segments, num_frames) +print("Raw Transcript: ", text_raw) +print("Normalized Transcript: ", text_normalized) ###################################################################### # -display_segment(17, waveform, word_segments, num_frames) +waveform, _ = torchaudio.load(speech_file, frame_offset=int(SAMPLE_RATE), num_frames=int(4.6 * SAMPLE_RATE)) -###################################################################### -# - -display_segment(18, waveform, word_segments, num_frames) +emission = get_emission(waveform.to(device)) +num_frames = emission.size(1) +plot_emission(emission[0].cpu()) ###################################################################### # -display_segment(19, waveform, word_segments, num_frames) +segments, word_segments = compute_alignments(text_normalized, dictionary, emission) - -###################################################################### -# Italian example: -# ~~~~~~~~~~~~~~~~ - -text_raw = "elle giacean per terra tutte quante fuor d'una ch'a seder si levò ratto ch'ella ci vide passarsi davante" -text_normalized = ( - "elle giacean per terra tutte quante fuor d'una ch'a seder si levo ratto ch'ella ci vide passarsi davante" -) -speech_file = torchaudio.utils.download_asset("tutorial-assets/642_529_000025.flac", progress=False) -waveform, _ = torchaudio.load(speech_file) - -emission, waveform = get_emission(waveform) - -transcript = text_normalized - -segments, word_segments, num_frames = compute_alignments(transcript, dictionary, emission) -plot_alignments(segments, word_segments, waveform, emission.shape[1]) - -print("Raw Transcript: ", text_raw) -print("Normalized Transcript: ", text_normalized) -IPython.display.Audio(waveform, rate=sample_rate) +plot_alignments(waveform, emission, segments, word_segments) ###################################################################### # display_segment(0, waveform, word_segments, num_frames) - ###################################################################### # @@ -813,7 +594,6 @@ def get_emission(waveform): display_segment(3, waveform, word_segments, num_frames) - ###################################################################### # @@ -840,50 +620,62 @@ def get_emission(waveform): display_segment(8, waveform, word_segments, num_frames) ###################################################################### -# +# Italian +# ~~~~~~~ + +speech_file = torchaudio.utils.download_asset("tutorial-assets/642_529_000025.flac", progress=False) + +text_raw = "elle giacean per terra tutte quante" +text_normalized = "elle giacean per terra tutte quante" -display_segment(9, waveform, word_segments, num_frames) +print("Raw Transcript: ", text_raw) +print("Normalized Transcript: ", text_normalized) ###################################################################### # -display_segment(10, waveform, word_segments, num_frames) +waveform, _ = torchaudio.load(speech_file, num_frames=int(4 * SAMPLE_RATE)) + +emission = get_emission(waveform.to(device)) +num_frames = emission.size(1) +plot_emission(emission[0].cpu()) ###################################################################### # -display_segment(11, waveform, word_segments, num_frames) +segments, word_segments = compute_alignments(text_normalized, dictionary, emission) + +plot_alignments(waveform, emission, segments, word_segments) ###################################################################### # -display_segment(12, waveform, word_segments, num_frames) +display_segment(0, waveform, word_segments, num_frames) ###################################################################### # -display_segment(13, waveform, word_segments, num_frames) +display_segment(1, waveform, word_segments, num_frames) ###################################################################### # -display_segment(14, waveform, word_segments, num_frames) +display_segment(2, waveform, word_segments, num_frames) ###################################################################### # -display_segment(15, waveform, word_segments, num_frames) +display_segment(3, waveform, word_segments, num_frames) ###################################################################### # -display_segment(16, waveform, word_segments, num_frames) +display_segment(4, waveform, word_segments, num_frames) ###################################################################### # -display_segment(17, waveform, word_segments, num_frames) - +display_segment(5, waveform, word_segments, num_frames) ###################################################################### # Conclusion @@ -894,7 +686,6 @@ def get_emission(waveform): # speech data to transcripts in five languages. # - ###################################################################### # Acknowledgement # --------------- diff --git a/examples/tutorials/forced_alignment_tutorial.py b/examples/tutorials/forced_alignment_tutorial.py index ab98908559..fef58e2e06 100644 --- a/examples/tutorials/forced_alignment_tutorial.py +++ b/examples/tutorials/forced_alignment_tutorial.py @@ -56,16 +56,11 @@ # First we import the necessary packages, and fetch data that we work on. # -# %matplotlib inline - from dataclasses import dataclass import IPython -import matplotlib import matplotlib.pyplot as plt -matplotlib.rcParams["figure.figsize"] = [16.0, 4.8] - torch.random.manual_seed(0) SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav") @@ -99,17 +94,22 @@ emission = emissions[0].cpu().detach() +print(labels) + ################################################################################ # Visualization -################################################################################ -print(labels) -plt.imshow(emission.T) -plt.colorbar() -plt.title("Frame-wise class probability") -plt.xlabel("Time") -plt.ylabel("Labels") -plt.show() +# ~~~~~~~~~~~~~ + +def plot(): + plt.imshow(emission.T) + plt.colorbar() + plt.title("Frame-wise class probability") + plt.xlabel("Time") + plt.ylabel("Labels") + + +plot() ###################################################################### # Generate alignment probability (trellis) @@ -181,12 +181,17 @@ def get_trellis(emission, tokens, blank_id=0): ################################################################################ # Visualization -################################################################################ -plt.imshow(trellis.T, origin="lower") -plt.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5)) -plt.annotate("+ Inf", (trellis.size(0) - trellis.size(1) / 5, trellis.size(1) / 3)) -plt.colorbar() -plt.show() +# ~~~~~~~~~~~~~ + + +def plot(): + plt.imshow(trellis.T, origin="lower") + plt.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5)) + plt.annotate("+ Inf", (trellis.size(0) - trellis.size(1) / 5, trellis.size(1) / 3)) + plt.colorbar() + + +plot() ###################################################################### # In the above visualization, we can see that there is a trace of high @@ -266,7 +271,9 @@ def backtrack(trellis, emission, tokens, blank_id=0): ################################################################################ # Visualization -################################################################################ +# ~~~~~~~~~~~~~ + + def plot_trellis_with_path(trellis, path): # To plot trellis with path, we take advantage of 'nan' value trellis_with_path = trellis.clone() @@ -277,10 +284,14 @@ def plot_trellis_with_path(trellis, path): plot_trellis_with_path(trellis, path) plt.title("The path found by backtracking") -plt.show() ###################################################################### -# Looking good. Now this path contains repetations for the same labels, so +# Looking good. + +###################################################################### +# Segment the path +# ---------------- +# Now this path contains repetations for the same labels, so # let’s merge them to make it close to the original transcript. # # When merging the multiple path points, we simply take the average @@ -297,7 +308,7 @@ class Segment: score: float def __repr__(self): - return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})" + return f"{self.label} ({self.score:4.2f}): [{self.start:5d}, {self.end:5d})" @property def length(self): @@ -330,7 +341,9 @@ def merge_repeats(path): ################################################################################ # Visualization -################################################################################ +# ~~~~~~~~~~~~~ + + def plot_trellis_with_segments(trellis, segments, transcript): # To plot trellis with path, we take advantage of 'nan' value trellis_with_path = trellis.clone() @@ -338,15 +351,14 @@ def plot_trellis_with_segments(trellis, segments, transcript): if seg.label != "|": trellis_with_path[seg.start : seg.end, i] = float("nan") - fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9.5)) + fig, [ax1, ax2] = plt.subplots(2, 1, sharex=True) ax1.set_title("Path, label and probability for each label") - ax1.imshow(trellis_with_path.T, origin="lower") - ax1.set_xticks([]) + ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto") for i, seg in enumerate(segments): if seg.label != "|": - ax1.annotate(seg.label, (seg.start, i - 0.7), weight="bold") - ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3)) + ax1.annotate(seg.label, (seg.start, i - 0.7), size="small") + ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), size="small") ax2.set_title("Label probability with and without repetation") xs, hs, ws = [], [], [] @@ -355,7 +367,7 @@ def plot_trellis_with_segments(trellis, segments, transcript): xs.append((seg.end + seg.start) / 2 + 0.4) hs.append(seg.score) ws.append(seg.end - seg.start) - ax2.annotate(seg.label, (seg.start + 0.8, -0.07), weight="bold") + ax2.annotate(seg.label, (seg.start + 0.8, -0.07)) ax2.bar(xs, hs, width=ws, color="gray", alpha=0.5, edgecolor="black") xs, hs = [], [] @@ -367,17 +379,21 @@ def plot_trellis_with_segments(trellis, segments, transcript): ax2.bar(xs, hs, width=0.5, alpha=0.5) ax2.axhline(0, color="black") - ax2.set_xlim(ax1.get_xlim()) + ax2.grid(True, axis="y") ax2.set_ylim(-0.1, 1.1) + fig.tight_layout() plot_trellis_with_segments(trellis, segments, transcript) -plt.tight_layout() -plt.show() ###################################################################### -# Looks good. Now let’s merge the words. The Wav2Vec2 model uses ``'|'`` +# Looks good. + +###################################################################### +# Merge the segments into words +# ----------------------------- +# Now let’s merge the words. The Wav2Vec2 model uses ``'|'`` # as the word boundary, so we merge the segments before each occurance of # ``'|'``. # @@ -410,16 +426,16 @@ def merge_words(segments, separator="|"): ################################################################################ # Visualization -################################################################################ +# ~~~~~~~~~~~~~ def plot_alignments(trellis, segments, word_segments, waveform): trellis_with_path = trellis.clone() for i, seg in enumerate(segments): if seg.label != "|": trellis_with_path[seg.start : seg.end, i] = float("nan") - fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9.5)) + fig, [ax1, ax2] = plt.subplots(2, 1) - ax1.imshow(trellis_with_path.T, origin="lower") + ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto") ax1.set_xticks([]) ax1.set_yticks([]) @@ -429,8 +445,8 @@ def plot_alignments(trellis, segments, word_segments, waveform): for i, seg in enumerate(segments): if seg.label != "|": - ax1.annotate(seg.label, (seg.start, i - 0.7)) - ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), fontsize=8) + ax1.annotate(seg.label, (seg.start, i - 0.7), size="small") + ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), size="small") # The original waveform ratio = waveform.size(0) / trellis.size(0) @@ -450,6 +466,7 @@ def plot_alignments(trellis, segments, word_segments, waveform): ax2.set_yticks([]) ax2.set_ylim(-1.0, 1.0) ax2.set_xlim(0, waveform.size(-1)) + fig.tight_layout() plot_alignments( @@ -458,7 +475,6 @@ def plot_alignments(trellis, segments, word_segments, waveform): word_segments, waveform[0], ) -plt.show() ################################################################################ diff --git a/examples/tutorials/hybrid_demucs_tutorial.py b/examples/tutorials/hybrid_demucs_tutorial.py index 8be6c9903b..081534bfe4 100644 --- a/examples/tutorials/hybrid_demucs_tutorial.py +++ b/examples/tutorials/hybrid_demucs_tutorial.py @@ -162,11 +162,10 @@ def separate_sources( def plot_spectrogram(stft, title="Spectrogram"): magnitude = stft.abs() spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy() - figure, axis = plt.subplots(1, 1) - img = axis.imshow(spectrogram, cmap="viridis", vmin=-60, vmax=0, origin="lower", aspect="auto") - figure.suptitle(title) - plt.colorbar(img, ax=axis) - plt.show() + _, axis = plt.subplots(1, 1) + axis.imshow(spectrogram, cmap="viridis", vmin=-60, vmax=0, origin="lower", aspect="auto") + axis.set_title(title) + plt.tight_layout() ###################################################################### @@ -252,7 +251,7 @@ def output_results(original_source: torch.Tensor, predicted_source: torch.Tensor "SDR score is:", separation.bss_eval_sources(original_source.detach().numpy(), predicted_source.detach().numpy())[0].mean(), ) - plot_spectrogram(stft(predicted_source)[0], f"Spectrogram {source}") + plot_spectrogram(stft(predicted_source)[0], f"Spectrogram - {source}") return Audio(predicted_source, rate=sample_rate) @@ -294,7 +293,7 @@ def output_results(original_source: torch.Tensor, predicted_source: torch.Tensor # # Mixture Clip -plot_spectrogram(stft(mix_spec)[0], "Spectrogram Mixture") +plot_spectrogram(stft(mix_spec)[0], "Spectrogram - Mixture") Audio(mix_spec, rate=sample_rate) ###################################################################### diff --git a/examples/tutorials/mvdr_tutorial.py b/examples/tutorials/mvdr_tutorial.py index 7c9013d180..442f6234a6 100644 --- a/examples/tutorials/mvdr_tutorial.py +++ b/examples/tutorials/mvdr_tutorial.py @@ -98,23 +98,21 @@ # -def plot_spectrogram(stft, title="Spectrogram", xlim=None): +def plot_spectrogram(stft, title="Spectrogram"): magnitude = stft.abs() spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy() figure, axis = plt.subplots(1, 1) img = axis.imshow(spectrogram, cmap="viridis", vmin=-100, vmax=0, origin="lower", aspect="auto") - figure.suptitle(title) + axis.set_title(title) plt.colorbar(img, ax=axis) - plt.show() -def plot_mask(mask, title="Mask", xlim=None): +def plot_mask(mask, title="Mask"): mask = mask.numpy() figure, axis = plt.subplots(1, 1) img = axis.imshow(mask, cmap="viridis", origin="lower", aspect="auto") - figure.suptitle(title) + axis.set_title(title) plt.colorbar(img, ax=axis) - plt.show() def si_snr(estimate, reference, epsilon=1e-8): diff --git a/examples/tutorials/nvdec_tutorial.py b/examples/tutorials/nvdec_tutorial.py index efeca53975..459b690a3f 100644 --- a/examples/tutorials/nvdec_tutorial.py +++ b/examples/tutorials/nvdec_tutorial.py @@ -33,12 +33,9 @@ import os import time -import matplotlib import matplotlib.pyplot as plt from torchaudio.io import StreamReader -matplotlib.rcParams["image.interpolation"] = "none" - ###################################################################### # # Check the prerequisites diff --git a/examples/tutorials/speech_recognition_pipeline_tutorial.py b/examples/tutorials/speech_recognition_pipeline_tutorial.py index 79bbae14c2..2d815a2e8e 100644 --- a/examples/tutorials/speech_recognition_pipeline_tutorial.py +++ b/examples/tutorials/speech_recognition_pipeline_tutorial.py @@ -160,8 +160,7 @@ ax[i].set_title(f"Feature from transformer layer {i+1}") ax[i].set_xlabel("Feature dimension") ax[i].set_ylabel("Frame (time-axis)") -plt.tight_layout() -plt.show() +fig.tight_layout() ###################################################################### @@ -190,7 +189,7 @@ plt.title("Classification result") plt.xlabel("Frame (time-axis)") plt.ylabel("Class") -plt.show() +plt.tight_layout() print("Class labels:", bundle.get_labels()) diff --git a/examples/tutorials/squim_tutorial.py b/examples/tutorials/squim_tutorial.py index 640e2e79b8..9b9b55ac2e 100644 --- a/examples/tutorials/squim_tutorial.py +++ b/examples/tutorials/squim_tutorial.py @@ -82,19 +82,23 @@ from pystoi import stoi from torchaudio.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE except ImportError: - import google.colab # noqa: F401 - - print( - """ - To enable running this notebook in Google Colab, install nightly - torch and torchaudio builds by adding the following code block to the top - of the notebook before running it: - !pip3 uninstall -y torch torchvision torchaudio - !pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu - !pip3 install pesq - !pip3 install pystoi - """ - ) + try: + import google.colab # noqa: F401 + + print( + """ + To enable running this notebook in Google Colab, install nightly + torch and torchaudio builds by adding the following code block to the top + of the notebook before running it: + !pip3 uninstall -y torch torchvision torchaudio + !pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu + !pip3 install pesq + !pip3 install pystoi + """ + ) + except Exception: + pass + raise import matplotlib.pyplot as plt @@ -128,27 +132,17 @@ def si_snr(estimate, reference, epsilon=1e-8): return si_snr.item() -def plot_waveform(waveform, title): +def plot(waveform, title, sample_rate=16000): wav_numpy = waveform.numpy() sample_size = waveform.shape[1] - time_axis = torch.arange(0, sample_size) / 16000 - - figure, axes = plt.subplots(1, 1) - axes = figure.gca() - axes.plot(time_axis, wav_numpy[0], linewidth=1) - axes.grid(True) - figure.suptitle(title) - plt.show(block=False) + time_axis = torch.arange(0, sample_size) / sample_rate - -def plot_specgram(waveform, sample_rate, title): - wav_numpy = waveform.numpy() - figure, axes = plt.subplots(1, 1) - axes = figure.gca() - axes.specgram(wav_numpy[0], Fs=sample_rate) + figure, axes = plt.subplots(2, 1) + axes[0].plot(time_axis, wav_numpy[0], linewidth=1) + axes[0].grid(True) + axes[1].specgram(wav_numpy[0], Fs=sample_rate) figure.suptitle(title) - plt.show(block=False) ###################################################################### @@ -238,32 +232,28 @@ def plot_specgram(waveform, sample_rate, title): # Visualize speech sample # -plot_waveform(WAVEFORM_SPEECH, "Clean Speech") -plot_specgram(WAVEFORM_SPEECH, 16000, "Clean Speech Spectrogram") +plot(WAVEFORM_SPEECH, "Clean Speech") ###################################################################### # Visualize noise sample # -plot_waveform(WAVEFORM_NOISE, "Noise") -plot_specgram(WAVEFORM_NOISE, 16000, "Noise Spectrogram") +plot(WAVEFORM_NOISE, "Noise") ###################################################################### # Visualize distorted speech with 20dB SNR # -plot_waveform(WAVEFORM_DISTORTED[0:1], f"Distorted Speech with {snr_dbs[0]}dB SNR") -plot_specgram(WAVEFORM_DISTORTED[0:1], 16000, f"Distorted Speech with {snr_dbs[0]}dB SNR") +plot(WAVEFORM_DISTORTED[0:1], f"Distorted Speech with {snr_dbs[0]}dB SNR") ###################################################################### # Visualize distorted speech with -5dB SNR # -plot_waveform(WAVEFORM_DISTORTED[1:2], f"Distorted Speech with {snr_dbs[1]}dB SNR") -plot_specgram(WAVEFORM_DISTORTED[1:2], 16000, f"Distorted Speech with {snr_dbs[1]}dB SNR") +plot(WAVEFORM_DISTORTED[1:2], f"Distorted Speech with {snr_dbs[1]}dB SNR") ###################################################################### diff --git a/examples/tutorials/streamreader_advanced_tutorial.py b/examples/tutorials/streamreader_advanced_tutorial.py index 84f113cc16..7d3a4bd09e 100644 --- a/examples/tutorials/streamreader_advanced_tutorial.py +++ b/examples/tutorials/streamreader_advanced_tutorial.py @@ -355,13 +355,14 @@ def _display(i): print("filter_desc:", streamer.get_out_stream_info(i).filter_description) - _, axs = plt.subplots(2, 1) + fig, axs = plt.subplots(2, 1) waveform = chunks[i][:, 0] axs[0].plot(waveform) axs[0].grid(True) axs[0].set_ylim([-1, 1]) plt.setp(axs[0].get_xticklabels(), visible=False) axs[1].specgram(waveform, Fs=sample_rate) + fig.tight_layout() return IPython.display.Audio(chunks[i].T, rate=sample_rate) @@ -440,7 +441,6 @@ def _display(i): axs[j].imshow(chunk[10 * j + 1].permute(1, 2, 0)) axs[j].set_axis_off() plt.tight_layout() - plt.show(block=False) ###################################################################### diff --git a/examples/tutorials/streamreader_basic_tutorial.py b/examples/tutorials/streamreader_basic_tutorial.py index 29ba36aabf..ce94088c79 100644 --- a/examples/tutorials/streamreader_basic_tutorial.py +++ b/examples/tutorials/streamreader_basic_tutorial.py @@ -592,7 +592,6 @@ if i == 0 and j == 0: ax.set_ylabel("Stream 2") plt.tight_layout() -plt.show(block=False) ###################################################################### # diff --git a/examples/tutorials/tacotron2_pipeline_tutorial.py b/examples/tutorials/tacotron2_pipeline_tutorial.py index 586cdb7d09..00687166e9 100644 --- a/examples/tutorials/tacotron2_pipeline_tutorial.py +++ b/examples/tutorials/tacotron2_pipeline_tutorial.py @@ -7,10 +7,6 @@ """ -import IPython -import matplotlib -import matplotlib.pyplot as plt - ###################################################################### # Overview # -------- @@ -65,8 +61,6 @@ import torch import torchaudio -matplotlib.rcParams["figure.figsize"] = [16.0, 4.8] - torch.random.manual_seed(0) device = "cuda" if torch.cuda.is_available() else "cpu" @@ -75,6 +69,13 @@ print(device) +###################################################################### +# + +import IPython +import matplotlib.pyplot as plt + + ###################################################################### # Text Processing # --------------- @@ -226,13 +227,17 @@ def text_to_sequence(text): # therefor, the process of generating the spectrogram incurs randomness. # -fig, ax = plt.subplots(3, 1, figsize=(16, 4.3 * 3)) -for i in range(3): - with torch.inference_mode(): - spec, spec_lengths, _ = tacotron2.infer(processed, lengths) - print(spec[0].shape) - ax[i].imshow(spec[0].cpu().detach(), origin="lower", aspect="auto") -plt.show() + +def plot(): + fig, ax = plt.subplots(3, 1) + for i in range(3): + with torch.inference_mode(): + spec, spec_lengths, _ = tacotron2.infer(processed, lengths) + print(spec[0].shape) + ax[i].imshow(spec[0].cpu().detach(), origin="lower", aspect="auto") + + +plot() ###################################################################### @@ -270,11 +275,22 @@ def text_to_sequence(text): spec, spec_lengths, _ = tacotron2.infer(processed, lengths) waveforms, lengths = vocoder(spec, spec_lengths) -fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9)) -ax1.imshow(spec[0].cpu().detach(), origin="lower", aspect="auto") -ax2.plot(waveforms[0].cpu().detach()) +###################################################################### +# + + +def plot(waveforms, spec, sample_rate): + waveforms = waveforms.cpu().detach() -IPython.display.Audio(waveforms[0:1].cpu(), rate=vocoder.sample_rate) + fig, [ax1, ax2] = plt.subplots(2, 1) + ax1.plot(waveforms[0]) + ax1.set_xlim(0, waveforms.size(-1)) + ax1.grid(True) + ax2.imshow(spec[0].cpu().detach(), origin="lower", aspect="auto") + return IPython.display.Audio(waveforms[0:1], rate=sample_rate) + + +plot(waveforms, spec, vocoder.sample_rate) ###################################################################### @@ -300,11 +316,10 @@ def text_to_sequence(text): spec, spec_lengths, _ = tacotron2.infer(processed, lengths) waveforms, lengths = vocoder(spec, spec_lengths) -fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9)) -ax1.imshow(spec[0].cpu().detach(), origin="lower", aspect="auto") -ax2.plot(waveforms[0].cpu().detach()) +###################################################################### +# -IPython.display.Audio(waveforms[0:1].cpu(), rate=vocoder.sample_rate) +plot(waveforms, spec, vocoder.sample_rate) ###################################################################### @@ -339,8 +354,7 @@ def text_to_sequence(text): with torch.no_grad(): waveforms = waveglow.infer(spec) -fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9)) -ax1.imshow(spec[0].cpu().detach(), origin="lower", aspect="auto") -ax2.plot(waveforms[0].cpu().detach()) +###################################################################### +# -IPython.display.Audio(waveforms[0:1].cpu(), rate=22050) +plot(waveforms, spec, 22050)