Skip to content

Commit

Permalink
feat: Using singleton mode
Browse files Browse the repository at this point in the history
  • Loading branch information
KenyonY committed Apr 27, 2024
1 parent f810636 commit 7b096cb
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 24 deletions.
4 changes: 3 additions & 1 deletion flaxkv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@

from .core import LevelDBDict, LMDBDict, RemoteDBDict

__version__ = "0.2.7.1"
__version__ = "0.2.8"

__all__ = [
"FlaxKV",
"Flaxkv",
"dbdict",
"dictdb",
"LMDBDict",
Expand Down Expand Up @@ -79,3 +80,4 @@ def FlaxKV(

dbdict = FlaxKV
dictdb = FlaxKV
Flaxkv = FlaxKV
64 changes: 48 additions & 16 deletions flaxkv/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,17 @@ def update(self, d: dict):
self._buffered_count = 0
self.write_immediately()

def from_dict(self, d: dict, clear=False):
"""
Updates the buffer with the given dictionary.
Args:
d (dict): A dictionary of key-value pairs to update.
"""
if clear:
self.clear(wait=True)
self.update(d)

# @class_measure_time()
def _write_buffer_to_db(
self,
Expand Down Expand Up @@ -831,6 +842,14 @@ class LMDBDict(BaseDBDict):
value: int, float, bool, str, list, dict, and np.ndarray,
"""

_instances = {}

def __new__(cls, db_name: str, root_path: str, rebuild=False, **kwargs):
name = db_name + str(root_path)
if name not in cls._instances:
cls._instances[name] = super().__new__(cls)
return cls._instances[name]

def __init__(
self,
db_name: str,
Expand All @@ -839,15 +858,17 @@ def __init__(
rebuild=False,
**kwargs,
):
super().__init__(
"lmdb",
root_path,
db_name,
max_dbs=1,
map_size=map_size,
rebuild=rebuild,
**kwargs,
)
if not hasattr(self, '_initialized'):
super().__init__(
"lmdb",
root_path,
db_name,
max_dbs=1,
map_size=map_size,
rebuild=rebuild,
**kwargs,
)
self._initialized = True

def _iter_db_view(self, view, include_key=True, include_value=True):
"""
Expand Down Expand Up @@ -899,14 +920,25 @@ class LevelDBDict(BaseDBDict):
value: int, float, bool, str, list, dict and np.ndarray,
"""

_instances = {}

def __new__(cls, db_name: str, root_path: str, rebuild=False, **kwargs):
name = db_name + str(root_path)
if name not in cls._instances:
cls._instances[name] = super().__new__(cls)
return cls._instances[name]

def __init__(self, db_name: str, root_path: str, rebuild=False, **kwargs):
super().__init__(
"leveldb",
root_path_or_url=root_path,
db_name=db_name,
rebuild=rebuild,
**kwargs,
)
if not hasattr(self, '_initialized'):
super().__init__(
"leveldb",
root_path_or_url=root_path,
db_name=db_name,
rebuild=rebuild,
**kwargs,
)

self._initialized = True

def _iter_db_view(self, view, include_key=True, include_value=True):
"""
Expand Down
2 changes: 1 addition & 1 deletion flaxkv/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import time
import traceback
from pathlib import Path
from typing import Dict
from typing import Dict, Literal
from uuid import uuid4

import msgspec
Expand Down
4 changes: 2 additions & 2 deletions flaxkv/serve/route.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ async def connect(data: AttachRequest) -> Stream:
db_name=data.db_name, backend=data.backend, rebuild=data.rebuild
)
elif data.rebuild:
db.destroy()
db.clear(wait=True)
_db_manager.set_db(
db_name=data.db_name, backend=data.backend, rebuild=False
db_name=data.db_name, backend=data.backend, rebuild=data.rebuild
)

async def stream(client: dict) -> AsyncGenerator:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_local_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def temp_db(request):
db = FlaxKV(**request.param)

yield db
db.destroy()
db.clear(wait=True)


def test_set_get_write(temp_db):
Expand Down
11 changes: 8 additions & 3 deletions tests/test_remote_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def start_server():
@pytest.fixture(
scope="function",
params=[
# dict(db_name="test_server_db", backend="lmdb", rebuild=True, cache=False),
dict(db_name="test_server_db", backend="lmdb", rebuild=True, cache=False),
dict(db_name="test_server_db", backend="leveldb", rebuild=True, cache=False),
# dict(db_name="test_server_db", backend="lmdb", rebuild=True, cache=True),
dict(db_name="test_server_db", backend="lmdb", rebuild=True, cache=True),
dict(db_name="test_server_db", backend="leveldb", rebuild=True, cache=True),
],
)
Expand All @@ -57,8 +57,13 @@ def temp_db(request):
db.close(wait=True)


from test_local_db import ( # test_large_value,; test_list_keys_values_items,; test_set_get_write,; test_setdefault,; test_update,
from test_local_db import (
test_buffered_writing,
test_key_checks_and_deletion,
test_large_value,
test_list_keys_values_items,
test_numpy_array,
test_set_get_write,
test_setdefault,
test_update,
)

0 comments on commit 7b096cb

Please sign in to comment.