Skip to content

Commit

Permalink
refactor: state buffer (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
rilshok authored Nov 29, 2024
2 parents f7e0e91 + 8e1773e commit 45b8e10
Show file tree
Hide file tree
Showing 22 changed files with 126 additions and 110 deletions.
36 changes: 21 additions & 15 deletions src/iokit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,36 @@
__version__ = "0.2.0"
__all__ = [
"Dat",
"Encryption",
"decrypt",
"download_file",
"Enc",
"encrypt",
"Env",
"filter_states",
"find_state",
"Flac",
"Gzip",
"Json",
"Jsonl",
"Tar",
"Txt",
"Yaml",
"Zip",
"load_file",
"Mp3",
"Npy",
"Wav",
"Waveform",
"Flac",
"Ogg",
"State",
"filter_states",
"find_state",
"download_file",
"load_file",
"save_file",
"save_temp",
"SecretState",
"State",
"Tar",
"Txt",
"Waveform",
"Wav",
"Yaml",
"Zip",
]
__version__ = "0.2.0"

from .extensions import (
Dat,
Encryption,
Enc,
Env,
Flac,
Gzip,
Expand All @@ -36,12 +39,15 @@
Mp3,
Npy,
Ogg,
SecretState,
Tar,
Txt,
Wav,
Waveform,
Yaml,
Zip,
decrypt,
encrypt,
)
from .state import State, filter_states, find_state
from .storage import download_file, load_file, save_file, save_temp
17 changes: 10 additions & 7 deletions src/iokit/extensions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,28 @@
__all__ = [
"Flac",
"Mp3",
"Ogg",
"Wav",
"Waveform",
"Dat",
"Encryption",
"Enc",
"Env",
"Flac",
"Gzip",
"Json",
"Jsonl",
"Mp3",
"Npy",
"Ogg",
"SecretState",
"Tar",
"Txt",
"Wav",
"Waveform",
"Yaml",
"Zip",
"decrypt",
"encrypt",
]

from .audio import Flac, Mp3, Ogg, Wav, Waveform
from .dat import Dat
from .enc import Encryption
from .enc import Enc, SecretState, decrypt, encrypt
from .env import Env
from .gz import Gzip
from .json import Json
Expand Down
19 changes: 11 additions & 8 deletions src/iokit/extensions/audio.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
__all__ = ["Waveform", "Flac", "Wav", "Mp3", "Ogg"]

from dataclasses import dataclass
from io import BytesIO
from typing import Any
Expand All @@ -11,16 +13,17 @@

class AudioState(State, suffix=""):
def __init__(self, waveform: "Waveform", **kwargs: Any):
soundfile.write(
file=(target := BytesIO()),
data=waveform.wave,
samplerate=waveform.freq,
format=self._suffix,
)
super().__init__(data=target.getvalue(), **kwargs)
with BytesIO() as buffer:
soundfile.write(
file=buffer,
data=waveform.wave,
samplerate=waveform.freq,
format=self._suffix,
)
super().__init__(data=buffer.getvalue(), **kwargs)

def load(self) -> "Waveform":
wave, freq = soundfile.read(self.data, always_2d=True)
wave, freq = soundfile.read(self.buffer, always_2d=True)
return Waveform(wave=wave, freq=freq)


Expand Down
10 changes: 4 additions & 6 deletions src/iokit/extensions/dat.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
__all__ = [
"Dat",
]
__all__ = ["Dat"]

from typing import Any

from iokit.state import Payload, State
from iokit.state import State


class Dat(State, suffix="dat"):
def __init__(self, data: Payload, **kwargs: Any):
def __init__(self, data: bytes, **kwargs: Any):
super().__init__(data=data, **kwargs)

def load(self) -> bytes:
return self._data.getvalue()
return self.data
8 changes: 5 additions & 3 deletions src/iokit/extensions/enc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
__all__ = ["SecretState", "Enc", "encrypt", "decrypt"]

import struct
from collections.abc import Iterator
from hashlib import sha256
Expand Down Expand Up @@ -91,12 +93,12 @@ def __repr__(self) -> str:

@classmethod
def pack(cls, state: State, password: bytes | str, salt: bytes | str = b"42") -> Self:
payload = _pack_arrays(str(state.name).encode("utf-8"), state.data.getvalue())
payload = _pack_arrays(str(state.name).encode("utf-8"), state.data)
data = encrypt(data=payload, password=_to_bytes(password), salt=_to_bytes(salt))
return cls(data=data)


class Encryption(State, suffix="enc"):
class Enc(State, suffix="enc"):
def __init__(
self,
state: State,
Expand All @@ -112,4 +114,4 @@ def __init__(
super().__init__(data=data, name=name, **kwargs)

def load(self) -> SecretState:
return SecretState(data=self.data.getvalue())
return SecretState(data=self.data)
6 changes: 4 additions & 2 deletions src/iokit/extensions/env.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
__all__ = ["Env"]

from io import StringIO
from pathlib import Path
from tempfile import TemporaryDirectory
Expand All @@ -20,5 +22,5 @@ def __init__(self, data: dict[str, str], **kwargs: Any):
super().__init__(data=data_bytes, **kwargs)

def load(self) -> dict[str, str | None]:
stream = StringIO(self.data.getvalue().decode())
return dict(dotenv.dotenv_values(stream=stream))
with StringIO(self.data.decode()) as stream:
return dict(dotenv.dotenv_values(stream=stream))
19 changes: 10 additions & 9 deletions src/iokit/extensions/gz.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
__all__ = ["Gzip"]

import gzip
from io import BytesIO
from typing import Any
Expand All @@ -7,14 +9,13 @@

class Gzip(State, suffix="gz"):
def __init__(self, state: State, *, compression: int = 1, **kwargs: Any):
data = BytesIO()
gzip_file = gzip.GzipFile(fileobj=data, mode="wb", compresslevel=compression, mtime=0)
with gzip_file as gzip_buffer:
gzip_buffer.write(state.data.getvalue())
super().__init__(data=data, name=state.name, **kwargs)
with BytesIO() as buffer:
gzip_file = gzip.GzipFile(fileobj=buffer, mode="wb", compresslevel=compression, mtime=0)
with gzip_file as gzip_buffer:
gzip_buffer.write(state.data)
super().__init__(data=buffer.getvalue(), name=state.name, **kwargs)

def load(self) -> State:
gzip_file = gzip.GzipFile(fileobj=self.data, mode="rb")
with gzip_file as file:
data = file.read()
return State(data=data, name=str(self.name).removesuffix(".gz")).cast()
# gzip_file =
with gzip.GzipFile(fileobj=self.buffer, mode="rb") as file:
return State(data=file.read(), name=str(self.name).removesuffix(".gz")).cast()
6 changes: 2 additions & 4 deletions src/iokit/extensions/json.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
__all__ = [
"Json",
]
__all__ = ["Json"]

import json
from collections.abc import Callable
Expand Down Expand Up @@ -42,4 +40,4 @@ def __init__(
super().__init__(data=data_, **kwargs)

def load(self) -> Any:
return json.load(self.data)
return json.load(self.buffer)
18 changes: 8 additions & 10 deletions src/iokit/extensions/jsonl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
__all__ = [
"Jsonl",
]
__all__ = ["Jsonl"]

from collections.abc import Iterable
from io import BytesIO
Expand All @@ -23,13 +21,13 @@ def __init__(
allow_nan: bool = False,
**kwargs: Any,
):
buffer = BytesIO()
dumps = json_dumps(compact=compact, ensure_ascii=ensure_ascii, allow_nan=allow_nan)
with Writer(buffer, compact=compact, sort_keys=False, dumps=dumps) as writer:
for item in sequence:
writer.write(item)
super().__init__(data=buffer, **kwargs)
with BytesIO() as buffer:
dumps = json_dumps(compact=compact, ensure_ascii=ensure_ascii, allow_nan=allow_nan)
with Writer(buffer, compact=compact, sort_keys=False, dumps=dumps) as writer:
for item in sequence:
writer.write(item)
super().__init__(data=buffer.getvalue(), **kwargs)

def load(self) -> list[Any]:
with Reader(self.data) as reader:
with Reader(self.buffer) as reader:
return list(reader)
4 changes: 3 additions & 1 deletion src/iokit/extensions/npy.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
__all__ = ["Npy"]

from io import BytesIO
from typing import Any

Expand All @@ -14,4 +16,4 @@ def __init__(self, array: NDArray[Any], **kwargs: Any) -> None:
super().__init__(data=buffer.getvalue(), **kwargs)

def load(self) -> NDArray[Any]:
return np.load(self.data, allow_pickle=False, fix_imports=False)
return np.load(self.buffer, allow_pickle=False, fix_imports=False)
20 changes: 11 additions & 9 deletions src/iokit/extensions/tar.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
__all__ = ["Tar"]

import tarfile
from collections.abc import Iterable
from io import BytesIO
Expand All @@ -9,19 +11,19 @@

class Tar(State, suffix="tar"):
def __init__(self, states: Iterable[State], **kwargs: Any):
buffer = BytesIO()
with tarfile.open(fileobj=buffer, mode="w") as tar_buffer:
for state in states:
file_data = tarfile.TarInfo(name=str(state.name))
file_data.size = state.size
file_data.mtime = int(state.time.timestamp())
tar_buffer.addfile(fileobj=state.data, tarinfo=file_data)
with BytesIO() as buffer:
with tarfile.open(fileobj=buffer, mode="w") as tar_buffer:
for state in states:
file_data = tarfile.TarInfo(name=str(state.name))
file_data.size = state.size
file_data.mtime = int(state.time.timestamp())
tar_buffer.addfile(fileobj=state.buffer, tarinfo=file_data)

super().__init__(data=buffer, **kwargs)
super().__init__(data=buffer.getvalue(), **kwargs)

def load(self) -> list[State]:
states: list[State] = []
with tarfile.open(fileobj=self.data, mode="r") as tar_buffer:
with tarfile.open(fileobj=self.buffer, mode="r") as tar_buffer:
assert tar_buffer is not None
for member in tar_buffer.getmembers():
if not member.isfile():
Expand Down
6 changes: 3 additions & 3 deletions src/iokit/extensions/txt.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
__all__ = ["Txt"]

from typing import Any

from iokit.state import State


class Txt(State, suffix="txt"):
def __init__(self, data: str, **kwargs: Any):
if not isinstance(data, str):
raise TypeError(f"Expected str, got {type(data).__name__}")
super().__init__(data=data.encode("utf-8"), **kwargs)

def load(self) -> str:
return self.data.getvalue().decode("utf-8")
return self.data.decode("utf-8")
6 changes: 2 additions & 4 deletions src/iokit/extensions/yaml.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
__all__ = [
"Yaml",
]
__all__ = ["Yaml"]

from typing import Any

Expand All @@ -15,4 +13,4 @@ def __init__(self, data: Any, **kwargs: Any):
super().__init__(data=data, **kwargs)

def load(self) -> Any:
return yaml.safe_load(self.data)
return yaml.safe_load(self.buffer)
14 changes: 8 additions & 6 deletions src/iokit/extensions/zip.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
__all__ = ["Zip"]

import zipfile
from collections.abc import Iterable
from datetime import datetime
Expand All @@ -9,16 +11,16 @@

class Zip(State, suffix="zip"):
def __init__(self, states: Iterable[State], **kwargs: Any):
buffer = BytesIO()
with zipfile.ZipFile(buffer, mode="w") as zip_buffer:
for state in states:
zip_buffer.writestr(str(state.name), data=state.data.getvalue())
with BytesIO() as buffer:
with zipfile.ZipFile(buffer, mode="w") as zip_buffer:
for state in states:
zip_buffer.writestr(str(state.name), data=state.data)

super().__init__(data=buffer, **kwargs)
super().__init__(data=buffer.getvalue(), **kwargs)

def load(self) -> list[State]:
states: list[State] = []
with zipfile.ZipFile(self.data, mode="r") as zip_buffer:
with zipfile.ZipFile(self.buffer, mode="r") as zip_buffer:
for file in zip_buffer.namelist():
with zip_buffer.open(file) as member_buffer:
state = State(
Expand Down
Loading

0 comments on commit 45b8e10

Please sign in to comment.