Skip to content

Commit

Permalink
Extend torchaudio support to 2.1.x (#3)
Browse files Browse the repository at this point in the history
* Update README.md

* Update README.md

* Update README.md

* Update README.md

* minor fixes for 4.1.0a1 (facebookresearch#552)

* minor fixes for 4.1.0a1

print out the exception when calling callback

ensures all threads can be stopped when interrupting separation

add release data for 4.0.1

* Fix model_idx_in_bag always zero

* fix linter

* Fix can't separate empty audio

* Calls callback when skipping empty audio

* Add description for aborting

* Does not ignore callback exception

* Fix linter

* Does not ignore exception

* Disable torchaudio 2.2+

* Uses epsilon to deal with empty audio

* Reraises exception in callback

* Ensure the pool stops when encountering exception

* Update windows.md for latest instructions

* Minor documentation updates (facebookresearch#565)

* Minor documentation updates

* Update readme

* Update api.md

* Fix segment defined in bag can't override model

* merge from adefossez/demucs

* Update README.md

* Extend torchaudio support to 2.1.x

* Use correct import statement

* Calculate FFT on CPU also when device is XPU (Intel GPU)

---------

Co-authored-by: Alexandre Défossez <[email protected]>
Co-authored-by: William Dye <[email protected]>
  • Loading branch information
3 people authored Jan 12, 2024
1 parent a4decbd commit b3398e4
Show file tree
Hide file tree
Showing 11 changed files with 39 additions and 14 deletions.
1 change: 1 addition & 0 deletions demucs/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import subprocess

from . import audio_legacy
import torch as th
import torchaudio as ta

Expand Down
1 change: 1 addition & 0 deletions demucs/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import lameenc
import julius
import numpy as np
from . import audio_legacy
import torch
import torchaudio as ta
import typing as tp
Expand Down
17 changes: 17 additions & 0 deletions demucs/audio_legacy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# This file is to extend support for torchaudio 2.1

import importlib
import os
import sys
import warnings

if not "torchaudio" in sys.modules:
os.environ["TORCHAUDIO_USE_BACKEND_DISPATCHER"] = "0"
elif os.getenv("TORCHAUDIO_USE_BACKEND_DISPATCHER", default="1") == "1":
if sys.modules["torchaudio"].__version__ >= "2.1":
os.environ["TORCHAUDIO_USE_BACKEND_DISPATCHER"] = "0"
importlib.reload(sys.modules["torchaudio"])
warnings.warn(
"TORCHAUDIO_USE_BACKEND_DISPATCHER is set to 0 and torchaudio is reloaded.",
ImportWarning,
)
10 changes: 6 additions & 4 deletions demucs/hdemucs.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,16 +776,18 @@ def forward(self, mix):
# demucs issue #435 ##432
# NOTE: in this case z already is on cpu
# TODO: remove this when mps supports complex numbers
x_is_mps = x.device.type == "mps"
if x_is_mps:
x_is_mps_xpu = x.device.type in ["mps", "xpu"]
x_device = x.device
if x_is_mps_xpu:
x = x.cpu()

zout = self._mask(z, x)
x = self._ispec(zout, length)

# back to mps device
if x_is_mps:
x = x.to('mps')
if x_is_mps_xpu:
x = x.to(x_device)


if self.hybrid:
xt = xt.view(B, S, -1, length)
Expand Down
9 changes: 5 additions & 4 deletions demucs/htdemucs.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,8 +629,9 @@ def forward(self, mix):
# demucs issue #435 ##432
# NOTE: in this case z already is on cpu
# TODO: remove this when mps supports complex numbers
x_is_mps = x.device.type == "mps"
if x_is_mps:
x_is_mps_xpu = x.device.type in ["mps", "xpu"]
x_device = x.device
if x_is_mps_xpu:
x = x.cpu()

zout = self._mask(z, x)
Expand All @@ -643,8 +644,8 @@ def forward(self, mix):
x = self._ispec(zout, length)

# back to mps device
if x_is_mps:
x = x.to("mps")
if x_is_mps_xpu:
x = x.to(x_device)

if self.use_train_segment:
if self.training:
Expand Down
1 change: 1 addition & 0 deletions demucs/repitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import subprocess as sp
import tempfile

from . import audio_legacy
import torch
import torchaudio as ta

Expand Down
8 changes: 4 additions & 4 deletions demucs/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
def spectro(x, n_fft=512, hop_length=None, pad=0):
*other, length = x.shape
x = x.reshape(-1, length)
is_mps = x.device.type == 'mps'
if is_mps:
is_mps_xpu = x.device.type in ['mps', 'xpu']
if is_mps_xpu:
x = x.cpu()
z = th.stft(x,
n_fft * (1 + pad),
Expand All @@ -32,8 +32,8 @@ def ispectro(z, hop_length=None, length=None, pad=0):
n_fft = 2 * freqs - 2
z = z.view(-1, freqs, frames)
win_length = n_fft // (1 + pad)
is_mps = z.device.type == 'mps'
if is_mps:
is_mps_xpu = z.device.type in ['mps', 'xpu']
if is_mps_xpu:
z = z.cpu()
x = th.istft(z,
n_fft,
Expand Down
1 change: 1 addition & 0 deletions demucs/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import hydra
from hydra.core.global_hydra import GlobalHydra
from omegaconf import OmegaConf
from . import audio_legacy
import torch
from torch import nn
import torchaudio
Expand Down
1 change: 1 addition & 0 deletions demucs/wav.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import musdb
import julius
from . import audio_legacy
import torch as th
from torch import distributed
import torchaudio as ta
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ openunmix
pyyaml
submitit
torch>=1.8.1
torchaudio>=0.8,<2.1
torchaudio>=0.8,<2.2
tqdm
treetable
soundfile>=0.10.3;sys_platform=="win32"
2 changes: 1 addition & 1 deletion requirements_minimal.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ lameenc>=1.2
openunmix
pyyaml
torch>=1.8.1
torchaudio>=0.8,<2.1
torchaudio>=0.8,<2.2
tqdm

0 comments on commit b3398e4

Please sign in to comment.