Skip to content

Commit

Permalink
feat: support audio formats (#29)
Browse files Browse the repository at this point in the history
* mp3, wav, flac, ogg support
  • Loading branch information
rilshok authored Nov 28, 2024
2 parents 3c4303f + 067f290 commit 2118058
Show file tree
Hide file tree
Showing 21 changed files with 262 additions and 39 deletions.
15 changes: 8 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,32 +1,33 @@
default_stages:
- commit

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v5.0.0
hooks:
- id: check-added-large-files
args: [--maxkb=5000]
args: [--maxkb=64]
- id: trailing-whitespace
- id: check-json
- id: check-merge-conflict
- id: check-xml
- id: check-yaml
- id: detect-private-key
- id: mixed-line-ending
- id: check-json
- id: pretty-format-json
args: [--autofix]
exclude: \.ipynb$
- id: check-yaml

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.10
rev: v0.7.3
hooks:
- id: ruff
types_or: [ python, pyi ]
args: [ --fix ]
- id: ruff-format
types_or: [ python, pyi ]


- repo: https://github.com/jendrikseipp/vulture
rev: v2.6
rev: v2.13
hooks:
- id: vulture
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include src/iokit/py.typed
17 changes: 16 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ dependencies = [
"requests>=2.32.3",
"typing-extensions>=4.8.0",
"cryptography>=41.0.7",
"numpy>=1.21.1",
"soundfile>=0.12.1",
]

[tool.setuptools.dynamic]
Expand All @@ -29,6 +31,9 @@ lint = [
"mypy>=1.7.1",
"ruff>=0.6.3",
"types-pytz>=2024.1",
"types-requests>=2.31.0",
"types-PyYAML>=6.0.12",
"types-python-dateutil>=2.8.19",
]
test = [
"pytest>=8.2.2",
Expand All @@ -49,9 +54,19 @@ line-length = 100
[tool.ruff.format]
docstring-code-format = true

[tool.ruff.lint]
select = [
"E", # pycodestyle
"F", # Pyflakes
"UP", # pyupgrade
"B", # flake8-bugbear
"SIM",# flake8-simplify
"I", # isort
]

[tool.vulture]
make_whitelist = true
sort_by_size = true
verbose = true
verbose = false
min_confidence = 100
paths = ["src/iokit"]
25 changes: 23 additions & 2 deletions src/iokit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
"Txt",
"Yaml",
"Zip",
"Mp3",
"Wav",
"Waveform",
"Flac",
"Ogg",
"State",
"filter_states",
"find_state",
Expand All @@ -17,8 +22,24 @@
"save_file",
"save_temp",
]
__version__ = "0.1.9"
__version__ = "0.2.0"

from .extensions import Dat, Encryption, Env, Gzip, Json, Jsonl, Tar, Txt, Yaml, Zip
from .extensions import (
Dat,
Encryption,
Env,
Flac,
Gzip,
Json,
Jsonl,
Mp3,
Ogg,
Tar,
Txt,
Wav,
Waveform,
Yaml,
Zip,
)
from .state import State, filter_states, find_state
from .storage import download_file, load_file, save_file, save_temp
6 changes: 6 additions & 0 deletions src/iokit/extensions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
__all__ = [
"Flac",
"Mp3",
"Ogg",
"Wav",
"Waveform",
"Dat",
"Encryption",
"Env",
Expand All @@ -11,6 +16,7 @@
"Zip",
]

from .audio import Flac, Mp3, Ogg, Wav, Waveform
from .dat import Dat
from .enc import Encryption
from .env import Env
Expand Down
106 changes: 106 additions & 0 deletions src/iokit/extensions/audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from dataclasses import dataclass
from io import BytesIO
from typing import Any

import soundfile
from numpy import float32
from numpy.typing import NDArray

from iokit.state import State


class AudioState(State, suffix=""):
def __init__(self, waveform: "Waveform", **kwargs: Any):
soundfile.write(
file=(target := BytesIO()),
data=waveform.wave,
samplerate=waveform.freq,
format=self._suffix,
)
super().__init__(data=target.getvalue(), **kwargs)

def load(self) -> "Waveform":
wave, freq = soundfile.read(self.data, always_2d=True)
return Waveform(wave=wave, freq=freq)


class Flac(AudioState, suffix="flac"):
pass


class Wav(AudioState, suffix="wav"):
pass


class Mp3(AudioState, suffix="mp3"):
pass


class Ogg(AudioState, suffix="ogg"):
pass


@dataclass
class Waveform:
wave: NDArray[float32]
freq: int

def __post_init__(self) -> None:
if self.wave.ndim == 1:
self.wave = self.wave[:, None]
if self.wave.ndim != 2:
msg = f"Waveform must be 1D or 2D array, but got {self.wave.ndim}D"
raise ValueError(msg)
if self.wave.shape[1] >= self.wave.shape[0]:
msg = (
"Waveform must have more frames than channels,"
f" but got {self.wave.shape[0]} frames and {self.wave.shape[1]} channels."
)
raise ValueError(msg)
if self.wave.dtype is not float32:
self.wave = self.wave.astype(float32)

@property
def channels(self) -> int:
return self.wave.shape[1]

@property
def duration(self) -> float:
return self.wave.shape[0] / self.freq

def copy(self) -> "Waveform":
return Waveform(self.wave.copy(), self.freq)

def _position(self, time: float) -> int:
return int(time * self.freq)

def cut(self, begin: float | None = None, end: float | None = None) -> "Waveform":
if begin is None and end is None:
return self.copy()
begin, end = begin or 0.0, end or self.duration
start, stop = self._position(begin), self._position(end)
if stop > self.wave.shape[0]:
stop = self.wave.shape[0]
return Waveform(self.wave[start:stop], self.freq)

def display(self):
from IPython.display import Audio, display

return display(Audio(self.wave.T, rate=self.freq))

def to_mono(self) -> "Waveform":
if self.channels == 1:
return self.copy()
return Waveform(self.wave.mean(axis=1), self.freq)

def to_flac(self, name: str, **kwargs: Any) -> Flac:
return Flac(waveform=self, name=name, **kwargs)

def to_wav(self, name: str, **kwargs: Any) -> Wav:
return Wav(waveform=self, name=name, **kwargs)

def to_mp3(self, name: str, **kwargs: Any) -> Mp3:
return Mp3(waveform=self, name=name, **kwargs)

def to_ogg(self, name: str, **kwargs: Any) -> Ogg:
return Ogg(waveform=self, name=name, **kwargs)
17 changes: 13 additions & 4 deletions src/iokit/extensions/enc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import struct
from collections.abc import Iterator
from hashlib import sha256
from typing import Any, Iterator
from typing import Any

from cryptography.exceptions import InvalidTag
from cryptography.hazmat.primitives.ciphers.algorithms import AES
from cryptography.hazmat.primitives.ciphers.base import Cipher
from cryptography.hazmat.primitives.ciphers.modes import GCM
Expand Down Expand Up @@ -41,16 +43,23 @@ def encrypt(data: bytes, password: bytes, salt: bytes) -> bytes:
padded = padder.update(data) + padder.finalize()
ct = encryptor.update(padded) + encryptor.finalize()
tag = encryptor.tag
return ct + tag
result = ct + tag
assert isinstance(result, bytes)
return result


def decrypt(data: bytes, password: bytes, salt: bytes) -> bytes:
key = _generate_key(password=password, salt=salt)
unpadder = PKCS7(128).unpadder()
decryptor = _cipher(key=key, salt=salt).decryptor()
ct, tag = data[:-16], data[-16:]
padded = decryptor.update(ct) + decryptor.finalize_with_tag(tag)
return unpadder.update(padded) + unpadder.finalize()
try:
padded = decryptor.update(ct) + decryptor.finalize_with_tag(tag)
except InvalidTag as exc:
raise ValueError("Decryption failed") from exc
result = unpadder.update(padded) + unpadder.finalize()
assert isinstance(result, bytes)
return result


def _pack_arrays(*arrays: bytes) -> bytes:
Expand Down
3 changes: 2 additions & 1 deletion src/iokit/extensions/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
]

import json
from collections.abc import Callable
from functools import lru_cache
from typing import Any, Callable
from typing import Any

from iokit.state import State

Expand Down
3 changes: 2 additions & 1 deletion src/iokit/extensions/jsonl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
"Jsonl",
]

from collections.abc import Iterable
from io import BytesIO
from typing import Any, Iterable
from typing import Any

from jsonlines import Reader, Writer

Expand Down
3 changes: 2 additions & 1 deletion src/iokit/extensions/tar.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import tarfile
from collections.abc import Iterable
from io import BytesIO
from typing import Any, Iterable
from typing import Any

from iokit.state import State
from iokit.tools.time import fromtimestamp
Expand Down
3 changes: 2 additions & 1 deletion src/iokit/extensions/zip.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import zipfile
from collections.abc import Iterable
from datetime import datetime
from io import BytesIO
from typing import Any, Iterable
from typing import Any

from iokit.state import State

Expand Down
16 changes: 8 additions & 8 deletions src/iokit/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
"find_state",
]

from collections.abc import Generator, Iterable
from contextlib import suppress
from datetime import datetime
from fnmatch import fnmatch
from io import BytesIO
from typing import Any, Generator, Iterable, Type
from typing import Any

from humanize import naturalsize
from typing_extensions import Self
Expand Down Expand Up @@ -86,9 +87,8 @@ def __init_subclass__(
if suffix is not None and suffixes is None:
suffixes = (suffix,)

if suffix is not None and suffixes is not None:
if suffix not in suffixes:
suffixes = (suffix, *suffixes)
if suffix is not None and suffixes is not None and suffix not in suffixes:
suffixes = (suffix, *suffixes)

if suffix is None or suffixes is None:
raise ValueError("State subclasses must define a suffix or suffixes")
Expand Down Expand Up @@ -128,7 +128,7 @@ def __repr__(self) -> str:
return f"{self.name} ({size})"

@classmethod
def _by_suffix(cls, suffix: str) -> Type[Self]:
def _by_suffix(cls, suffix: str) -> type[Self]:
if suffix in cls._suffixes:
return cls
for kls in cls.__subclasses__():
Expand All @@ -140,9 +140,9 @@ def cast(self) -> "State":
with suppress(ValueError):
klass = self._by_suffix(self.name.suffix)
state = klass.__new__(klass)
setattr(state, "_data", self.data)
setattr(state, "_name", self.name)
setattr(state, "_time", self.time)
state._data = self.data
state._name = self.name
state._time = self.time
return state
return self

Expand Down
2 changes: 1 addition & 1 deletion src/iokit/storage/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
"save_temp",
]
import tempfile
from collections.abc import Generator
from contextlib import contextmanager
from pathlib import Path
from typing import Generator

from iokit.state import State
from iokit.tools.time import fromtimestamp
Expand Down
Empty file added tests/__init__.py
Empty file.
Loading

0 comments on commit 2118058

Please sign in to comment.