Skip to content

Commit

Permalink
🐛 Bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
KenyonY committed Dec 5, 2023
1 parent 6d250bb commit ba4698c
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 16 deletions.
11 changes: 11 additions & 0 deletions flaxkv/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from rich import print
from rich.text import Text

from .pack import encode

ENABLED_MEASURE_TIME_DECORATOR = True


Expand Down Expand Up @@ -35,3 +37,12 @@ def wrapper(self, *args, **kwargs):
return wrapper

return decorate


def msg_encoder(func):
@wraps(func)
async def wrapper(*args, **kwargs):
result = await func(*args, **kwargs)
return encode(result)

return wrapper
24 changes: 13 additions & 11 deletions flaxkv/serve/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import msgspec
from litestar import Litestar, MediaType, Request, get, post

from ..decorators import msg_encoder
from ..pack import decode, decode_key, encode
from .interface import (
AttachRequest,
Expand Down Expand Up @@ -40,8 +41,6 @@ async def set_value(data: SetRequest) -> dict:
db = db_manager.get(data.db_name)
if db is None:
return {"success": False, "info": "db not found"}
print(data.key, data.value)
print(encode(data.key), encode(data.value))
db[encode(data.key)] = encode(data.value)
return {"success": True}

Expand Down Expand Up @@ -79,7 +78,7 @@ async def update_raw(db_name: str, request: Request) -> dict:
async def get_raw(db_name: str, request: Request) -> bytes:
db = db_manager.get(db_name)
if db is None:
return encode({"success": False, "info": "db not found"})
raise ValueError("db not found")
key = await request.body()
value = db.get(key)
if value is None:
Expand All @@ -98,12 +97,13 @@ async def contains(db_name: str, request: Request) -> bytes:


@post("/pop")
@msg_encoder
async def pop(data: PopKeyRequest) -> dict:
db = db_manager.get(data.db_name)
if db is None:
return {"success": False, "info": "db not found"}
try:
return {"success": True, "value": db.pop(encode(data.key), None)}
return {"success": True, "data": db.pop(encode(data.key), None)}

except Exception as e:
traceback.print_exc()
Expand All @@ -116,31 +116,33 @@ async def get_keys(db_name: str) -> dict:
if db is None:
return {"success": False, "info": "db not found"}
try:
return {"keys": db.keys()}
return {"success": True, "data": db.keys()}
except Exception as e:
traceback.print_exc()
return {"success": False, "info": str(e)}


@get("/values")
async def get_values(db_name: str) -> dict:
@get("/values", media_type=MediaType.TEXT)
@msg_encoder
async def get_values(db_name: str) -> bytes:
db = db_manager.get(db_name)
if db is None:
return {"success": False, "info": "db not found"}
try:
return {"values": db.values()}
return {"success": True, "data": db.values()}
except Exception as e:
traceback.print_exc()
return {"success": False, "info": str(e)}


@get("/items")
async def get_items(db_name: str) -> dict:
@get("/items", media_type=MediaType.TEXT)
@msg_encoder
async def get_items(db_name: str) -> bytes:
db = db_manager.get(db_name)
if db is None:
return {"success": False, "info": "db not found"}
try:
return dict(db.items())
return {"success": True, "data": dict(db.items())}
except Exception as e:
traceback.print_exc()
return {"success": False, "info": str(e)}
Expand Down
13 changes: 8 additions & 5 deletions flaxkv/serve/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ def pop(self, key, default=None):
url = f"{self._url}/pop"
data = {"key": key, "db_name": self._db_name}
response = self._client.post(url, json=data)
result = response.json()
result = decode(response.read())
if result["success"]:
value = result["value"]
value = result["data"]
if value is None:
return default
return value
Expand All @@ -59,7 +59,7 @@ def pop(self, key, default=None):
def _items_dict(self):
url = f"{self._url}/items?db_name={self._db_name}"
response = self._client.get(url)
return response.json()
return decode(response.read())

def items(self):
return self._items_dict().items()
Expand All @@ -70,12 +70,12 @@ def __repr__(self):
def keys(self):
url = f"{self._url}/keys?db_name={self._db_name}"
response = self._client.get(url)
return response.json()["keys"]
return response.json()['data']

def values(self):
url = f"{self._url}/values?db_name={self._db_name}"
response = self._client.get(url)
return response.json()["values"]
return decode(response.read())['data']

def __contains__(self, key):
url = f"{self._url}/contains?db_name={self._db_name}"
Expand All @@ -97,3 +97,6 @@ def __getitem__(self, key):
if value is None:
raise KeyError(f"Key `{key}` not found in the database.")
return value

def __len__(self):
return len(self.keys())

0 comments on commit ba4698c

Please sign in to comment.