diff --git a/docs/source/conf.py b/docs/source/conf.py
index 7a25e79102..a7c308cb3c 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -127,6 +127,22 @@ def _get_pattern():
return ret
+def reset_mpl(gallery_conf, fname):
+ from sphinx_gallery.scrapers import _reset_matplotlib
+
+ _reset_matplotlib(gallery_conf, fname)
+ import matplotlib
+
+ matplotlib.rcParams.update(
+ {
+ "image.interpolation": "none",
+ "figure.figsize": (9.6, 4.8),
+ "font.size": 8.0,
+ "axes.axisbelow": True,
+ }
+ )
+
+
sphinx_gallery_conf = {
"examples_dirs": [
"../../examples/tutorials",
@@ -139,6 +155,7 @@ def _get_pattern():
"promote_jupyter_magic": True,
"first_notebook_cell": None,
"doc_module": ("torchaudio",),
+ "reset_modules": (reset_mpl, "seaborn"),
}
autosummary_generate = True
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 61e9fdc0d6..23ed1ba7a6 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -71,8 +71,8 @@ model implementations and application components.
tutorials/online_asr_tutorial
tutorials/device_asr
tutorials/device_avsr
- tutorials/forced_alignment_for_multilingual_data_tutorial
tutorials/forced_alignment_tutorial
+ tutorials/forced_alignment_for_multilingual_data_tutorial
tutorials/tacotron2_pipeline_tutorial
tutorials/mvdr_tutorial
tutorials/hybrid_demucs_tutorial
@@ -147,6 +147,13 @@ Tutorials
.. customcardstart::
+.. customcarditem::
+ :header: On device audio-visual automatic speech recognition
+ :card_description: Learn how to stream audio and video from laptop webcam and perform audio-visual automatic speech recognition using Emformer-RNNT model.
+ :image: https://download.pytorch.org/torchaudio/doc-assets/avsr/transformed.gif
+ :link: tutorials/device_avsr.html
+ :tags: I/O,Pipelines,RNNT
+
.. customcarditem::
:header: Loading waveform Tensors from files and saving them
:card_description: Learn how to query/load audio files and save waveform tensors to files, using torchaudio.info
, torchaudio.load
and torchaudio.save
functions.
diff --git a/examples/tutorials/additive_synthesis_tutorial.py b/examples/tutorials/additive_synthesis_tutorial.py
index d6407f95bc..329611918a 100644
--- a/examples/tutorials/additive_synthesis_tutorial.py
+++ b/examples/tutorials/additive_synthesis_tutorial.py
@@ -85,7 +85,7 @@
#
-def show(freq, amp, waveform, sample_rate, zoom=None, vol=0.1):
+def plot(freq, amp, waveform, sample_rate, zoom=None, vol=0.1):
t = torch.arange(waveform.size(0)) / sample_rate
fig, axes = plt.subplots(4, 1, sharex=True)
@@ -101,7 +101,7 @@ def show(freq, amp, waveform, sample_rate, zoom=None, vol=0.1):
for i in range(4):
axes[i].grid(True)
pos = axes[2].get_position()
- plt.tight_layout()
+ fig.tight_layout()
if zoom is not None:
ax = fig.add_axes([pos.x0 + 0.02, pos.y0 + 0.03, pos.width / 2.5, pos.height / 2.0])
@@ -168,7 +168,7 @@ def sawtooth_wave(freq0, amp0, num_pitches, sample_rate):
freq0 = torch.full((NUM_FRAMES, 1), F0)
amp0 = torch.ones((NUM_FRAMES, 1))
freq, amp, waveform = sawtooth_wave(freq0, amp0, int(SAMPLE_RATE / F0), SAMPLE_RATE)
-show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
+plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
######################################################################
#
@@ -183,7 +183,7 @@ def sawtooth_wave(freq0, amp0, num_pitches, sample_rate):
freq0 = F0 + f_dev * torch.sin(phase).unsqueeze(-1)
freq, amp, waveform = sawtooth_wave(freq0, amp0, int(SAMPLE_RATE / F0), SAMPLE_RATE)
-show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
+plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
######################################################################
# Square wave
@@ -220,7 +220,7 @@ def square_wave(freq0, amp0, num_pitches, sample_rate):
freq0 = torch.full((NUM_FRAMES, 1), F0)
amp0 = torch.ones((NUM_FRAMES, 1))
freq, amp, waveform = square_wave(freq0, amp0, int(SAMPLE_RATE / F0 / 2), SAMPLE_RATE)
-show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
+plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
######################################################################
# Triangle wave
@@ -256,7 +256,7 @@ def triangle_wave(freq0, amp0, num_pitches, sample_rate):
#
freq, amp, waveform = triangle_wave(freq0, amp0, int(SAMPLE_RATE / F0 / 2), SAMPLE_RATE)
-show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
+plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
######################################################################
# Inharmonic Paritials
@@ -296,7 +296,7 @@ def triangle_wave(freq0, amp0, num_pitches, sample_rate):
waveform = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE)
-show(freq, amp, waveform, SAMPLE_RATE, vol=0.4)
+plot(freq, amp, waveform, SAMPLE_RATE, vol=0.4)
######################################################################
#
@@ -308,7 +308,7 @@ def triangle_wave(freq0, amp0, num_pitches, sample_rate):
freq = extend_pitch(freq0, num_tones)
waveform = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE)
-show(freq, amp, waveform, SAMPLE_RATE)
+plot(freq, amp, waveform, SAMPLE_RATE)
######################################################################
# References
diff --git a/examples/tutorials/asr_inference_with_ctc_decoder_tutorial.py b/examples/tutorials/asr_inference_with_ctc_decoder_tutorial.py
index 154f8589f7..955dc3c029 100644
--- a/examples/tutorials/asr_inference_with_ctc_decoder_tutorial.py
+++ b/examples/tutorials/asr_inference_with_ctc_decoder_tutorial.py
@@ -407,30 +407,45 @@ def forward(self, emission: torch.Tensor) -> List[str]:
#
-def plot_alignments(waveform, emission, tokens, timesteps):
- fig, ax = plt.subplots(figsize=(32, 10))
-
- ax.plot(waveform)
-
- ratio = waveform.shape[0] / emission.shape[1]
- word_start = 0
-
- for i in range(len(tokens)):
- if i != 0 and tokens[i - 1] == "|":
- word_start = timesteps[i]
- if tokens[i] != "|":
- plt.annotate(tokens[i].upper(), (timesteps[i] * ratio, waveform.max() * 1.02), size=14)
- elif i != 0:
- word_end = timesteps[i]
- ax.axvspan(word_start * ratio, word_end * ratio, alpha=0.1, color="red")
-
- xticks = ax.get_xticks()
- plt.xticks(xticks, xticks / bundle.sample_rate)
- ax.set_xlabel("time (sec)")
- ax.set_xlim(0, waveform.shape[0])
-
-
-plot_alignments(waveform[0], emission, predicted_tokens, timesteps)
+def plot_alignments(waveform, emission, tokens, timesteps, sample_rate):
+
+ t = torch.arange(waveform.size(0)) / sample_rate
+ ratio = waveform.size(0) / emission.size(1) / sample_rate
+
+ chars = []
+ words = []
+ word_start = None
+ for token, timestep in zip(tokens, timesteps * ratio):
+ if token == "|":
+ if word_start is not None:
+ words.append((word_start, timestep))
+ word_start = None
+ else:
+ chars.append((token, timestep))
+ if word_start is None:
+ word_start = timestep
+
+ fig, axes = plt.subplots(3, 1)
+
+ def _plot(ax, xlim):
+ ax.plot(t, waveform)
+ for token, timestep in chars:
+ ax.annotate(token.upper(), (timestep, 0.5))
+ for word_start, word_end in words:
+ ax.axvspan(word_start, word_end, alpha=0.1, color="red")
+ ax.set_ylim(-0.6, 0.7)
+ ax.set_yticks([0])
+ ax.grid(True, axis="y")
+ ax.set_xlim(xlim)
+
+ _plot(axes[0], (0.3, 2.5))
+ _plot(axes[1], (2.5, 4.7))
+ _plot(axes[2], (4.7, 6.9))
+ axes[2].set_xlabel("time (sec)")
+ fig.tight_layout()
+
+
+plot_alignments(waveform[0], emission, predicted_tokens, timesteps, bundle.sample_rate)
######################################################################
diff --git a/examples/tutorials/audio_data_augmentation_tutorial.py b/examples/tutorials/audio_data_augmentation_tutorial.py
index cbe53b5326..3d9d3922ee 100644
--- a/examples/tutorials/audio_data_augmentation_tutorial.py
+++ b/examples/tutorials/audio_data_augmentation_tutorial.py
@@ -100,7 +100,6 @@ def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None):
if xlim:
axes[c].set_xlim(xlim)
figure.suptitle(title)
- plt.show(block=False)
######################################################################
@@ -122,7 +121,6 @@ def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
if xlim:
axes[c].set_xlim(xlim)
figure.suptitle(title)
- plt.show(block=False)
######################################################################
diff --git a/examples/tutorials/audio_datasets_tutorial.py b/examples/tutorials/audio_datasets_tutorial.py
index d3c16ffddb..2d540b78fe 100644
--- a/examples/tutorials/audio_datasets_tutorial.py
+++ b/examples/tutorials/audio_datasets_tutorial.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""
Audio Datasets
==============
@@ -10,10 +9,6 @@
available datasets.
"""
-# When running this tutorial in Google Colab, install the required packages
-# with the following.
-# !pip install torchaudio
-
import torch
import torchaudio
@@ -21,22 +16,13 @@
print(torchaudio.__version__)
######################################################################
-# Preparing data and utility functions (skip this section)
-# --------------------------------------------------------
#
-# @title Prepare data and utility functions. {display-mode: "form"}
-# @markdown
-# @markdown You do not need to look into this cell.
-# @markdown Just execute once and you are good to go.
-
-# -------------------------------------------------------------------------------
-# Preparation of data and helper functions.
-# -------------------------------------------------------------------------------
import os
+import IPython
+
import matplotlib.pyplot as plt
-from IPython.display import Audio, display
_SAMPLE_DIR = "_assets"
@@ -44,34 +30,13 @@
os.makedirs(YESNO_DATASET_PATH, exist_ok=True)
-def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
+def plot_specgram(waveform, sample_rate, title="Spectrogram"):
waveform = waveform.numpy()
- num_channels, _ = waveform.shape
-
- figure, axes = plt.subplots(num_channels, 1)
- if num_channels == 1:
- axes = [axes]
- for c in range(num_channels):
- axes[c].specgram(waveform[c], Fs=sample_rate)
- if num_channels > 1:
- axes[c].set_ylabel(f"Channel {c+1}")
- if xlim:
- axes[c].set_xlim(xlim)
+ figure, ax = plt.subplots()
+ ax.specgram(waveform[0], Fs=sample_rate)
figure.suptitle(title)
- plt.show(block=False)
-
-
-def play_audio(waveform, sample_rate):
- waveform = waveform.numpy()
-
- num_channels, _ = waveform.shape
- if num_channels == 1:
- display(Audio(waveform[0], rate=sample_rate))
- elif num_channels == 2:
- display(Audio((waveform[0], waveform[1]), rate=sample_rate))
- else:
- raise ValueError("Waveform with more than 2 channels are not supported.")
+ figure.tight_layout()
######################################################################
@@ -79,10 +44,25 @@ def play_audio(waveform, sample_rate):
# :py:class:`torchaudio.datasets.YESNO` dataset.
#
-
dataset = torchaudio.datasets.YESNO(YESNO_DATASET_PATH, download=True)
-for i in [1, 3, 5]:
- waveform, sample_rate, label = dataset[i]
- plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}")
- play_audio(waveform, sample_rate)
+######################################################################
+#
+i = 1
+waveform, sample_rate, label = dataset[i]
+plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}")
+IPython.display.Audio(waveform, rate=sample_rate)
+
+######################################################################
+#
+i = 3
+waveform, sample_rate, label = dataset[i]
+plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}")
+IPython.display.Audio(waveform, rate=sample_rate)
+
+######################################################################
+#
+i = 5
+waveform, sample_rate, label = dataset[i]
+plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}")
+IPython.display.Audio(waveform, rate=sample_rate)
diff --git a/examples/tutorials/audio_feature_augmentation_tutorial.py b/examples/tutorials/audio_feature_augmentation_tutorial.py
index 197d03d04b..6e69ef5056 100644
--- a/examples/tutorials/audio_feature_augmentation_tutorial.py
+++ b/examples/tutorials/audio_feature_augmentation_tutorial.py
@@ -19,25 +19,19 @@
print(torchaudio.__version__)
######################################################################
-# Preparing data and utility functions (skip this section)
-# --------------------------------------------------------
+# Preparation
+# -----------
#
-# @title Prepare data and utility functions. {display-mode: "form"}
-# @markdown
-# @markdown You do not need to look into this cell.
-# @markdown Just execute once and you are good to go.
-# @markdown
-# @markdown In this tutorial, we will use a speech data from [VOiCES dataset](https://iqtlabs.github.io/voices/),
-# @markdown which is licensed under Creative Commos BY 4.0.
-
-# -------------------------------------------------------------------------------
-# Preparation of data and helper functions.
-# -------------------------------------------------------------------------------
import librosa
import matplotlib.pyplot as plt
from torchaudio.utils import download_asset
+######################################################################
+# In this tutorial, we will use a speech data from
+# `VOiCES dataset `__,
+# which is licensed under Creative Commos BY 4.0.
+
SAMPLE_WAV_SPEECH_PATH = download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
@@ -75,16 +69,9 @@ def get_spectrogram(
return spectrogram(waveform)
-def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=None):
- fig, axs = plt.subplots(1, 1)
- axs.set_title(title or "Spectrogram (db)")
- axs.set_ylabel(ylabel)
- axs.set_xlabel("frame")
- im = axs.imshow(librosa.power_to_db(spec), origin="lower", aspect=aspect)
- if xmax:
- axs.set_xlim((0, xmax))
- fig.colorbar(im, ax=axs)
- plt.show(block=False)
+def plot_spec(ax, spec, title, ylabel="freq_bin"):
+ ax.set_title(title)
+ ax.imshow(librosa.power_to_db(spec), origin="lower", aspect="auto")
######################################################################
@@ -108,43 +95,47 @@ def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=No
spec = get_spectrogram(power=None)
stretch = T.TimeStretch()
-rate = 1.2
-spec_ = stretch(spec, rate)
-plot_spectrogram(torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect="equal", xmax=304)
-
-plot_spectrogram(torch.abs(spec[0]), title="Original", aspect="equal", xmax=304)
-
-rate = 0.9
-spec_ = stretch(spec, rate)
-plot_spectrogram(torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect="equal", xmax=304)
+spec_12 = stretch(spec, overriding_rate=1.2)
+spec_09 = stretch(spec, overriding_rate=0.9)
######################################################################
-# TimeMasking
-# -----------
#
-torch.random.manual_seed(4)
-spec = get_spectrogram()
-plot_spectrogram(spec[0], title="Original")
+def plot():
+ fig, axes = plt.subplots(3, 1, sharex=True, sharey=True)
+ plot_spec(axes[0], torch.abs(spec_12[0]), title="Stretched x1.2")
+ plot_spec(axes[1], torch.abs(spec[0]), title="Original")
+ plot_spec(axes[2], torch.abs(spec_09[0]), title="Stretched x0.9")
+ fig.tight_layout()
-masking = T.TimeMasking(time_mask_param=80)
-spec = masking(spec)
-plot_spectrogram(spec[0], title="Masked along time axis")
+plot()
######################################################################
-# FrequencyMasking
-# ----------------
+# Time and Frequency Masking
+# --------------------------
#
-
torch.random.manual_seed(4)
+time_masking = T.TimeMasking(time_mask_param=80)
+freq_masking = T.FrequencyMasking(freq_mask_param=80)
+
spec = get_spectrogram()
-plot_spectrogram(spec[0], title="Original")
+time_masked = time_masking(spec)
+freq_masked = freq_masking(spec)
+
+######################################################################
+#
+
+
+def plot():
+ fig, axes = plt.subplots(3, 1, sharex=True, sharey=True)
+ plot_spec(axes[0], spec[0], title="Original")
+ plot_spec(axes[1], time_masked[0], title="Masked along time axis")
+ plot_spec(axes[2], freq_masked[0], title="Masked along frequency axis")
+ fig.tight_layout()
-masking = T.FrequencyMasking(freq_mask_param=80)
-spec = masking(spec)
-plot_spectrogram(spec[0], title="Masked along frequency axis")
+plot()
diff --git a/examples/tutorials/audio_feature_extractions_tutorial.py b/examples/tutorials/audio_feature_extractions_tutorial.py
index 63b71bc14a..eb43c6dca8 100644
--- a/examples/tutorials/audio_feature_extractions_tutorial.py
+++ b/examples/tutorials/audio_feature_extractions_tutorial.py
@@ -75,7 +75,6 @@ def plot_waveform(waveform, sr, title="Waveform", ax=None):
ax.grid(True)
ax.set_xlim([0, time_axis[-1]])
ax.set_title(title)
- plt.show(block=False)
def plot_spectrogram(specgram, title=None, ylabel="freq_bin", ax=None):
@@ -85,7 +84,6 @@ def plot_spectrogram(specgram, title=None, ylabel="freq_bin", ax=None):
ax.set_title(title)
ax.set_ylabel(ylabel)
ax.imshow(librosa.power_to_db(specgram), origin="lower", aspect="auto", interpolation="nearest")
- plt.show(block=False)
def plot_fbank(fbank, title=None):
@@ -94,7 +92,6 @@ def plot_fbank(fbank, title=None):
axs.imshow(fbank, aspect="auto")
axs.set_ylabel("frequency bin")
axs.set_xlabel("mel bin")
- plt.show(block=False)
######################################################################
@@ -486,7 +483,6 @@ def plot_pitch(waveform, sr, pitch):
axis2.plot(time_axis, pitch[0], linewidth=2, label="Pitch", color="green")
axis2.legend(loc=0)
- plt.show(block=False)
plot_pitch(SPEECH_WAVEFORM, SAMPLE_RATE, pitch)
diff --git a/examples/tutorials/audio_io_tutorial.py b/examples/tutorials/audio_io_tutorial.py
index 6fd0f1f2e9..15ef25cc6e 100644
--- a/examples/tutorials/audio_io_tutorial.py
+++ b/examples/tutorials/audio_io_tutorial.py
@@ -181,7 +181,6 @@ def plot_waveform(waveform, sample_rate):
if num_channels > 1:
axes[c].set_ylabel(f"Channel {c+1}")
figure.suptitle("waveform")
- plt.show(block=False)
######################################################################
@@ -204,7 +203,6 @@ def plot_specgram(waveform, sample_rate, title="Spectrogram"):
if num_channels > 1:
axes[c].set_ylabel(f"Channel {c+1}")
figure.suptitle(title)
- plt.show(block=False)
######################################################################
diff --git a/examples/tutorials/audio_resampling_tutorial.py b/examples/tutorials/audio_resampling_tutorial.py
index 33b1ffec53..1398e1d69a 100644
--- a/examples/tutorials/audio_resampling_tutorial.py
+++ b/examples/tutorials/audio_resampling_tutorial.py
@@ -105,7 +105,6 @@ def plot_sweep(
axis.yaxis.grid(True, alpha=0.67)
figure.suptitle(f"{title} (sample rate: {sample_rate} Hz)")
plt.colorbar(cax)
- plt.show(block=True)
######################################################################
diff --git a/examples/tutorials/ctc_forced_alignment_api_tutorial.py b/examples/tutorials/ctc_forced_alignment_api_tutorial.py
index a0d3d7acb7..7d6f02a7f4 100644
--- a/examples/tutorials/ctc_forced_alignment_api_tutorial.py
+++ b/examples/tutorials/ctc_forced_alignment_api_tutorial.py
@@ -5,254 +5,277 @@
**Author**: `Xiaohui Zhang `__
-This tutorial shows how to align transcripts to speech with
-``torchaudio``'s CTC forced alignment API proposed in the paper
-`“Scaling Speech Technology to 1,000+
-Languages” `__,
-and one advanced usage, i.e. dealing with transcription errors with a token.
-
-Though there’s some overlap in visualization
-diagrams, the scope here is different from the `“Forced Alignment with
-Wav2Vec2” `__
-tutorial, which focuses on a step-by-step demonstration of the forced
-alignment generation algorithm (without using an API) described in the
-`paper `__ with a Wav2Vec2 model.
-
+This tutorial shows how to align transcripts to speech using
+:py:func:`torchaudio.functional.forced_align`
+which was developed along the work of
+`Scaling Speech Technology to 1,000+ Languages `__.
+
+The forced alignment is a process to align transcript with speech.
+We cover the basics of forced alignment in `Forced Alignment with
+Wav2Vec2 <./forced_alignment_tutorial.html>`__ with simplified
+step-by-step Python implementations.
+
+:py:func:`~torchaudio.functional.forced_align` has custom CPU and CUDA
+implementations which are more performant than the vanilla Python
+implementation above, and are more accurate.
+It can also handle missing transcript with special token.
+
+For examples of aligning multiple languages, please refer to
+`Forced alignment for multilingual data <./forced_alignment_for_multilingual_data_tutorial.html>`__.
"""
import torch
import torchaudio
+
print(torch.__version__)
print(torchaudio.__version__)
+######################################################################
+#
-device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-print(device)
-
+from dataclasses import dataclass
+from typing import List
-try:
- from torchaudio.functional import forced_align
-except ModuleNotFoundError:
- print(
- "Failed to import the forced alignment API. "
- "Please install torchaudio nightly builds. "
- "Please refer to https://pytorch.org/get-started/locally "
- "for instructions to install a nightly build."
- )
- raise
+import IPython
+import matplotlib.pyplot as plt
######################################################################
-# Basic usages
-# ------------
-#
-# In this section, we cover the following content:
-#
-# 1. Generate frame-wise class probabilites from audio waveform from a CTC
-# acoustic model.
-# 2. Compute frame-level alignments using TorchAudio’s forced alignment
-# API.
-# 3. Obtain token-level alignments from frame-level alignments.
-# 4. Obtain word-level alignments from token-level alignments.
#
+from torchaudio.functional import forced_align
+
+torch.random.manual_seed(0)
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+print(device)
######################################################################
# Preparation
-# ~~~~~~~~~~~
+# -----------
#
-# First we import the necessary packages, and fetch data that we work on.
+# First we prepare the speech data and the transcript we area going
+# to use.
#
-# %matplotlib inline
-from dataclasses import dataclass
-
-import IPython
-import matplotlib
-import matplotlib.pyplot as plt
-
-matplotlib.rcParams["figure.figsize"] = [16.0, 4.8]
-
-torch.random.manual_seed(0)
-
SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
-sample_rate = 16000
-
+TRANSCRIPT = "I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT"
######################################################################
-# Generate frame-wise class posteriors from a CTC acoustic model
-# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# Generating emissions and tokens
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# :py:func:`~torchaudio.functional.forced_align` takes emission and
+# token sequences and outputs timestaps of the tokens and their scores.
#
-# The first step is to generate the class probabilities (i.e. posteriors)
-# of each audio frame using a CTC model.
-# Here we use :py:func:`torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H`.
+# Emission reperesents the frame-wise probability distribution over
+# tokens, and it can be obtained by passing waveform to an acoustic
+# model.
+# Tokens are numerical expression of transcripts. It can be obtained by
+# simply mapping each character to the index of token list.
+# The emission and the token sequences must be using the same set of tokens.
+#
+# We can use pre-trained Wav2Vec2 model to obtain emission from speech,
+# and map transcript to tokens.
+# Here, we use :py:data:`~torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H`,
+# which bandles pre-trained model weights with associated labels.
#
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
model = bundle.get_model().to(device)
-labels = bundle.get_labels()
with torch.inference_mode():
waveform, _ = torchaudio.load(SPEECH_FILE)
- emissions, _ = model(waveform.to(device))
- emissions = torch.log_softmax(emissions, dim=-1)
+ emission, _ = model(waveform.to(device))
+ emission = torch.log_softmax(emission, dim=-1)
+
+
+######################################################################
+#
+
-emission = emissions.cpu().detach()
-dictionary = {c: i for i, c in enumerate(labels)}
+def plot_emission(emission):
+ plt.imshow(emission.cpu().T)
+ plt.title("Frame-wise class probabilities")
+ plt.xlabel("Time")
+ plt.ylabel("Labels")
+ plt.tight_layout()
-print(dictionary)
+plot_emission(emission[0])
######################################################################
-# Visualization
-# ^^^^^^^^^^^^^
-#
+# We create a dictionary, which maps each label into token.
+
+labels = bundle.get_labels()
+DICTIONARY = {c: i for i, c in enumerate(labels)}
+
+for k, v in DICTIONARY.items():
+ print(f"{k}: {v}")
+
+######################################################################
+# converting transcript to tokens is as simple as
-plt.imshow(emission[0].T)
-plt.colorbar()
-plt.title("Frame-wise class probabilities")
-plt.xlabel("Time")
-plt.ylabel("Labels")
-plt.show()
+tokens = [DICTIONARY[c] for c in TRANSCRIPT]
+print(" ".join(str(t) for t in tokens))
######################################################################
# Computing frame-level alignments
-# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# --------------------------------
#
-# Then we call TorchAudio’s forced alignment API to compute the
-# frame-level alignment between each audio frame and each token in the
-# transcript. We first explain the inputs and outputs of the API
-# ``functional.forced_align``. Note that this API works on both CPU and
-# GPU. In the current tutorial we demonstrate it on CPU.
+# Now we call TorchAudio’s forced alignment API to compute the
+# frame-level alignment. For the detail of function signature, please
+# refer to :py:func:`~torchaudio.functional.forced_align`.
#
-# **Inputs**:
#
-# ``emission``: a 2D tensor of size :math:`T \times N`, where :math:`T` is
-# the number of frames (after sub-sampling by the acoustic model, if any),
-# and :math:`N` is the vocabulary size.
-#
-# ``targets``: a 1D tensor vector of size :math:`M`, where :math:`M` is
-# the length of the transcript, and each element is a token ID looked up
-# from the vocabulary. For example, the ``targets`` tensor repsenting the
-# transcript “i had…” is :math:`[5, 18, 4, 16, ...]`.
-#
-# ``input lengths``: :math:`T`.
-#
-# ``target lengths``: :math:`M`.
-#
-# **Outputs**:
-#
-# ``frame_alignment``: a 1D tensor of size :math:`T` storing the aligned
-# token index (looked up from the vocabulary) of each frame, e.g. for the
-# segment corresponding to “i had” in the given example , the
-# frame_alignment is
-# :math:`[...0, 0, 5, 0, 0, 18, 18, 4, 0, 0, 0, 16,...]`, where :math:`0`
-# represents the blank symbol.
+
+
+def align(emission, tokens):
+ alignments, scores = forced_align(
+ emission,
+ targets=torch.tensor([tokens], dtype=torch.int32, device=emission.device),
+ input_lengths=torch.tensor([emission.size(1)], device=emission.device),
+ target_lengths=torch.tensor([len(tokens)], device=emission.device),
+ blank=0,
+ )
+
+ scores = scores.exp() # convert back to probability
+ alignments, scores = alignments[0], scores[0] # remove batch dimension for simplicity
+ return alignments.tolist(), scores.tolist()
+
+
+frame_alignment, frame_scores = align(emission, tokens)
+
+######################################################################
+# Now let's look at the output.
+# Notice that the alignment is expressed in the frame cordinate of
+# emission, which is different from the original waveform.
+
+for i, (ali, score) in enumerate(zip(frame_alignment, frame_scores)):
+ print(f"{i:3d}: {ali:2d} [{labels[ali]}], {score:.2f}")
+
+######################################################################
#
-# ``frame_scores``: a 1D tensor of size :math:`T` storing the confidence
-# score (0 to 1) for each each frame. For each frame, the score should be
-# close to one if the alignment quality is good.
+# The ``Frame`` instance represents the most likely token at each frame
+# with its confidence.
+#
+# When interpreting it, one must remember that the meaning of blank token
+# and repeated token are context dependent.
+#
+# .. note::
+#
+# When same token occured after blank tokens, it is not treated as
+# a repeat, but as a new occurrence.
+#
+# .. code-block::
+#
+# a a a b -> a b
+# a - - b -> a b
+# a a - b -> a b
+# a - a b -> a a b
+# ^^^ ^^^
+#
+# .. code-block::
+#
+# 29: 0 [-], 1.00
+# 30: 7 [I], 1.00 # Start of "I"
+# 31: 0 [-], 0.98 # repeat (blank token)
+# 32: 0 [-], 1.00 # repeat (blank token)
+# 33: 1 [|], 0.85 # Start of "|" (word boundary)
+# 34: 1 [|], 1.00 # repeat (same token)
+# 35: 0 [-], 0.61 # repeat (blank token)
+# 36: 8 [H], 1.00 # Start of "H"
+# 37: 0 [-], 1.00 # repeat (blank token)
+# 38: 4 [A], 1.00 # Start of "A"
+# 39: 0 [-], 0.99 # repeat (blank token)
+# 40: 11 [D], 0.92 # Start of "D"
+# 41: 0 [-], 0.93 # repeat (blank token)
+# 42: 1 [|], 0.98 # Start of "|"
+# 43: 1 [|], 1.00 # repeat (same token)
+# 44: 3 [T], 1.00 # Start of "T"
+# 45: 3 [T], 0.90 # repeat (same token)
+# 46: 8 [H], 1.00 # Start of "H"
+# 47: 0 [-], 1.00 # repeat (blank token)
+
+######################################################################
+# Resolve blank and repeated tokens
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
-# From the outputs ``frame_alignment`` and ``frame_scores``, we generate a
+# Next step is to resolve the repetation. So that what alignment represents
+# do not depend on previous alignments.
+# From the outputs ``alignment`` and ``scores``, we generate a
# list called ``frames`` storing information of all frames aligned to
-# non-blank tokens. Each element contains 1) token_index: the aligned
-# token’s index in the transcript 2) time_index: the current frame’s index
-# in the input audio (or more precisely, the row dimension of the emission
-# matrix) 3) the confidence scores of the current frame.
-#
-# For the given example, the first few elements of the list ``frames``
-# corresponding to “i had” looks as the following:
-#
-# ``Frame(token_index=0, time_index=32, score=0.9994410872459412)``
+# non-blank tokens.
#
-# ``Frame(token_index=1, time_index=35, score=0.9980823993682861)``
+# Each element contains the following
#
-# ``Frame(token_index=1, time_index=36, score=0.9295750260353088)``
-#
-# ``Frame(token_index=2, time_index=37, score=0.9997448325157166)``
-#
-# ``Frame(token_index=3, time_index=41, score=0.9991760849952698)``
-#
-# ``...``
-#
-# The interpretation is:
-#
-# The token with index :math:`0` in the transcript, i.e. “i”, is aligned
-# to the :math:`32`\ th audio frame, with confidence :math:`0.9994`. The
-# token with index :math:`1` in the transcript, i.e. “h”, is aligned to
-# the :math:`35`\ th and :math:`36`\ th audio frames, with confidence
-# :math:`0.9981` and :math:`0.9296` respectively. The token with index
-# :math:`2` in the transcript, i.e. “a”, is aligned to the :math:`35`\ th
-# and :math:`36`\ th audio frames, with confidence :math:`0.9997`. The
-# token with index :math:`3` in the transcript, i.e. “d”, is aligned to
-# the :math:`41`\ th audio frame, with confidence :math:`0.9992`.
-#
-# From such information stored in the ``frames`` list, we’ll compute
-# token-level and word-level alignments easily.
+# - ``token_index``: the aligned token’s index **in the transcript**
+# - ``time_index``: the current frame’s index in emission
+# - ``score``: scores of the current frame.
#
+# ``token_index`` is the index of each token in the transcript,
+# i.e. the current frame aligns to the N-th character from the transcript.
@dataclass
class Frame:
- # This is the index of each token in the transcript,
- # i.e. the current frame aligns to the N-th character from the transcript.
token_index: int
time_index: int
score: float
-def compute_alignments(transcript, dictionary, emission):
- frames = []
- tokens = [dictionary[c] for c in transcript.replace(" ", "")]
-
- targets = torch.tensor(tokens, dtype=torch.int32).unsqueeze(0)
- input_lengths = torch.tensor([emission.shape[1]])
- target_lengths = torch.tensor([targets.shape[1]])
+######################################################################
+#
- # This is the key step, where we call the forced alignment API functional.forced_align to compute alignments.
- frame_alignment, frame_scores = forced_align(emission, targets, input_lengths, target_lengths, 0)
- assert frame_alignment.shape[1] == input_lengths[0].item()
- assert targets.shape[1] == target_lengths[0].item()
+def obtain_token_level_alignments(alignments, scores) -> List[Frame]:
+ assert len(alignments) == len(scores)
token_index = -1
prev_hyp = 0
- for i in range(frame_alignment.shape[1]):
- if frame_alignment[0][i].item() == 0:
+ frames = []
+ for i, (ali, score) in enumerate(zip(alignments, scores)):
+ if ali == 0:
prev_hyp = 0
continue
- if frame_alignment[0][i].item() != prev_hyp:
+ if ali != prev_hyp:
token_index += 1
- frames.append(Frame(token_index, i, frame_scores[0][i].exp().item()))
- prev_hyp = frame_alignment[0][i].item()
- return frames, frame_alignment, frame_scores
-
-
-transcript = "I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT"
-frames, frame_alignment, frame_scores = compute_alignments(transcript, dictionary, emission)
+ frames.append(Frame(token_index, i, score))
+ prev_hyp = ali
+ return frames
######################################################################
-# Obtain token-level alignments and confidence scores
-# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
+frames = obtain_token_level_alignments(frame_alignment, frame_scores)
+
+print("Time\tLabel\tScore")
+for f in frames:
+ print(f"{f.time_index:3d}\t{TRANSCRIPT[f.token_index]}\t{f.score:.2f}")
+
######################################################################
+# Obtain token-level alignments and confidence scores
+# ---------------------------------------------------
+#
# The frame-level alignments contains repetations for the same labels.
# Another format “token-level alignment”, which specifies the aligned
# frame ranges for each transcript token, contains the same information,
# while being more convenient to apply to some downstream tasks
-# (e.g. computing word-level alignments).
+# (e.g. computing word-level alignments).
#
# Now we demonstrate how to obtain token-level alignments and confidence
# scores by simply merging frame-level alignments and averaging
# frame-level confidence scores.
#
+######################################################################
+# The following class represents the label, its score and the time span
+# of its occurance.
+#
+
-# Merge the labels
@dataclass
class Segment:
label: str
@@ -261,13 +284,16 @@ class Segment:
score: float
def __repr__(self):
- return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"
+ return f"{self.label:2s} ({self.score:4.2f}): [{self.start:4d}, {self.end:4d})"
- @property
- def length(self):
+ def __len__(self):
return self.end - self.start
+######################################################################
+#
+
+
def merge_repeats(frames, transcript):
transcript_nospace = transcript.replace(" ", "")
i1, i2 = 0, 0
@@ -288,29 +314,31 @@ def merge_repeats(frames, transcript):
return segments
-segments = merge_repeats(frames, transcript)
+######################################################################
+#
+segments = merge_repeats(frames, TRANSCRIPT)
for seg in segments:
print(seg)
######################################################################
# Visualization
-# ^^^^^^^^^^^^^
+# ~~~~~~~~~~~~~
#
def plot_label_prob(segments, transcript):
- fig, ax2 = plt.subplots(figsize=(16, 4))
+ fig, ax = plt.subplots()
- ax2.set_title("frame-level and token-level confidence scores")
+ ax.set_title("frame-level and token-level confidence scores")
xs, hs, ws = [], [], []
for seg in segments:
if seg.label != "|":
xs.append((seg.end + seg.start) / 2 + 0.4)
hs.append(seg.score)
ws.append(seg.end - seg.start)
- ax2.annotate(seg.label, (seg.start + 0.8, -0.07), weight="bold")
- ax2.bar(xs, hs, width=ws, color="gray", alpha=0.5, edgecolor="black")
+ ax.annotate(seg.label, (seg.start + 0.8, -0.07), weight="bold")
+ ax.bar(xs, hs, width=ws, color="gray", alpha=0.5, edgecolor="black")
xs, hs = [], []
for p in frames:
@@ -319,27 +347,28 @@ def plot_label_prob(segments, transcript):
xs.append(p.time_index + 1)
hs.append(p.score)
- ax2.bar(xs, hs, width=0.5, alpha=0.5)
- ax2.axhline(0, color="black")
- ax2.set_ylim(-0.1, 1.1)
+ ax.bar(xs, hs, width=0.5, alpha=0.5)
+ ax.set_ylim(-0.1, 1.1)
+ ax.grid(True, axis="y")
+ fig.tight_layout()
-plot_label_prob(segments, transcript)
-plt.tight_layout()
-plt.show()
+plot_label_prob(segments, TRANSCRIPT)
######################################################################
# From the visualized scores, we can see that, for tokens spanning over
-# more multiple frames, e.g. “T” in “THAT, the token-level confidence
+# more multiple frames, e.g. “T” in “THAT, the token-level confidence
# score is the average of frame-level confidence scores. To make this
# clearer, we don’t plot confidence scores for blank frames, which was
# plotted in the”Label probability with and without repeatation” figure in
-# the previous tutorial `“Forced Alignment with
-# Wav2Vec2” `__.
+# the previous tutorial
+# `Forced Alignment with Wav2Vec2 <./forced_alignment_tutorial.html>`__.
#
+
+######################################################################
# Obtain word-level alignments and confidence scores
-# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# --------------------------------------------------
#
@@ -367,7 +396,7 @@ def merge_words(transcript, segments, separator=" "):
s = 0
segs = segments[i1 + s : i2 + s]
word = "".join([seg.label for seg in segs])
- score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
+ score = sum(seg.score * len(seg) for seg in segs) / sum(len(seg) for seg in segs)
words.append(Segment(word, segments[i1 + s].start, segments[i2 + s - 1].end, score))
i1 = i2
else:
@@ -376,59 +405,43 @@ def merge_words(transcript, segments, separator=" "):
return words
-word_segments = merge_words(transcript, segments, "|")
+word_segments = merge_words(TRANSCRIPT, segments, "|")
######################################################################
# Visualization
-# ^^^^^^^^^^^^^
+# ~~~~~~~~~~~~~
#
-def plot_alignments(segments, word_segments, waveform, input_lengths, scale=10):
- fig, ax2 = plt.subplots(figsize=(64, 12))
- plt.rcParams.update({"font.size": 30})
+def plot_alignments(waveform, emission, segments, word_segments, sample_rate=bundle.sample_rate):
+ fig, ax = plt.subplots()
- # The original waveform
- ratio = waveform.size(1) / input_lengths
- ax2.plot(waveform)
- ax2.set_ylim(-1.0 * scale, 1.0 * scale)
- ax2.set_xlim(0, waveform.size(-1))
+ ax.specgram(waveform[0], Fs=sample_rate)
+ # The original waveform
+ ratio = waveform.size(1) / sample_rate / emission.size(1)
for word in word_segments:
- x0 = ratio * word.start
- x1 = ratio * word.end
- ax2.axvspan(x0, x1, alpha=0.1, color="red")
- ax2.annotate(f"{word.score:.2f}", (x0, 0.8 * scale))
+ t0, t1 = ratio * word.start, ratio * word.end
+ ax.axvspan(t0, t1, facecolor="None", hatch="/", edgecolor="white")
+ ax.annotate(f"{word.score:.2f}", (t0, sample_rate * 0.51), annotation_clip=False)
for seg in segments:
if seg.label != "|":
- ax2.annotate(seg.label, (seg.start * ratio, 0.9 * scale))
+ ax.annotate(seg.label, (seg.start * ratio, sample_rate * 0.53), annotation_clip=False)
- xticks = ax2.get_xticks()
- plt.xticks(xticks, xticks / sample_rate, fontsize=50)
- ax2.set_xlabel("time [second]", fontsize=40)
- ax2.set_yticks([])
+ ax.set_xlabel("time [second]")
+ fig.tight_layout()
-plot_alignments(
- segments,
- word_segments,
- waveform,
- emission.shape[1],
- 1,
-)
-plt.show()
+plot_alignments(waveform, emission, segments, word_segments)
######################################################################
-# A trick to embed the resulting audio to the generated file.
-# `IPython.display.Audio` has to be the last call in a cell,
-# and there should be only one call par cell.
-def display_segment(i, waveform, word_segments, frame_alignment):
- ratio = waveform.size(1) / frame_alignment.size(1)
+def display_segment(i, waveform, word_segments, frame_alignment, sample_rate=bundle.sample_rate):
+ ratio = waveform.size(1) / len(frame_alignment)
word = word_segments[i]
x0 = int(ratio * word.start)
x1 = int(ratio * word.end)
@@ -437,8 +450,10 @@ def display_segment(i, waveform, word_segments, frame_alignment):
return IPython.display.Audio(segment.numpy(), rate=sample_rate)
+######################################################################
+
# Generate the audio for each segment
-print(transcript)
+print(TRANSCRIPT)
IPython.display.Audio(SPEECH_FILE)
######################################################################
@@ -488,62 +503,71 @@ def display_segment(i, waveform, word_segments, frame_alignment):
######################################################################
-# Advanced usage: Dealing with missing transcripts using the token
-# ---------------------------------------------------------------------------
+# Advanced: Handling transcripts with ```` token
+# ----------------------------------------------------
#
# Now let’s look at when the transcript is partially missing, how can we
-# improve alignment quality using the token, which is capable of modeling
+# improve alignment quality using the ```` token, which is capable of modeling
# any token.
#
# Here we use the same English example as used above. But we remove the
-# beginning text “i had that curiosity beside me at” from the transcript.
+# beginning text ``“i had that curiosity beside me at”`` from the transcript.
# Aligning audio with such transcript results in wrong alignments of the
# existing word “this”. However, this issue can be mitigated by using the
-# token to model the missing text.
+# ```` token to model the missing text.
#
-# Reload the emission tensor in order to add the extra dimension corresponding to the token.
-with torch.inference_mode():
- waveform, _ = torchaudio.load(SPEECH_FILE)
- emissions, _ = model(waveform.to(device))
- emissions = torch.log_softmax(emissions, dim=-1)
+######################################################################
+# First, we extend the dictionary to include the ```` token.
- # Append the extra dimension corresponding to the token
- extra_dim = torch.zeros(emissions.shape[0], emissions.shape[1], 1)
- emissions = torch.cat((emissions.cpu(), extra_dim), 2)
- emission = emissions.detach()
+DICTIONARY["*"] = len(DICTIONARY)
-# Extend the dictionary to include the token.
-dictionary["*"] = 29
+######################################################################
+# Next, we extend the emission tensor with the extra dimension
+# corresponding to the ```` token.
+#
-assert len(dictionary) == emission.shape[2]
+extra_dim = torch.zeros(emission.shape[0], emission.shape[1], 1, device=device)
+emission = torch.cat((emission, extra_dim), 2)
+
+assert len(DICTIONARY) == emission.shape[2]
+
+
+######################################################################
+# The following function combines all the processes, and compute
+# word segments from emission in one-go.
def compute_and_plot_alignments(transcript, dictionary, emission, waveform):
- frames, frame_alignment, _ = compute_alignments(transcript, dictionary, emission)
+ tokens = [dictionary[c] for c in transcript]
+ alignment, scores = align(emission, tokens)
+ frames = obtain_token_level_alignments(alignment, scores)
segments = merge_repeats(frames, transcript)
word_segments = merge_words(transcript, segments, "|")
- plot_alignments(segments, word_segments, waveform, emission.shape[1], 1)
- plt.show()
- return word_segments, frame_alignment
-
+ plot_alignments(waveform, emission, segments, word_segments)
+ plt.xlim([0, None])
-# original:
-word_segments, frame_alignment = compute_and_plot_alignments(transcript, dictionary, emission, waveform)
######################################################################
+# **Original**
-# Demonstrate the effect of token for dealing with deletion errors
-# ("i had that curiosity beside me at" missing from the transcript):
-transcript = "THIS|MOMENT"
-word_segments, frame_alignment = compute_and_plot_alignments(transcript, dictionary, emission, waveform)
+compute_and_plot_alignments(TRANSCRIPT, DICTIONARY, emission, waveform)
######################################################################
+# **With token**
+#
+# Now we replace the first part of the transcript with the ```` token.
-# Replacing the missing transcript with the token:
-transcript = "*|THIS|MOMENT"
-word_segments, frame_alignment = compute_and_plot_alignments(transcript, dictionary, emission, waveform)
+compute_and_plot_alignments("*|THIS|MOMENT", DICTIONARY, emission, waveform)
+
+######################################################################
+# **Without token**
+#
+# As a comparison, the following aligns the partial transcript
+# without using ```` token.
+# It demonstrates the effect of ```` token for dealing with deletion errors.
+compute_and_plot_alignments("THIS|MOMENT", DICTIONARY, emission, waveform)
######################################################################
# Conclusion
@@ -551,7 +575,7 @@ def compute_and_plot_alignments(transcript, dictionary, emission, waveform):
#
# In this tutorial, we looked at how to use torchaudio’s forced alignment
# API to align and segment speech files, and demonstrated one advanced usage:
-# How introducing a token could improve alignment accuracy when
+# How introducing a ```` token could improve alignment accuracy when
# transcription errors exist.
#
diff --git a/examples/tutorials/device_avsr.py b/examples/tutorials/device_avsr.py
index f7a12ec943..0bb7a5792d 100644
--- a/examples/tutorials/device_avsr.py
+++ b/examples/tutorials/device_avsr.py
@@ -69,7 +69,7 @@
# -------------------
#
# Firstly, we define the function to collect videos from microphone and
-# camera. To be specific, we use :py:func:`~torchaudio.io.StreamReader`
+# camera. To be specific, we use :py:class:`~torchaudio.io.StreamReader`
# class for the purpose of data collection, which supports capturing
# audio/video from microphone and camera. For the detailed usage of this
# class, please refer to the
diff --git a/examples/tutorials/filter_design_tutorial.py b/examples/tutorials/filter_design_tutorial.py
index 944a7df3f8..1637eb0cc2 100644
--- a/examples/tutorials/filter_design_tutorial.py
+++ b/examples/tutorials/filter_design_tutorial.py
@@ -89,7 +89,7 @@ def plot_sinc_ir(irs, cutoff):
num_filts, window_size = irs.shape
half = window_size // 2
- fig, axes = plt.subplots(num_filts, 1, sharex=True, figsize=(6.4, 4.8 * 1.5))
+ fig, axes = plt.subplots(num_filts, 1, sharex=True, figsize=(9.6, 8))
t = torch.linspace(-half, half - 1, window_size)
for ax, ir, coff, color in zip(axes, irs, cutoff, plt.cm.tab10.colors):
ax.plot(t, ir, linewidth=1.2, color=color, zorder=4, label=f"Cutoff: {coff}")
@@ -100,7 +100,7 @@ def plot_sinc_ir(irs, cutoff):
"(Frequencies are relative to Nyquist frequency)"
)
axes[-1].set_xticks([i * half // 4 for i in range(-4, 5)])
- plt.tight_layout()
+ fig.tight_layout()
######################################################################
@@ -130,7 +130,7 @@ def plot_sinc_fr(frs, cutoff, band=False):
num_filts, num_fft = frs.shape
num_ticks = num_filts + 1 if band else num_filts
- fig, axes = plt.subplots(num_filts, 1, sharex=True, sharey=True, figsize=(6.4, 4.8 * 1.5))
+ fig, axes = plt.subplots(num_filts, 1, sharex=True, sharey=True, figsize=(9.6, 8))
for ax, fr, coff, color in zip(axes, frs, cutoff, plt.cm.tab10.colors):
ax.grid(True)
ax.semilogy(fr, color=color, zorder=4, label=f"Cutoff: {coff}")
@@ -146,7 +146,7 @@ def plot_sinc_fr(frs, cutoff, band=False):
"Frequency response of sinc low-pass filter for different cut-off frequencies\n"
"(Frequencies are relative to Nyquist frequency)"
)
- plt.tight_layout()
+ fig.tight_layout()
######################################################################
@@ -275,7 +275,7 @@ def plot_ir(magnitudes, ir, num_fft=2048):
axes[i].grid(True)
axes[1].set(title="Frequency Response")
axes[2].set(title="Frequency Response (log-scale)", xlabel="Frequency")
- axes[2].legend(loc="lower right")
+ axes[2].legend(loc="center right")
fig.tight_layout()
diff --git a/examples/tutorials/forced_alignment_for_multilingual_data_tutorial.py b/examples/tutorials/forced_alignment_for_multilingual_data_tutorial.py
index 01333d7175..6f78b0e5d3 100644
--- a/examples/tutorials/forced_alignment_for_multilingual_data_tutorial.py
+++ b/examples/tutorials/forced_alignment_for_multilingual_data_tutorial.py
@@ -6,15 +6,14 @@
This tutorial shows how to compute forced alignments for speech data
from multiple non-English languages using ``torchaudio``'s CTC forced alignment
-API described in `“CTC forced alignment
-tutorial” `__
-and the multilingual Wav2vec2 model proposed in the paper `“Scaling
+API described in `CTC forced alignment tutorial <./forced_alignment_tutorial.html>`__
+and the multilingual Wav2vec2 model proposed in the paper `Scaling
Speech Technology to 1,000+
-Languages” `__.
+Languages `__.
+
The model was trained on 23K of audio data from 1100+ languages using
-the `“uroman vocabulary” `__
+the `uroman vocabulary `__
as targets.
-
"""
import torch
@@ -23,53 +22,46 @@
print(torch.__version__)
print(torchaudio.__version__)
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
-
-try:
- from torchaudio.functional import forced_align
-except ModuleNotFoundError:
- print(
- "Failed to import the forced alignment API. "
- "Please install torchaudio nightly builds. "
- "Please refer to https://pytorch.org/get-started/locally "
- "for instructions to install a nightly build."
- )
- raise
-
######################################################################
# Preparation
# -----------
#
-# Here we import necessary packages, and define utility functions for
-# computing the frame-level alignments (using the API
-# ``functional.forced_align``), token-level and word-level alignments, and
-# also alignment visualization utilities.
-#
-# %matplotlib inline
from dataclasses import dataclass
import IPython
-
import matplotlib.pyplot as plt
+from torchaudio.functional import forced_align
-torch.random.manual_seed(0)
-sample_rate = 16000
+######################################################################
+#
+
+SAMPLE_RATE = 16000
+
+
+######################################################################
+#
+# Here we define utility functions for computing the frame-level
+# alignments (using the API :py:func:`torchaudio.functional.forced_align`),
+# token-level and word-level alignments.
+# For the detail of these functions please refer to
+# `CTC forced alignment API tutorial <./ctc_forced_alignment_api_tutorial.html>`__.
+#
@dataclass
class Frame:
- # This is the index of each token in the transcript,
- # i.e. the current frame aligns to the N-th character from the transcript.
token_index: int
time_index: int
score: float
+######################################################################
+#
@dataclass
class Segment:
label: str
@@ -78,39 +70,42 @@ class Segment:
score: float
def __repr__(self):
- return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"
+ return f"{self.label:2s} ({self.score:4.2f}): [{self.start:4d}, {self.end:4d})"
- @property
- def length(self):
+ def __len__(self):
return self.end - self.start
-# compute frame-level and word-level alignments using torchaudio's forced alignment API
+######################################################################
+#
+
+
def compute_alignments(transcript, dictionary, emission):
- frames = []
tokens = [dictionary[c] for c in transcript.replace(" ", "")]
- targets = torch.tensor(tokens, dtype=torch.int32).unsqueeze(0)
- input_lengths = torch.tensor([emission.shape[1]])
- target_lengths = torch.tensor([targets.shape[1]])
+ targets = torch.tensor([tokens], dtype=torch.int32, device=emission.device)
+ input_lengths = torch.tensor([emission.shape[1]], device=emission.device)
+ target_lengths = torch.tensor([targets.shape[1]], device=emission.device)
+
+ alignment, scores = forced_align(emission, targets, input_lengths, target_lengths, 0)
- # This is the key step, where we call the forced alignment API functional.forced_align to compute frame alignments.
- frame_alignment, frame_scores = forced_align(emission, targets, input_lengths, target_lengths, 0)
+ scores = scores.exp() # convert back to probability
+ alignment, scores = alignment[0].tolist(), scores[0].tolist()
- assert frame_alignment.shape[1] == input_lengths[0].item()
- assert targets.shape[1] == target_lengths[0].item()
+ assert len(alignment) == len(scores) == emission.size(1)
token_index = -1
prev_hyp = 0
- for i in range(frame_alignment.shape[1]):
- if frame_alignment[0][i].item() == 0:
+ frames = []
+ for i, (ali, score) in enumerate(zip(alignment, scores)):
+ if ali == 0:
prev_hyp = 0
continue
- if frame_alignment[0][i].item() != prev_hyp:
+ if ali != prev_hyp:
token_index += 1
- frames.append(Frame(token_index, i, frame_scores[0][i].exp().item()))
- prev_hyp = frame_alignment[0][i].item()
+ frames.append(Frame(token_index, i, score))
+ prev_hyp = ali
# compute frame alignments from token alignments
transcript_nospace = transcript.replace(" ", "")
@@ -140,52 +135,59 @@ def compute_alignments(transcript, dictionary, emission):
if i1 != i2:
if i3 == len(transcript) - 1:
i2 += 1
- s = 0
- segs = segments[i1 + s : i2 + s]
- word = "".join([seg.label for seg in segs])
- score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
- words.append(Segment(word, segments[i1 + s].start, segments[i2 + s - 1].end, score))
+ segs = segments[i1:i2]
+ word = "".join([s.label for s in segs])
+ score = sum(s.score * len(s) for s in segs) / sum(len(s) for s in segs)
+ words.append(Segment(word, segs[0].start, segs[-1].end + 1, score))
i1 = i2
else:
i2 += 1
i3 += 1
+ return segments, words
- num_frames = frame_alignment.shape[1]
- return segments, words, num_frames
+######################################################################
+#
+
+
+def plot_emission(emission):
+ fig, ax = plt.subplots()
+ ax.imshow(emission.T, aspect="auto")
+ ax.set_title("Emission")
+ fig.tight_layout()
-# utility function for plotting word alignments
-def plot_alignments(segments, word_segments, waveform, input_lengths, scale=10):
- fig, ax2 = plt.subplots(figsize=(64, 12))
- plt.rcParams.update({"font.size": 30})
- # The original waveform
- ratio = waveform.size(1) / input_lengths
- ax2.plot(waveform)
- ax2.set_ylim(-1.0 * scale, 1.0 * scale)
- ax2.set_xlim(0, waveform.size(-1))
+######################################################################
+#
+
+# utility function for plotting word alignments
+def plot_alignments(waveform, emission, segments, word_segments, sample_rate=SAMPLE_RATE):
+ fig, ax = plt.subplots()
+ ax.specgram(waveform[0], Fs=sample_rate)
+ xlim = ax.get_xlim()
+ ratio = waveform.size(1) / sample_rate / emission.size(1)
for word in word_segments:
- x0 = ratio * word.start
- x1 = ratio * word.end
- ax2.axvspan(x0, x1, alpha=0.1, color="red")
- ax2.annotate(f"{word.score:.2f}", (x0, 0.8 * scale))
+ t0, t1 = word.start * ratio, word.end * ratio
+ ax.axvspan(t0, t1, facecolor="None", hatch="/", edgecolor="white")
+ ax.annotate(f"{word.score:.2f}", (t0, sample_rate * 0.51), annotation_clip=False)
for seg in segments:
if seg.label != "|":
- ax2.annotate(seg.label, (seg.start * ratio, 0.9 * scale))
+ ax.annotate(seg.label, (seg.start * ratio, sample_rate * 0.53), annotation_clip=False)
- xticks = ax2.get_xticks()
- plt.xticks(xticks, xticks / sample_rate, fontsize=50)
- ax2.set_xlabel("time [second]", fontsize=40)
- ax2.set_yticks([])
+ ax.set_xlabel("time [second]")
+ ax.set_xlim(xlim)
+ fig.tight_layout()
+ return IPython.display.Audio(waveform, rate=sample_rate)
+
+
+######################################################################
+#
# utility function for playing audio segments.
-# A trick to embed the resulting audio to the generated file.
-# `IPython.display.Audio` has to be the last call in a cell,
-# and there should be only one call par cell.
-def display_segment(i, waveform, word_segments, num_frames):
+def display_segment(i, waveform, word_segments, num_frames, sample_rate=SAMPLE_RATE):
ratio = waveform.size(1) / num_frames
word = word_segments[i]
x0 = int(ratio * word.start)
@@ -241,26 +243,21 @@ def display_segment(i, waveform, word_segments, num_frames):
)
)
model.eval()
+model.to(device)
def get_emission(waveform):
- # NOTE: this step is essential
- waveform = torch.nn.functional.layer_norm(waveform, waveform.shape)
-
- emissions, _ = model(waveform)
- emissions = torch.log_softmax(emissions, dim=-1)
- emission = emissions.cpu().detach()
-
- # Append the extra dimension corresponding to the token
- extra_dim = torch.zeros(emissions.shape[0], emissions.shape[1], 1)
- emissions = torch.cat((emissions.cpu(), extra_dim), 2)
- emission = emissions.detach()
- return emission, waveform
+ with torch.inference_mode():
+ # NOTE: this step is essential
+ waveform = torch.nn.functional.layer_norm(waveform, waveform.shape)
+ emission, _ = model(waveform)
+ return torch.log_softmax(emission, dim=-1)
# Construct the dictionary
-# '@' represents the OOV token, '*' represents the token.
+# '@' represents the OOV token
# and are fairseq's legacy tokens, which're not used.
+# token is omitted as we do not use it in this tutorial
dictionary = {
"": 0,
"": 1,
@@ -293,7 +290,6 @@ def get_emission(waveform):
"'": 28,
"q": 29,
"x": 30,
- "*": 31,
}
@@ -304,11 +300,8 @@ def get_emission(waveform):
# romanizer and using it to obtain romanized transcripts, and PyThon
# commands required for further normalizing the romanized transcript.
#
-
-# %%
# .. code-block:: bash
#
-# %%bash
# Save the raw transcript to a file
# echo 'raw text' > text.txt
# git clone https://github.com/isi-nlp/uroman
@@ -334,141 +327,77 @@ def get_emission(waveform):
######################################################################
-# German example:
-# ~~~~~~~~~~~~~~~~
+# German
+# ~~~~~~
-text_raw = (
- "aber seit ich bei ihnen das brot hole brauch ich viel weniger schulze wandte sich ab die kinder taten ihm leid"
-)
-text_normalized = (
- "aber seit ich bei ihnen das brot hole brauch ich viel weniger schulze wandte sich ab die kinder taten ihm leid"
-)
speech_file = torchaudio.utils.download_asset("tutorial-assets/10349_8674_000087.flac", progress=False)
-waveform, _ = torchaudio.load(speech_file)
-
-emission, waveform = get_emission(waveform)
-assert len(dictionary) == emission.shape[2]
-transcript = text_normalized
-
-segments, word_segments, num_frames = compute_alignments(transcript, dictionary, emission)
-plot_alignments(segments, word_segments, waveform, emission.shape[1])
+text_raw = "aber seit ich bei ihnen das brot hole"
+text_normalized = "aber seit ich bei ihnen das brot hole"
print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized)
-IPython.display.Audio(waveform, rate=sample_rate)
######################################################################
#
-display_segment(0, waveform, word_segments, num_frames)
+waveform, _ = torchaudio.load(speech_file, frame_offset=int(0.5 * SAMPLE_RATE), num_frames=int(2.5 * SAMPLE_RATE))
+emission = get_emission(waveform.to(device))
+num_frames = emission.size(1)
+plot_emission(emission[0].cpu())
######################################################################
#
-display_segment(1, waveform, word_segments, num_frames)
+segments, word_segments = compute_alignments(text_normalized, dictionary, emission)
-######################################################################
-#
-
-display_segment(2, waveform, word_segments, num_frames)
-
-######################################################################
-#
-
-display_segment(3, waveform, word_segments, num_frames)
-
-
-######################################################################
-#
-
-display_segment(4, waveform, word_segments, num_frames)
+plot_alignments(waveform, emission, segments, word_segments)
######################################################################
#
-display_segment(5, waveform, word_segments, num_frames)
-
-######################################################################
-#
-
-display_segment(6, waveform, word_segments, num_frames)
-
-######################################################################
-#
-
-display_segment(7, waveform, word_segments, num_frames)
-
-######################################################################
-#
-
-display_segment(8, waveform, word_segments, num_frames)
-
-######################################################################
-#
-
-display_segment(9, waveform, word_segments, num_frames)
-
-######################################################################
-#
-
-display_segment(10, waveform, word_segments, num_frames)
-
-######################################################################
-#
-
-display_segment(11, waveform, word_segments, num_frames)
-
-######################################################################
-#
-
-display_segment(12, waveform, word_segments, num_frames)
+display_segment(0, waveform, word_segments, num_frames)
######################################################################
#
-display_segment(13, waveform, word_segments, num_frames)
+display_segment(1, waveform, word_segments, num_frames)
######################################################################
#
-display_segment(14, waveform, word_segments, num_frames)
+display_segment(2, waveform, word_segments, num_frames)
######################################################################
#
-display_segment(15, waveform, word_segments, num_frames)
-
-######################################################################
-#
+display_segment(3, waveform, word_segments, num_frames)
-display_segment(16, waveform, word_segments, num_frames)
######################################################################
#
-display_segment(17, waveform, word_segments, num_frames)
+display_segment(4, waveform, word_segments, num_frames)
######################################################################
#
-display_segment(18, waveform, word_segments, num_frames)
+display_segment(5, waveform, word_segments, num_frames)
######################################################################
#
-display_segment(19, waveform, word_segments, num_frames)
+display_segment(6, waveform, word_segments, num_frames)
######################################################################
#
-display_segment(20, waveform, word_segments, num_frames)
-
+display_segment(7, waveform, word_segments, num_frames)
######################################################################
-# Chinese example:
-# ~~~~~~~~~~~~~~~~
+# Chinese
+# ~~~~~~~
#
# Chinese is a character-based language, and there is not explicit word-level
# tokenization (separated by spaces) in its raw written form. In order to
@@ -478,98 +407,36 @@ def get_emission(waveform):
# However this is not needed if you only want character-level alignments.
#
-text_raw = "关 服务 高端 产品 仍 处于 供不应求 的 局面"
-text_normalized = "guan fuwu gaoduan chanpin reng chuyu gongbuyingqiu de jumian"
speech_file = torchaudio.utils.download_asset("tutorial-assets/mvdr/clean_speech.wav", progress=False)
-waveform, _ = torchaudio.load(speech_file)
-waveform = waveform[0:1]
-
-emission, waveform = get_emission(waveform)
-
-transcript = text_normalized
-segments, word_segments, num_frames = compute_alignments(transcript, dictionary, emission)
-plot_alignments(segments, word_segments, waveform, emission.shape[1])
+text_raw = "关 服务 高端 产品 仍 处于 供不应求 的 局面"
+text_normalized = "guan fuwu gaoduan chanpin reng chuyu gongbuyingqiu de jumian"
print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized)
-IPython.display.Audio(waveform, rate=sample_rate)
######################################################################
#
-display_segment(0, waveform, word_segments, num_frames)
-
-
-######################################################################
-#
-
-display_segment(1, waveform, word_segments, num_frames)
-
-######################################################################
-#
-
-display_segment(2, waveform, word_segments, num_frames)
-
-######################################################################
-#
-
-display_segment(3, waveform, word_segments, num_frames)
-
-
-######################################################################
-#
-
-display_segment(4, waveform, word_segments, num_frames)
-
-######################################################################
-#
-
-display_segment(5, waveform, word_segments, num_frames)
-
-######################################################################
-#
-
-display_segment(6, waveform, word_segments, num_frames)
-
-######################################################################
-#
+waveform, _ = torchaudio.load(speech_file)
+waveform = waveform[0:1]
-display_segment(7, waveform, word_segments, num_frames)
+emission = get_emission(waveform.to(device))
+num_frames = emission.size(1)
+plot_emission(emission[0].cpu())
######################################################################
#
-display_segment(8, waveform, word_segments, num_frames)
-
-
-######################################################################
-# Polish example:
-# ~~~~~~~~~~~~~~~
-
-
-text_raw = "wtedy ujrzałem na jego brzuchu okrągłą czarną ranę dlaczego mi nie powiedziałeś szepnąłem ze łzami"
-text_normalized = "wtedy ujrzalem na jego brzuchu okragla czarna rane dlaczego mi nie powiedziales szepnalem ze lzami"
-speech_file = torchaudio.utils.download_asset("tutorial-assets/5090_1447_000088.flac", progress=False)
-waveform, _ = torchaudio.load(speech_file)
-
-emission, waveform = get_emission(waveform)
-
-transcript = text_normalized
+segments, word_segments = compute_alignments(text_normalized, dictionary, emission)
-segments, word_segments, num_frames = compute_alignments(transcript, dictionary, emission)
-plot_alignments(segments, word_segments, waveform, emission.shape[1])
-
-print("Raw Transcript: ", text_raw)
-print("Normalized Transcript: ", text_normalized)
-IPython.display.Audio(waveform, rate=sample_rate)
+plot_alignments(waveform, emission, segments, word_segments)
######################################################################
#
display_segment(0, waveform, word_segments, num_frames)
-
######################################################################
#
@@ -585,7 +452,6 @@ def get_emission(waveform):
display_segment(3, waveform, word_segments, num_frames)
-
######################################################################
#
@@ -611,68 +477,40 @@ def get_emission(waveform):
display_segment(8, waveform, word_segments, num_frames)
-######################################################################
-#
-
-display_segment(9, waveform, word_segments, num_frames)
######################################################################
-#
+# Polish
+# ~~~~~~
-display_segment(10, waveform, word_segments, num_frames)
+speech_file = torchaudio.utils.download_asset("tutorial-assets/5090_1447_000088.flac", progress=False)
-######################################################################
-#
+text_raw = "wtedy ujrzałem na jego brzuchu okrągłą czarną ranę"
+text_normalized = "wtedy ujrzalem na jego brzuchu okragla czarna rane"
-display_segment(11, waveform, word_segments, num_frames)
+print("Raw Transcript: ", text_raw)
+print("Normalized Transcript: ", text_normalized)
######################################################################
#
-display_segment(12, waveform, word_segments, num_frames)
-
-######################################################################
-#
+waveform, _ = torchaudio.load(speech_file, num_frames=int(4.5 * SAMPLE_RATE))
-display_segment(13, waveform, word_segments, num_frames)
+emission = get_emission(waveform.to(device))
+num_frames = emission.size(1)
+plot_emission(emission[0].cpu())
######################################################################
#
-display_segment(14, waveform, word_segments, num_frames)
+segments, word_segments = compute_alignments(text_normalized, dictionary, emission)
-
-######################################################################
-# Portuguese example:
-# ~~~~~~~~~~~~~~~~~~~
-
-
-text_raw = (
- "mas na imensa extensão onde se esconde o inconsciente imortal só me responde um bramido um queixume e nada mais"
-)
-text_normalized = (
- "mas na imensa extensao onde se esconde o inconsciente imortal so me responde um bramido um queixume e nada mais"
-)
-speech_file = torchaudio.utils.download_asset("tutorial-assets/6566_5323_000027.flac", progress=False)
-waveform, _ = torchaudio.load(speech_file)
-
-emission, waveform = get_emission(waveform)
-
-transcript = text_normalized
-
-segments, word_segments, num_frames = compute_alignments(transcript, dictionary, emission)
-plot_alignments(segments, word_segments, waveform, emission.shape[1])
-
-print("Raw Transcript: ", text_raw)
-print("Normalized Transcript: ", text_normalized)
-IPython.display.Audio(waveform, rate=sample_rate)
+plot_alignments(waveform, emission, segments, word_segments)
######################################################################
#
display_segment(0, waveform, word_segments, num_frames)
-
######################################################################
#
@@ -688,7 +526,6 @@ def get_emission(waveform):
display_segment(3, waveform, word_segments, num_frames)
-
######################################################################
#
@@ -710,94 +547,38 @@ def get_emission(waveform):
display_segment(7, waveform, word_segments, num_frames)
######################################################################
-#
-
-display_segment(8, waveform, word_segments, num_frames)
-
-######################################################################
-#
-
-display_segment(9, waveform, word_segments, num_frames)
-
-######################################################################
-#
-
-display_segment(10, waveform, word_segments, num_frames)
-
-######################################################################
-#
-
-display_segment(11, waveform, word_segments, num_frames)
-
-######################################################################
-#
-
-display_segment(12, waveform, word_segments, num_frames)
+# Portuguese
+# ~~~~~~~~~~
-######################################################################
-#
-
-display_segment(13, waveform, word_segments, num_frames)
-
-######################################################################
-#
-
-display_segment(14, waveform, word_segments, num_frames)
-
-######################################################################
-#
-
-display_segment(15, waveform, word_segments, num_frames)
+speech_file = torchaudio.utils.download_asset("tutorial-assets/6566_5323_000027.flac", progress=False)
-######################################################################
-#
+text_raw = "na imensa extensão onde se esconde o inconsciente imortal"
+text_normalized = "na imensa extensao onde se esconde o inconsciente imortal"
-display_segment(16, waveform, word_segments, num_frames)
+print("Raw Transcript: ", text_raw)
+print("Normalized Transcript: ", text_normalized)
######################################################################
#
-display_segment(17, waveform, word_segments, num_frames)
+waveform, _ = torchaudio.load(speech_file, frame_offset=int(SAMPLE_RATE), num_frames=int(4.6 * SAMPLE_RATE))
-######################################################################
-#
-
-display_segment(18, waveform, word_segments, num_frames)
+emission = get_emission(waveform.to(device))
+num_frames = emission.size(1)
+plot_emission(emission[0].cpu())
######################################################################
#
-display_segment(19, waveform, word_segments, num_frames)
+segments, word_segments = compute_alignments(text_normalized, dictionary, emission)
-
-######################################################################
-# Italian example:
-# ~~~~~~~~~~~~~~~~
-
-text_raw = "elle giacean per terra tutte quante fuor d'una ch'a seder si levò ratto ch'ella ci vide passarsi davante"
-text_normalized = (
- "elle giacean per terra tutte quante fuor d'una ch'a seder si levo ratto ch'ella ci vide passarsi davante"
-)
-speech_file = torchaudio.utils.download_asset("tutorial-assets/642_529_000025.flac", progress=False)
-waveform, _ = torchaudio.load(speech_file)
-
-emission, waveform = get_emission(waveform)
-
-transcript = text_normalized
-
-segments, word_segments, num_frames = compute_alignments(transcript, dictionary, emission)
-plot_alignments(segments, word_segments, waveform, emission.shape[1])
-
-print("Raw Transcript: ", text_raw)
-print("Normalized Transcript: ", text_normalized)
-IPython.display.Audio(waveform, rate=sample_rate)
+plot_alignments(waveform, emission, segments, word_segments)
######################################################################
#
display_segment(0, waveform, word_segments, num_frames)
-
######################################################################
#
@@ -813,7 +594,6 @@ def get_emission(waveform):
display_segment(3, waveform, word_segments, num_frames)
-
######################################################################
#
@@ -840,50 +620,62 @@ def get_emission(waveform):
display_segment(8, waveform, word_segments, num_frames)
######################################################################
-#
+# Italian
+# ~~~~~~~
+
+speech_file = torchaudio.utils.download_asset("tutorial-assets/642_529_000025.flac", progress=False)
+
+text_raw = "elle giacean per terra tutte quante"
+text_normalized = "elle giacean per terra tutte quante"
-display_segment(9, waveform, word_segments, num_frames)
+print("Raw Transcript: ", text_raw)
+print("Normalized Transcript: ", text_normalized)
######################################################################
#
-display_segment(10, waveform, word_segments, num_frames)
+waveform, _ = torchaudio.load(speech_file, num_frames=int(4 * SAMPLE_RATE))
+
+emission = get_emission(waveform.to(device))
+num_frames = emission.size(1)
+plot_emission(emission[0].cpu())
######################################################################
#
-display_segment(11, waveform, word_segments, num_frames)
+segments, word_segments = compute_alignments(text_normalized, dictionary, emission)
+
+plot_alignments(waveform, emission, segments, word_segments)
######################################################################
#
-display_segment(12, waveform, word_segments, num_frames)
+display_segment(0, waveform, word_segments, num_frames)
######################################################################
#
-display_segment(13, waveform, word_segments, num_frames)
+display_segment(1, waveform, word_segments, num_frames)
######################################################################
#
-display_segment(14, waveform, word_segments, num_frames)
+display_segment(2, waveform, word_segments, num_frames)
######################################################################
#
-display_segment(15, waveform, word_segments, num_frames)
+display_segment(3, waveform, word_segments, num_frames)
######################################################################
#
-display_segment(16, waveform, word_segments, num_frames)
+display_segment(4, waveform, word_segments, num_frames)
######################################################################
#
-display_segment(17, waveform, word_segments, num_frames)
-
+display_segment(5, waveform, word_segments, num_frames)
######################################################################
# Conclusion
@@ -894,7 +686,6 @@ def get_emission(waveform):
# speech data to transcripts in five languages.
#
-
######################################################################
# Acknowledgement
# ---------------
diff --git a/examples/tutorials/forced_alignment_tutorial.py b/examples/tutorials/forced_alignment_tutorial.py
index ab98908559..fef58e2e06 100644
--- a/examples/tutorials/forced_alignment_tutorial.py
+++ b/examples/tutorials/forced_alignment_tutorial.py
@@ -56,16 +56,11 @@
# First we import the necessary packages, and fetch data that we work on.
#
-# %matplotlib inline
-
from dataclasses import dataclass
import IPython
-import matplotlib
import matplotlib.pyplot as plt
-matplotlib.rcParams["figure.figsize"] = [16.0, 4.8]
-
torch.random.manual_seed(0)
SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
@@ -99,17 +94,22 @@
emission = emissions[0].cpu().detach()
+print(labels)
+
################################################################################
# Visualization
-################################################################################
-print(labels)
-plt.imshow(emission.T)
-plt.colorbar()
-plt.title("Frame-wise class probability")
-plt.xlabel("Time")
-plt.ylabel("Labels")
-plt.show()
+# ~~~~~~~~~~~~~
+
+def plot():
+ plt.imshow(emission.T)
+ plt.colorbar()
+ plt.title("Frame-wise class probability")
+ plt.xlabel("Time")
+ plt.ylabel("Labels")
+
+
+plot()
######################################################################
# Generate alignment probability (trellis)
@@ -181,12 +181,17 @@ def get_trellis(emission, tokens, blank_id=0):
################################################################################
# Visualization
-################################################################################
-plt.imshow(trellis.T, origin="lower")
-plt.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5))
-plt.annotate("+ Inf", (trellis.size(0) - trellis.size(1) / 5, trellis.size(1) / 3))
-plt.colorbar()
-plt.show()
+# ~~~~~~~~~~~~~
+
+
+def plot():
+ plt.imshow(trellis.T, origin="lower")
+ plt.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5))
+ plt.annotate("+ Inf", (trellis.size(0) - trellis.size(1) / 5, trellis.size(1) / 3))
+ plt.colorbar()
+
+
+plot()
######################################################################
# In the above visualization, we can see that there is a trace of high
@@ -266,7 +271,9 @@ def backtrack(trellis, emission, tokens, blank_id=0):
################################################################################
# Visualization
-################################################################################
+# ~~~~~~~~~~~~~
+
+
def plot_trellis_with_path(trellis, path):
# To plot trellis with path, we take advantage of 'nan' value
trellis_with_path = trellis.clone()
@@ -277,10 +284,14 @@ def plot_trellis_with_path(trellis, path):
plot_trellis_with_path(trellis, path)
plt.title("The path found by backtracking")
-plt.show()
######################################################################
-# Looking good. Now this path contains repetations for the same labels, so
+# Looking good.
+
+######################################################################
+# Segment the path
+# ----------------
+# Now this path contains repetations for the same labels, so
# let’s merge them to make it close to the original transcript.
#
# When merging the multiple path points, we simply take the average
@@ -297,7 +308,7 @@ class Segment:
score: float
def __repr__(self):
- return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"
+ return f"{self.label} ({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"
@property
def length(self):
@@ -330,7 +341,9 @@ def merge_repeats(path):
################################################################################
# Visualization
-################################################################################
+# ~~~~~~~~~~~~~
+
+
def plot_trellis_with_segments(trellis, segments, transcript):
# To plot trellis with path, we take advantage of 'nan' value
trellis_with_path = trellis.clone()
@@ -338,15 +351,14 @@ def plot_trellis_with_segments(trellis, segments, transcript):
if seg.label != "|":
trellis_with_path[seg.start : seg.end, i] = float("nan")
- fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9.5))
+ fig, [ax1, ax2] = plt.subplots(2, 1, sharex=True)
ax1.set_title("Path, label and probability for each label")
- ax1.imshow(trellis_with_path.T, origin="lower")
- ax1.set_xticks([])
+ ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto")
for i, seg in enumerate(segments):
if seg.label != "|":
- ax1.annotate(seg.label, (seg.start, i - 0.7), weight="bold")
- ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3))
+ ax1.annotate(seg.label, (seg.start, i - 0.7), size="small")
+ ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), size="small")
ax2.set_title("Label probability with and without repetation")
xs, hs, ws = [], [], []
@@ -355,7 +367,7 @@ def plot_trellis_with_segments(trellis, segments, transcript):
xs.append((seg.end + seg.start) / 2 + 0.4)
hs.append(seg.score)
ws.append(seg.end - seg.start)
- ax2.annotate(seg.label, (seg.start + 0.8, -0.07), weight="bold")
+ ax2.annotate(seg.label, (seg.start + 0.8, -0.07))
ax2.bar(xs, hs, width=ws, color="gray", alpha=0.5, edgecolor="black")
xs, hs = [], []
@@ -367,17 +379,21 @@ def plot_trellis_with_segments(trellis, segments, transcript):
ax2.bar(xs, hs, width=0.5, alpha=0.5)
ax2.axhline(0, color="black")
- ax2.set_xlim(ax1.get_xlim())
+ ax2.grid(True, axis="y")
ax2.set_ylim(-0.1, 1.1)
+ fig.tight_layout()
plot_trellis_with_segments(trellis, segments, transcript)
-plt.tight_layout()
-plt.show()
######################################################################
-# Looks good. Now let’s merge the words. The Wav2Vec2 model uses ``'|'``
+# Looks good.
+
+######################################################################
+# Merge the segments into words
+# -----------------------------
+# Now let’s merge the words. The Wav2Vec2 model uses ``'|'``
# as the word boundary, so we merge the segments before each occurance of
# ``'|'``.
#
@@ -410,16 +426,16 @@ def merge_words(segments, separator="|"):
################################################################################
# Visualization
-################################################################################
+# ~~~~~~~~~~~~~
def plot_alignments(trellis, segments, word_segments, waveform):
trellis_with_path = trellis.clone()
for i, seg in enumerate(segments):
if seg.label != "|":
trellis_with_path[seg.start : seg.end, i] = float("nan")
- fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9.5))
+ fig, [ax1, ax2] = plt.subplots(2, 1)
- ax1.imshow(trellis_with_path.T, origin="lower")
+ ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto")
ax1.set_xticks([])
ax1.set_yticks([])
@@ -429,8 +445,8 @@ def plot_alignments(trellis, segments, word_segments, waveform):
for i, seg in enumerate(segments):
if seg.label != "|":
- ax1.annotate(seg.label, (seg.start, i - 0.7))
- ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), fontsize=8)
+ ax1.annotate(seg.label, (seg.start, i - 0.7), size="small")
+ ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), size="small")
# The original waveform
ratio = waveform.size(0) / trellis.size(0)
@@ -450,6 +466,7 @@ def plot_alignments(trellis, segments, word_segments, waveform):
ax2.set_yticks([])
ax2.set_ylim(-1.0, 1.0)
ax2.set_xlim(0, waveform.size(-1))
+ fig.tight_layout()
plot_alignments(
@@ -458,7 +475,6 @@ def plot_alignments(trellis, segments, word_segments, waveform):
word_segments,
waveform[0],
)
-plt.show()
################################################################################
diff --git a/examples/tutorials/hybrid_demucs_tutorial.py b/examples/tutorials/hybrid_demucs_tutorial.py
index 8be6c9903b..081534bfe4 100644
--- a/examples/tutorials/hybrid_demucs_tutorial.py
+++ b/examples/tutorials/hybrid_demucs_tutorial.py
@@ -162,11 +162,10 @@ def separate_sources(
def plot_spectrogram(stft, title="Spectrogram"):
magnitude = stft.abs()
spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
- figure, axis = plt.subplots(1, 1)
- img = axis.imshow(spectrogram, cmap="viridis", vmin=-60, vmax=0, origin="lower", aspect="auto")
- figure.suptitle(title)
- plt.colorbar(img, ax=axis)
- plt.show()
+ _, axis = plt.subplots(1, 1)
+ axis.imshow(spectrogram, cmap="viridis", vmin=-60, vmax=0, origin="lower", aspect="auto")
+ axis.set_title(title)
+ plt.tight_layout()
######################################################################
@@ -252,7 +251,7 @@ def output_results(original_source: torch.Tensor, predicted_source: torch.Tensor
"SDR score is:",
separation.bss_eval_sources(original_source.detach().numpy(), predicted_source.detach().numpy())[0].mean(),
)
- plot_spectrogram(stft(predicted_source)[0], f"Spectrogram {source}")
+ plot_spectrogram(stft(predicted_source)[0], f"Spectrogram - {source}")
return Audio(predicted_source, rate=sample_rate)
@@ -294,7 +293,7 @@ def output_results(original_source: torch.Tensor, predicted_source: torch.Tensor
#
# Mixture Clip
-plot_spectrogram(stft(mix_spec)[0], "Spectrogram Mixture")
+plot_spectrogram(stft(mix_spec)[0], "Spectrogram - Mixture")
Audio(mix_spec, rate=sample_rate)
######################################################################
diff --git a/examples/tutorials/mvdr_tutorial.py b/examples/tutorials/mvdr_tutorial.py
index 7c9013d180..442f6234a6 100644
--- a/examples/tutorials/mvdr_tutorial.py
+++ b/examples/tutorials/mvdr_tutorial.py
@@ -98,23 +98,21 @@
#
-def plot_spectrogram(stft, title="Spectrogram", xlim=None):
+def plot_spectrogram(stft, title="Spectrogram"):
magnitude = stft.abs()
spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
figure, axis = plt.subplots(1, 1)
img = axis.imshow(spectrogram, cmap="viridis", vmin=-100, vmax=0, origin="lower", aspect="auto")
- figure.suptitle(title)
+ axis.set_title(title)
plt.colorbar(img, ax=axis)
- plt.show()
-def plot_mask(mask, title="Mask", xlim=None):
+def plot_mask(mask, title="Mask"):
mask = mask.numpy()
figure, axis = plt.subplots(1, 1)
img = axis.imshow(mask, cmap="viridis", origin="lower", aspect="auto")
- figure.suptitle(title)
+ axis.set_title(title)
plt.colorbar(img, ax=axis)
- plt.show()
def si_snr(estimate, reference, epsilon=1e-8):
diff --git a/examples/tutorials/nvdec_tutorial.py b/examples/tutorials/nvdec_tutorial.py
index efeca53975..459b690a3f 100644
--- a/examples/tutorials/nvdec_tutorial.py
+++ b/examples/tutorials/nvdec_tutorial.py
@@ -33,12 +33,9 @@
import os
import time
-import matplotlib
import matplotlib.pyplot as plt
from torchaudio.io import StreamReader
-matplotlib.rcParams["image.interpolation"] = "none"
-
######################################################################
#
# Check the prerequisites
diff --git a/examples/tutorials/speech_recognition_pipeline_tutorial.py b/examples/tutorials/speech_recognition_pipeline_tutorial.py
index 79bbae14c2..2d815a2e8e 100644
--- a/examples/tutorials/speech_recognition_pipeline_tutorial.py
+++ b/examples/tutorials/speech_recognition_pipeline_tutorial.py
@@ -160,8 +160,7 @@
ax[i].set_title(f"Feature from transformer layer {i+1}")
ax[i].set_xlabel("Feature dimension")
ax[i].set_ylabel("Frame (time-axis)")
-plt.tight_layout()
-plt.show()
+fig.tight_layout()
######################################################################
@@ -190,7 +189,7 @@
plt.title("Classification result")
plt.xlabel("Frame (time-axis)")
plt.ylabel("Class")
-plt.show()
+plt.tight_layout()
print("Class labels:", bundle.get_labels())
diff --git a/examples/tutorials/squim_tutorial.py b/examples/tutorials/squim_tutorial.py
index 640e2e79b8..9b9b55ac2e 100644
--- a/examples/tutorials/squim_tutorial.py
+++ b/examples/tutorials/squim_tutorial.py
@@ -82,19 +82,23 @@
from pystoi import stoi
from torchaudio.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE
except ImportError:
- import google.colab # noqa: F401
-
- print(
- """
- To enable running this notebook in Google Colab, install nightly
- torch and torchaudio builds by adding the following code block to the top
- of the notebook before running it:
- !pip3 uninstall -y torch torchvision torchaudio
- !pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu
- !pip3 install pesq
- !pip3 install pystoi
- """
- )
+ try:
+ import google.colab # noqa: F401
+
+ print(
+ """
+ To enable running this notebook in Google Colab, install nightly
+ torch and torchaudio builds by adding the following code block to the top
+ of the notebook before running it:
+ !pip3 uninstall -y torch torchvision torchaudio
+ !pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu
+ !pip3 install pesq
+ !pip3 install pystoi
+ """
+ )
+ except Exception:
+ pass
+ raise
import matplotlib.pyplot as plt
@@ -128,27 +132,17 @@ def si_snr(estimate, reference, epsilon=1e-8):
return si_snr.item()
-def plot_waveform(waveform, title):
+def plot(waveform, title, sample_rate=16000):
wav_numpy = waveform.numpy()
sample_size = waveform.shape[1]
- time_axis = torch.arange(0, sample_size) / 16000
-
- figure, axes = plt.subplots(1, 1)
- axes = figure.gca()
- axes.plot(time_axis, wav_numpy[0], linewidth=1)
- axes.grid(True)
- figure.suptitle(title)
- plt.show(block=False)
+ time_axis = torch.arange(0, sample_size) / sample_rate
-
-def plot_specgram(waveform, sample_rate, title):
- wav_numpy = waveform.numpy()
- figure, axes = plt.subplots(1, 1)
- axes = figure.gca()
- axes.specgram(wav_numpy[0], Fs=sample_rate)
+ figure, axes = plt.subplots(2, 1)
+ axes[0].plot(time_axis, wav_numpy[0], linewidth=1)
+ axes[0].grid(True)
+ axes[1].specgram(wav_numpy[0], Fs=sample_rate)
figure.suptitle(title)
- plt.show(block=False)
######################################################################
@@ -238,32 +232,28 @@ def plot_specgram(waveform, sample_rate, title):
# Visualize speech sample
#
-plot_waveform(WAVEFORM_SPEECH, "Clean Speech")
-plot_specgram(WAVEFORM_SPEECH, 16000, "Clean Speech Spectrogram")
+plot(WAVEFORM_SPEECH, "Clean Speech")
######################################################################
# Visualize noise sample
#
-plot_waveform(WAVEFORM_NOISE, "Noise")
-plot_specgram(WAVEFORM_NOISE, 16000, "Noise Spectrogram")
+plot(WAVEFORM_NOISE, "Noise")
######################################################################
# Visualize distorted speech with 20dB SNR
#
-plot_waveform(WAVEFORM_DISTORTED[0:1], f"Distorted Speech with {snr_dbs[0]}dB SNR")
-plot_specgram(WAVEFORM_DISTORTED[0:1], 16000, f"Distorted Speech with {snr_dbs[0]}dB SNR")
+plot(WAVEFORM_DISTORTED[0:1], f"Distorted Speech with {snr_dbs[0]}dB SNR")
######################################################################
# Visualize distorted speech with -5dB SNR
#
-plot_waveform(WAVEFORM_DISTORTED[1:2], f"Distorted Speech with {snr_dbs[1]}dB SNR")
-plot_specgram(WAVEFORM_DISTORTED[1:2], 16000, f"Distorted Speech with {snr_dbs[1]}dB SNR")
+plot(WAVEFORM_DISTORTED[1:2], f"Distorted Speech with {snr_dbs[1]}dB SNR")
######################################################################
diff --git a/examples/tutorials/streamreader_advanced_tutorial.py b/examples/tutorials/streamreader_advanced_tutorial.py
index 84f113cc16..7d3a4bd09e 100644
--- a/examples/tutorials/streamreader_advanced_tutorial.py
+++ b/examples/tutorials/streamreader_advanced_tutorial.py
@@ -355,13 +355,14 @@
def _display(i):
print("filter_desc:", streamer.get_out_stream_info(i).filter_description)
- _, axs = plt.subplots(2, 1)
+ fig, axs = plt.subplots(2, 1)
waveform = chunks[i][:, 0]
axs[0].plot(waveform)
axs[0].grid(True)
axs[0].set_ylim([-1, 1])
plt.setp(axs[0].get_xticklabels(), visible=False)
axs[1].specgram(waveform, Fs=sample_rate)
+ fig.tight_layout()
return IPython.display.Audio(chunks[i].T, rate=sample_rate)
@@ -440,7 +441,6 @@ def _display(i):
axs[j].imshow(chunk[10 * j + 1].permute(1, 2, 0))
axs[j].set_axis_off()
plt.tight_layout()
- plt.show(block=False)
######################################################################
diff --git a/examples/tutorials/streamreader_basic_tutorial.py b/examples/tutorials/streamreader_basic_tutorial.py
index 29ba36aabf..ce94088c79 100644
--- a/examples/tutorials/streamreader_basic_tutorial.py
+++ b/examples/tutorials/streamreader_basic_tutorial.py
@@ -592,7 +592,6 @@
if i == 0 and j == 0:
ax.set_ylabel("Stream 2")
plt.tight_layout()
-plt.show(block=False)
######################################################################
#
diff --git a/examples/tutorials/tacotron2_pipeline_tutorial.py b/examples/tutorials/tacotron2_pipeline_tutorial.py
index 586cdb7d09..00687166e9 100644
--- a/examples/tutorials/tacotron2_pipeline_tutorial.py
+++ b/examples/tutorials/tacotron2_pipeline_tutorial.py
@@ -7,10 +7,6 @@
"""
-import IPython
-import matplotlib
-import matplotlib.pyplot as plt
-
######################################################################
# Overview
# --------
@@ -65,8 +61,6 @@
import torch
import torchaudio
-matplotlib.rcParams["figure.figsize"] = [16.0, 4.8]
-
torch.random.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -75,6 +69,13 @@
print(device)
+######################################################################
+#
+
+import IPython
+import matplotlib.pyplot as plt
+
+
######################################################################
# Text Processing
# ---------------
@@ -226,13 +227,17 @@ def text_to_sequence(text):
# therefor, the process of generating the spectrogram incurs randomness.
#
-fig, ax = plt.subplots(3, 1, figsize=(16, 4.3 * 3))
-for i in range(3):
- with torch.inference_mode():
- spec, spec_lengths, _ = tacotron2.infer(processed, lengths)
- print(spec[0].shape)
- ax[i].imshow(spec[0].cpu().detach(), origin="lower", aspect="auto")
-plt.show()
+
+def plot():
+ fig, ax = plt.subplots(3, 1)
+ for i in range(3):
+ with torch.inference_mode():
+ spec, spec_lengths, _ = tacotron2.infer(processed, lengths)
+ print(spec[0].shape)
+ ax[i].imshow(spec[0].cpu().detach(), origin="lower", aspect="auto")
+
+
+plot()
######################################################################
@@ -270,11 +275,22 @@ def text_to_sequence(text):
spec, spec_lengths, _ = tacotron2.infer(processed, lengths)
waveforms, lengths = vocoder(spec, spec_lengths)
-fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9))
-ax1.imshow(spec[0].cpu().detach(), origin="lower", aspect="auto")
-ax2.plot(waveforms[0].cpu().detach())
+######################################################################
+#
+
+
+def plot(waveforms, spec, sample_rate):
+ waveforms = waveforms.cpu().detach()
-IPython.display.Audio(waveforms[0:1].cpu(), rate=vocoder.sample_rate)
+ fig, [ax1, ax2] = plt.subplots(2, 1)
+ ax1.plot(waveforms[0])
+ ax1.set_xlim(0, waveforms.size(-1))
+ ax1.grid(True)
+ ax2.imshow(spec[0].cpu().detach(), origin="lower", aspect="auto")
+ return IPython.display.Audio(waveforms[0:1], rate=sample_rate)
+
+
+plot(waveforms, spec, vocoder.sample_rate)
######################################################################
@@ -300,11 +316,10 @@ def text_to_sequence(text):
spec, spec_lengths, _ = tacotron2.infer(processed, lengths)
waveforms, lengths = vocoder(spec, spec_lengths)
-fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9))
-ax1.imshow(spec[0].cpu().detach(), origin="lower", aspect="auto")
-ax2.plot(waveforms[0].cpu().detach())
+######################################################################
+#
-IPython.display.Audio(waveforms[0:1].cpu(), rate=vocoder.sample_rate)
+plot(waveforms, spec, vocoder.sample_rate)
######################################################################
@@ -339,8 +354,7 @@ def text_to_sequence(text):
with torch.no_grad():
waveforms = waveglow.infer(spec)
-fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9))
-ax1.imshow(spec[0].cpu().detach(), origin="lower", aspect="auto")
-ax2.plot(waveforms[0].cpu().detach())
+######################################################################
+#
-IPython.display.Audio(waveforms[0:1].cpu(), rate=22050)
+plot(waveforms, spec, 22050)