diff --git a/requirements.txt b/requirements.txt index 8d360ad..bcea830 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ numpy numba torch>=2 -torchaudio torchlpc diff --git a/setup.py b/setup.py index 1ac0022..e2e3dca 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ long_description_content_type="text/markdown", url="https://github.com/yoyololicon/torchcomp", packages=["torchcomp"], - install_requires=["torch>=2", "torchaudio", "torchlpc", "numpy", "numba"], + install_requires=["torch>=2", "torchlpc", "numpy", "numba"], classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", diff --git a/torchcomp/__init__.py b/torchcomp/__init__.py index cbd6093..e7e216c 100644 --- a/torchcomp/__init__.py +++ b/torchcomp/__init__.py @@ -1,7 +1,7 @@ import torch import torch.nn.functional as F from typing import Union -from torchaudio.functional import lfilter +from torchlpc import sample_wise_lpc from .core import compressor_core @@ -88,11 +88,9 @@ def avg(rms: torch.Tensor, avg_coef: Union[torch.Tensor, float]): ).broadcast_to(rms.shape[0]) assert torch.all(avg_coef > 0) and torch.all(avg_coef <= 1) - return lfilter( - rms, - torch.stack([torch.ones_like(avg_coef), avg_coef - 1], 1), - torch.stack([avg_coef, torch.zeros_like(avg_coef)], 1), - False, + return sample_wise_lpc( + rms * avg_coef, + avg_coef[:, None, None].broadcast_to(rms.shape + (1,)) - 1, )