Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support audio formats #29

Merged
merged 22 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,32 +1,33 @@
default_stages:
- commit

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v5.0.0
hooks:
- id: check-added-large-files
args: [--maxkb=5000]
args: [--maxkb=64]
- id: trailing-whitespace
- id: check-json
- id: check-merge-conflict
- id: check-xml
- id: check-yaml
- id: detect-private-key
- id: mixed-line-ending
- id: check-json
- id: pretty-format-json
args: [--autofix]
exclude: \.ipynb$
- id: check-yaml

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.10
rev: v0.7.3
hooks:
- id: ruff
types_or: [ python, pyi ]
args: [ --fix ]
- id: ruff-format
types_or: [ python, pyi ]


- repo: https://github.com/jendrikseipp/vulture
rev: v2.6
rev: v2.13
hooks:
- id: vulture
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include src/iokit/py.typed
17 changes: 16 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ dependencies = [
"requests>=2.32.3",
"typing-extensions>=4.8.0",
"cryptography>=41.0.7",
"numpy>=1.21.1",
"soundfile>=0.12.1",
]

[tool.setuptools.dynamic]
Expand All @@ -29,6 +31,9 @@ lint = [
"mypy>=1.7.1",
"ruff>=0.6.3",
"types-pytz>=2024.1",
"types-requests>=2.31.0",
"types-PyYAML>=6.0.12",
"types-python-dateutil>=2.8.19",
]
test = [
"pytest>=8.2.2",
Expand All @@ -49,9 +54,19 @@ line-length = 100
[tool.ruff.format]
docstring-code-format = true

[tool.ruff.lint]
select = [
"E", # pycodestyle
"F", # Pyflakes
"UP", # pyupgrade
"B", # flake8-bugbear
"SIM",# flake8-simplify
"I", # isort
]

[tool.vulture]
make_whitelist = true
sort_by_size = true
verbose = true
verbose = false
min_confidence = 100
paths = ["src/iokit"]
25 changes: 23 additions & 2 deletions src/iokit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
"Txt",
"Yaml",
"Zip",
"Mp3",
"Wav",
"Waveform",
"Flac",
"Ogg",
"State",
"filter_states",
"find_state",
Expand All @@ -17,8 +22,24 @@
"save_file",
"save_temp",
]
__version__ = "0.1.9"
__version__ = "0.2.0"

from .extensions import Dat, Encryption, Env, Gzip, Json, Jsonl, Tar, Txt, Yaml, Zip
from .extensions import (
Dat,
Encryption,
Env,
Flac,
Gzip,
Json,
Jsonl,
Mp3,
Ogg,
Tar,
Txt,
Wav,
Waveform,
Yaml,
Zip,
)
from .state import State, filter_states, find_state
from .storage import download_file, load_file, save_file, save_temp
6 changes: 6 additions & 0 deletions src/iokit/extensions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
__all__ = [
"Flac",
"Mp3",
"Ogg",
"Wav",
"Waveform",
"Dat",
"Encryption",
"Env",
Expand All @@ -11,6 +16,7 @@
"Zip",
]

from .audio import Flac, Mp3, Ogg, Wav, Waveform
from .dat import Dat
from .enc import Encryption
from .env import Env
Expand Down
106 changes: 106 additions & 0 deletions src/iokit/extensions/audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from dataclasses import dataclass
from io import BytesIO
from typing import Any

import soundfile
from numpy import float32
from numpy.typing import NDArray

from iokit.state import State


class AudioState(State, suffix=""):
def __init__(self, waveform: "Waveform", **kwargs: Any):
soundfile.write(
file=(target := BytesIO()),
data=waveform.wave,
samplerate=waveform.freq,
format=self._suffix,
)
super().__init__(data=target.getvalue(), **kwargs)

def load(self) -> "Waveform":
wave, freq = soundfile.read(self.data, always_2d=True)
return Waveform(wave=wave, freq=freq)


class Flac(AudioState, suffix="flac"):
pass


class Wav(AudioState, suffix="wav"):
pass


class Mp3(AudioState, suffix="mp3"):
pass


class Ogg(AudioState, suffix="ogg"):
pass


@dataclass
class Waveform:
wave: NDArray[float32]
freq: int

def __post_init__(self) -> None:
if self.wave.ndim == 1:
self.wave = self.wave[:, None]
if self.wave.ndim != 2:
msg = f"Waveform must be 1D or 2D array, but got {self.wave.ndim}D"
raise ValueError(msg)
if self.wave.shape[1] >= self.wave.shape[0]:
msg = (
"Waveform must have more frames than channels,"
f" but got {self.wave.shape[0]} frames and {self.wave.shape[1]} channels."
)
raise ValueError(msg)
if self.wave.dtype is not float32:
self.wave = self.wave.astype(float32)

@property
def channels(self) -> int:
return self.wave.shape[1]

@property
def duration(self) -> float:
return self.wave.shape[0] / self.freq

def copy(self) -> "Waveform":
return Waveform(self.wave.copy(), self.freq)

def _position(self, time: float) -> int:
return int(time * self.freq)

def cut(self, begin: float | None = None, end: float | None = None) -> "Waveform":
if begin is None and end is None:
return self.copy()
begin, end = begin or 0.0, end or self.duration
start, stop = self._position(begin), self._position(end)
if stop > self.wave.shape[0]:
stop = self.wave.shape[0]
return Waveform(self.wave[start:stop], self.freq)

def display(self):
from IPython.display import Audio, display

return display(Audio(self.wave.T, rate=self.freq))

def to_mono(self) -> "Waveform":
if self.channels == 1:
return self.copy()
return Waveform(self.wave.mean(axis=1), self.freq)

def to_flac(self, name: str, **kwargs: Any) -> Flac:
return Flac(waveform=self, name=name, **kwargs)

def to_wav(self, name: str, **kwargs: Any) -> Wav:
return Wav(waveform=self, name=name, **kwargs)

def to_mp3(self, name: str, **kwargs: Any) -> Mp3:
return Mp3(waveform=self, name=name, **kwargs)

def to_ogg(self, name: str, **kwargs: Any) -> Ogg:
return Ogg(waveform=self, name=name, **kwargs)
17 changes: 13 additions & 4 deletions src/iokit/extensions/enc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import struct
from collections.abc import Iterator
from hashlib import sha256
from typing import Any, Iterator
from typing import Any

from cryptography.exceptions import InvalidTag
from cryptography.hazmat.primitives.ciphers.algorithms import AES
from cryptography.hazmat.primitives.ciphers.base import Cipher
from cryptography.hazmat.primitives.ciphers.modes import GCM
Expand Down Expand Up @@ -41,16 +43,23 @@ def encrypt(data: bytes, password: bytes, salt: bytes) -> bytes:
padded = padder.update(data) + padder.finalize()
ct = encryptor.update(padded) + encryptor.finalize()
tag = encryptor.tag
return ct + tag
result = ct + tag
assert isinstance(result, bytes)
return result


def decrypt(data: bytes, password: bytes, salt: bytes) -> bytes:
key = _generate_key(password=password, salt=salt)
unpadder = PKCS7(128).unpadder()
decryptor = _cipher(key=key, salt=salt).decryptor()
ct, tag = data[:-16], data[-16:]
padded = decryptor.update(ct) + decryptor.finalize_with_tag(tag)
return unpadder.update(padded) + unpadder.finalize()
try:
padded = decryptor.update(ct) + decryptor.finalize_with_tag(tag)
except InvalidTag as exc:
raise ValueError("Decryption failed") from exc
result = unpadder.update(padded) + unpadder.finalize()
assert isinstance(result, bytes)
return result


def _pack_arrays(*arrays: bytes) -> bytes:
Expand Down
3 changes: 2 additions & 1 deletion src/iokit/extensions/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
]

import json
from collections.abc import Callable
from functools import lru_cache
from typing import Any, Callable
from typing import Any

from iokit.state import State

Expand Down
3 changes: 2 additions & 1 deletion src/iokit/extensions/jsonl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
"Jsonl",
]

from collections.abc import Iterable
from io import BytesIO
from typing import Any, Iterable
from typing import Any

from jsonlines import Reader, Writer

Expand Down
3 changes: 2 additions & 1 deletion src/iokit/extensions/tar.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import tarfile
from collections.abc import Iterable
from io import BytesIO
from typing import Any, Iterable
from typing import Any

from iokit.state import State
from iokit.tools.time import fromtimestamp
Expand Down
3 changes: 2 additions & 1 deletion src/iokit/extensions/zip.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import zipfile
from collections.abc import Iterable
from datetime import datetime
from io import BytesIO
from typing import Any, Iterable
from typing import Any

from iokit.state import State

Expand Down
16 changes: 8 additions & 8 deletions src/iokit/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
"find_state",
]

from collections.abc import Generator, Iterable
from contextlib import suppress
from datetime import datetime
from fnmatch import fnmatch
from io import BytesIO
from typing import Any, Generator, Iterable, Type
from typing import Any

from humanize import naturalsize
from typing_extensions import Self
Expand Down Expand Up @@ -86,9 +87,8 @@ def __init_subclass__(
if suffix is not None and suffixes is None:
suffixes = (suffix,)

if suffix is not None and suffixes is not None:
if suffix not in suffixes:
suffixes = (suffix, *suffixes)
if suffix is not None and suffixes is not None and suffix not in suffixes:
suffixes = (suffix, *suffixes)

if suffix is None or suffixes is None:
raise ValueError("State subclasses must define a suffix or suffixes")
Expand Down Expand Up @@ -128,7 +128,7 @@ def __repr__(self) -> str:
return f"{self.name} ({size})"

@classmethod
def _by_suffix(cls, suffix: str) -> Type[Self]:
def _by_suffix(cls, suffix: str) -> type[Self]:
if suffix in cls._suffixes:
return cls
for kls in cls.__subclasses__():
Expand All @@ -140,9 +140,9 @@ def cast(self) -> "State":
with suppress(ValueError):
klass = self._by_suffix(self.name.suffix)
state = klass.__new__(klass)
setattr(state, "_data", self.data)
setattr(state, "_name", self.name)
setattr(state, "_time", self.time)
state._data = self.data
state._name = self.name
state._time = self.time
return state
return self

Expand Down
2 changes: 1 addition & 1 deletion src/iokit/storage/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
"save_temp",
]
import tempfile
from collections.abc import Generator
from contextlib import contextmanager
from pathlib import Path
from typing import Generator

from iokit.state import State
from iokit.tools.time import fromtimestamp
Expand Down
Empty file added tests/__init__.py
Empty file.
Loading