Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set and tweak global matplotlib configuration in tutorials #3515

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,22 @@ def _get_pattern():
return ret


def reset_mpl(gallery_conf, fname):
from sphinx_gallery.scrapers import _reset_matplotlib

_reset_matplotlib(gallery_conf, fname)
import matplotlib

matplotlib.rcParams.update(
{
"image.interpolation": "none",
"figure.figsize": (9.6, 4.8),
"font.size": 8.0,
"axes.axisbelow": True,
}
)


sphinx_gallery_conf = {
"examples_dirs": [
"../../examples/tutorials",
Expand All @@ -139,6 +155,7 @@ def _get_pattern():
"promote_jupyter_magic": True,
"first_notebook_cell": None,
"doc_module": ("torchaudio",),
"reset_modules": (reset_mpl, "seaborn"),
}
autosummary_generate = True

Expand Down
9 changes: 8 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ model implementations and application components.
tutorials/online_asr_tutorial
tutorials/device_asr
tutorials/device_avsr
tutorials/forced_alignment_for_multilingual_data_tutorial
tutorials/forced_alignment_tutorial
tutorials/forced_alignment_for_multilingual_data_tutorial
tutorials/tacotron2_pipeline_tutorial
tutorials/mvdr_tutorial
tutorials/hybrid_demucs_tutorial
Expand Down Expand Up @@ -147,6 +147,13 @@ Tutorials

.. customcardstart::

.. customcarditem::
:header: On device audio-visual automatic speech recognition
:card_description: Learn how to stream audio and video from laptop webcam and perform audio-visual automatic speech recognition using Emformer-RNNT model.
:image: https://download.pytorch.org/torchaudio/doc-assets/avsr/transformed.gif
:link: tutorials/device_avsr.html
:tags: I/O,Pipelines,RNNT

.. customcarditem::
:header: Loading waveform Tensors from files and saving them
:card_description: Learn how to query/load audio files and save waveform tensors to files, using <code>torchaudio.info</code>, <code>torchaudio.load</code> and <code>torchaudio.save</code> functions.
Expand Down
16 changes: 8 additions & 8 deletions examples/tutorials/additive_synthesis_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
#


def show(freq, amp, waveform, sample_rate, zoom=None, vol=0.1):
def plot(freq, amp, waveform, sample_rate, zoom=None, vol=0.1):
t = torch.arange(waveform.size(0)) / sample_rate

fig, axes = plt.subplots(4, 1, sharex=True)
Expand All @@ -101,7 +101,7 @@ def show(freq, amp, waveform, sample_rate, zoom=None, vol=0.1):
for i in range(4):
axes[i].grid(True)
pos = axes[2].get_position()
plt.tight_layout()
fig.tight_layout()

if zoom is not None:
ax = fig.add_axes([pos.x0 + 0.02, pos.y0 + 0.03, pos.width / 2.5, pos.height / 2.0])
Expand Down Expand Up @@ -168,7 +168,7 @@ def sawtooth_wave(freq0, amp0, num_pitches, sample_rate):
freq0 = torch.full((NUM_FRAMES, 1), F0)
amp0 = torch.ones((NUM_FRAMES, 1))
freq, amp, waveform = sawtooth_wave(freq0, amp0, int(SAMPLE_RATE / F0), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))

######################################################################
#
Expand All @@ -183,7 +183,7 @@ def sawtooth_wave(freq0, amp0, num_pitches, sample_rate):
freq0 = F0 + f_dev * torch.sin(phase).unsqueeze(-1)

freq, amp, waveform = sawtooth_wave(freq0, amp0, int(SAMPLE_RATE / F0), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))

######################################################################
# Square wave
Expand Down Expand Up @@ -220,7 +220,7 @@ def square_wave(freq0, amp0, num_pitches, sample_rate):
freq0 = torch.full((NUM_FRAMES, 1), F0)
amp0 = torch.ones((NUM_FRAMES, 1))
freq, amp, waveform = square_wave(freq0, amp0, int(SAMPLE_RATE / F0 / 2), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))

######################################################################
# Triangle wave
Expand Down Expand Up @@ -256,7 +256,7 @@ def triangle_wave(freq0, amp0, num_pitches, sample_rate):
#

freq, amp, waveform = triangle_wave(freq0, amp0, int(SAMPLE_RATE / F0 / 2), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))

######################################################################
# Inharmonic Paritials
Expand Down Expand Up @@ -296,7 +296,7 @@ def triangle_wave(freq0, amp0, num_pitches, sample_rate):

waveform = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE)

show(freq, amp, waveform, SAMPLE_RATE, vol=0.4)
plot(freq, amp, waveform, SAMPLE_RATE, vol=0.4)

######################################################################
#
Expand All @@ -308,7 +308,7 @@ def triangle_wave(freq0, amp0, num_pitches, sample_rate):
freq = extend_pitch(freq0, num_tones)
waveform = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE)

show(freq, amp, waveform, SAMPLE_RATE)
plot(freq, amp, waveform, SAMPLE_RATE)

######################################################################
# References
Expand Down
63 changes: 39 additions & 24 deletions examples/tutorials/asr_inference_with_ctc_decoder_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,30 +407,45 @@ def forward(self, emission: torch.Tensor) -> List[str]:
#


def plot_alignments(waveform, emission, tokens, timesteps):
fig, ax = plt.subplots(figsize=(32, 10))

ax.plot(waveform)

ratio = waveform.shape[0] / emission.shape[1]
word_start = 0

for i in range(len(tokens)):
if i != 0 and tokens[i - 1] == "|":
word_start = timesteps[i]
if tokens[i] != "|":
plt.annotate(tokens[i].upper(), (timesteps[i] * ratio, waveform.max() * 1.02), size=14)
elif i != 0:
word_end = timesteps[i]
ax.axvspan(word_start * ratio, word_end * ratio, alpha=0.1, color="red")

xticks = ax.get_xticks()
plt.xticks(xticks, xticks / bundle.sample_rate)
ax.set_xlabel("time (sec)")
ax.set_xlim(0, waveform.shape[0])


plot_alignments(waveform[0], emission, predicted_tokens, timesteps)
def plot_alignments(waveform, emission, tokens, timesteps, sample_rate):

t = torch.arange(waveform.size(0)) / sample_rate
ratio = waveform.size(0) / emission.size(1) / sample_rate

chars = []
words = []
word_start = None
for token, timestep in zip(tokens, timesteps * ratio):
if token == "|":
if word_start is not None:
words.append((word_start, timestep))
word_start = None
else:
chars.append((token, timestep))
if word_start is None:
word_start = timestep

fig, axes = plt.subplots(3, 1)

def _plot(ax, xlim):
ax.plot(t, waveform)
for token, timestep in chars:
ax.annotate(token.upper(), (timestep, 0.5))
for word_start, word_end in words:
ax.axvspan(word_start, word_end, alpha=0.1, color="red")
ax.set_ylim(-0.6, 0.7)
ax.set_yticks([0])
ax.grid(True, axis="y")
ax.set_xlim(xlim)

_plot(axes[0], (0.3, 2.5))
_plot(axes[1], (2.5, 4.7))
_plot(axes[2], (4.7, 6.9))
axes[2].set_xlabel("time (sec)")
fig.tight_layout()


plot_alignments(waveform[0], emission, predicted_tokens, timesteps, bundle.sample_rate)


######################################################################
Expand Down
2 changes: 0 additions & 2 deletions examples/tutorials/audio_data_augmentation_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None):
if xlim:
axes[c].set_xlim(xlim)
figure.suptitle(title)
plt.show(block=False)


######################################################################
Expand All @@ -122,7 +121,6 @@ def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
if xlim:
axes[c].set_xlim(xlim)
figure.suptitle(title)
plt.show(block=False)


######################################################################
Expand Down
72 changes: 26 additions & 46 deletions examples/tutorials/audio_datasets_tutorial.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
Audio Datasets
==============
Expand All @@ -10,79 +9,60 @@
available datasets.
"""

# When running this tutorial in Google Colab, install the required packages
# with the following.
# !pip install torchaudio

import torch
import torchaudio

print(torch.__version__)
print(torchaudio.__version__)

######################################################################
# Preparing data and utility functions (skip this section)
# --------------------------------------------------------
#

# @title Prepare data and utility functions. {display-mode: "form"}
# @markdown
# @markdown You do not need to look into this cell.
# @markdown Just execute once and you are good to go.

# -------------------------------------------------------------------------------
# Preparation of data and helper functions.
# -------------------------------------------------------------------------------
import os

import IPython

import matplotlib.pyplot as plt
from IPython.display import Audio, display


_SAMPLE_DIR = "_assets"
YESNO_DATASET_PATH = os.path.join(_SAMPLE_DIR, "yes_no")
os.makedirs(YESNO_DATASET_PATH, exist_ok=True)


def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
def plot_specgram(waveform, sample_rate, title="Spectrogram"):
waveform = waveform.numpy()

num_channels, _ = waveform.shape

figure, axes = plt.subplots(num_channels, 1)
if num_channels == 1:
axes = [axes]
for c in range(num_channels):
axes[c].specgram(waveform[c], Fs=sample_rate)
if num_channels > 1:
axes[c].set_ylabel(f"Channel {c+1}")
if xlim:
axes[c].set_xlim(xlim)
figure, ax = plt.subplots()
ax.specgram(waveform[0], Fs=sample_rate)
figure.suptitle(title)
plt.show(block=False)


def play_audio(waveform, sample_rate):
waveform = waveform.numpy()

num_channels, _ = waveform.shape
if num_channels == 1:
display(Audio(waveform[0], rate=sample_rate))
elif num_channels == 2:
display(Audio((waveform[0], waveform[1]), rate=sample_rate))
else:
raise ValueError("Waveform with more than 2 channels are not supported.")
figure.tight_layout()


######################################################################
# Here, we show how to use the
# :py:class:`torchaudio.datasets.YESNO` dataset.
#


dataset = torchaudio.datasets.YESNO(YESNO_DATASET_PATH, download=True)

for i in [1, 3, 5]:
waveform, sample_rate, label = dataset[i]
plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}")
play_audio(waveform, sample_rate)
######################################################################
#
i = 1
waveform, sample_rate, label = dataset[i]
plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}")
IPython.display.Audio(waveform, rate=sample_rate)

######################################################################
#
i = 3
waveform, sample_rate, label = dataset[i]
plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}")
IPython.display.Audio(waveform, rate=sample_rate)

######################################################################
#
i = 5
waveform, sample_rate, label = dataset[i]
plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}")
IPython.display.Audio(waveform, rate=sample_rate)
Loading
Loading