Skip to content

Commit

Permalink
switch to psycopg 3
Browse files Browse the repository at this point in the history
  • Loading branch information
kheina committed Oct 4, 2024
1 parent 8c68b2e commit c0bea54
Show file tree
Hide file tree
Showing 27 changed files with 203 additions and 468 deletions.
2 changes: 1 addition & 1 deletion account/account.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from re import IGNORECASE
from re import compile as re_compile

from psycopg2.errors import UniqueViolation
from psycopg.errors import UniqueViolation

from authenticator.authenticator import Authenticator
from authenticator.models import LoginResponse, TokenResponse
Expand Down
2 changes: 1 addition & 1 deletion account/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

@app.on_event('shutdown')
async def shutdown() :
account.close()
await account.close()


@app.post('/login', response_model=LoginResponse)
Expand Down
24 changes: 9 additions & 15 deletions authenticator/authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from cache import AsyncLRU
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey
from psycopg2.errors import UniqueViolation
from psycopg.errors import UniqueViolation

from shared import logging
from shared.base64 import b64decode, b64encode
Expand Down Expand Up @@ -268,12 +268,6 @@ async def generate_token(self, user_id: int, token_data: dict) -> TokenResponse
)


async def logout(self, guid: RefId) :
# since this endpoint is behind user.authenticated, we already know that the
# token exists and all the information is correct. we just need to delete it.
await Authenticator.KVS.remove_async(guid)


async def fetchPublicKey(self, key_id, algorithm: Optional[AuthAlgorithm] = None) -> PublicKeyResponse :
algo = algorithm.name if algorithm else self._token_algorithm
lookup_key = (algo, key_id)
Expand All @@ -284,7 +278,7 @@ async def fetchPublicKey(self, key_id, algorithm: Optional[AuthAlgorithm] = None
public_key = self._public_keyring[lookup_key]

else :
data: tuple[memoryview, memoryview, datetime, datetime] = await self.query_async("""
data: tuple[bytes, bytes, datetime, datetime] = await self.query_async("""
SELECT public_key, signature, issued, expires
FROM kheina.auth.token_keys
WHERE algorithm = %s AND key_id = %s;
Expand Down Expand Up @@ -336,7 +330,7 @@ async def login(self, email: str, password: str, token_data:Dict[str, Any]={ })
try :
email_dict: Dict[str, str] = self._validateEmail(email)
email_hash = self._hash_email(email)
data: tuple[int, memoryview, int, str, str, bool] = await self.query_async("""
data: Optional[tuple[int, bytes, int, str, str, bool]] = await self.query_async("""
SELECT
user_login.user_id,
user_login.password,
Expand All @@ -357,7 +351,7 @@ async def login(self, email: str, password: str, token_data:Dict[str, Any]={ })
raise Unauthorized('login failed.')

user_id, pwhash, secret, handle, name, mod = data
password_hash = pwhash.tobytes().decode()
password_hash = pwhash.decode()

if not self._argon2.verify(password_hash, password.encode() + self._secrets[secret]) :
raise Unauthorized('login failed.')
Expand Down Expand Up @@ -402,7 +396,7 @@ async def login(self, email: str, password: str, token_data:Dict[str, Any]={ })


async def createBot(self, user: KhUser, bot_type: BotType) -> BotCreateResponse :
if type(bot_type) != BotType :
if type(bot_type) is not BotType :
# this should never run, thanks to pydantic/fastapi. just being extra careful.
raise BadRequest('bot_type must be a BotType value.')

Expand Down Expand Up @@ -464,7 +458,7 @@ async def botLogin(self, token: str) -> LoginResponse :
bot_type: BotType

try :
data: Tuple[int, memoryview, int, int] = await self.query_async("""
data: Tuple[int, bytes, int, int] = await self.query_async("""
SELECT
bot_login.user_id,
bot_login.password,
Expand All @@ -482,7 +476,7 @@ async def botLogin(self, token: str) -> LoginResponse :

bot_type_id: int
user_id, pw, secret, bot_type_id = data
password_hash = pw.tobytes().decode()
password_hash = pw.decode()
bot_type = await bot_type_map.get(bot_type_id)

if user_id != bot_login.user_id :
Expand Down Expand Up @@ -548,7 +542,7 @@ async def changePassword(self, email: str, old_password: str, new_password: str)
try :

email_hash = self._hash_email(email)
data: tuple[memoryview, int] = await self.query_async("""
data: tuple[bytes, int] = await self.query_async("""
SELECT password, secret
FROM kheina.auth.user_login
INNER JOIN kheina.public.users
Expand All @@ -563,7 +557,7 @@ async def changePassword(self, email: str, old_password: str, new_password: str)
raise Unauthorized('password change failed.')

pwhash, secret = data
password_hash = pwhash.tobytes()
password_hash = pwhash

if not self._argon2.verify(password_hash.decode(), old_password.encode() + self._secrets[secret]) :
raise Unauthorized('password change failed.')
Expand Down
4 changes: 2 additions & 2 deletions avro_schema_repository/schema_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async def getSchema(self, fingerprint: bytes) -> bytes :
"""
fp: int = int_from_bytes(fingerprint)

data: List[memoryview] = await self.query_async("""
data: List[bytes] = await self.query_async("""
SELECT schema
FROM kheina.public.avro_schemas
WHERE fingerprint = %s;
Expand All @@ -46,7 +46,7 @@ async def getSchema(self, fingerprint: bytes) -> bytes :
if not data :
raise NotFound('no data was found for the provided schema fingerprint.')

return data[0].tobytes()
return data[0]


@HttpErrorHandler('saving schema')
Expand Down
2 changes: 1 addition & 1 deletion configs/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ async def getFunding(self) -> int :
@HttpErrorHandler('retrieving config')
@AerospikeCache('kheina', 'configs', '{config}', _kvs=KVS)
async def getConfig[T: BaseModel](self, config: ConfigType, _: Type[T]) -> T :
data: Optional[tuple[memoryview]] = await self.query_async("""
data: Optional[tuple[bytes]] = await self.query_async("""
SELECT bytes
FROM kheina.public.configs
WHERE key = %s;
Expand Down
2 changes: 1 addition & 1 deletion configs/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ async def startup() :

@app.on_event('shutdown')
async def shutdown() :
configs.close()
await configs.close()


################################################## INTERNAL ##################################################
Expand Down
2 changes: 1 addition & 1 deletion deployment.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ docker push us-central1-docker.pkg.dev/kheinacom/fuzzly-repo/fuzzly-backend:$(gi
connect to gke from kubectl
https://cloud.google.com/kubernetes-engine/docs/deploy-app-cluster#get_authentication_credentials_for_the_cluster
```sh
gcloud container clusters get-credentials fuzzly-backend \
gcloud container clusters get-credentials fuzzly-backend \
--location us-central1
```

Expand Down
2 changes: 1 addition & 1 deletion emojis/emoji.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Self

from psycopg2.errors import UniqueViolation
from psycopg.errors import UniqueViolation

from shared.auth import KhUser
from shared.exceptions.http_error import BadRequest, Conflict, Forbidden, HttpErrorHandler, NotFound
Expand Down
69 changes: 34 additions & 35 deletions init.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def nukeCache() -> None :
'--file',
default='',
)
def execSql(unlock: bool = False, file: str = '') -> None :
async def execSql(unlock: bool = False, file: str = '') -> None :
"""
connects to the database and runs all files stored under the db folder
folders under db are sorted numberically and run in descending order
Expand All @@ -102,50 +102,49 @@ def execSql(unlock: bool = False, file: str = '') -> None :
nukeCache()

sql = SqlInterface()
with sql.pool.conn() as conn :
cur = conn.cursor()
async with sql.pool.connection() as conn :
async with conn.cursor() as cur :
sqllock = None
if not unlock and isfile('sql.lock') :
sqllock = int(open('sql.lock').read().strip())
click.echo(f'==> sql.lock: {sqllock}')

if file :
if not isfile(file) :
return

sqllock = None
if not unlock and isfile('sql.lock') :
sqllock = int(open('sql.lock').read().strip())
click.echo(f'==> sql.lock: {sqllock}')
if not file.endswith('.sql') :
return

if file :
if not isfile(file) :
return
with open(file) as f :
click.echo(f'==> exec: {file}')
await cur.execute(f.read()) # type: ignore

if not file.endswith('.sql') :
await conn.commit()
return

with open(file) as f :
click.echo(f'==> exec: {file}')
cur.execute(f.read())

conn.commit()
return

dirs = sorted(int(i) for i in listdir('db') if isdir(f'db/{i}') and i == str(isint(i)))
dir = ""
for dir in dirs :
if sqllock and sqllock >= dir :
continue

files = [join('db', str(dir), file) for file in sorted(listdir(join('db', str(dir))))]
for file in files :
if not isfile(file) :
dirs = sorted(int(i) for i in listdir('db') if isdir(f'db/{i}') and i == str(isint(i)))
dir = ""
for dir in dirs :
if sqllock and sqllock >= dir :
continue

if not file.endswith('.sql') :
continue
files = [join('db', str(dir), file) for file in sorted(listdir(join('db', str(dir))))]
for file in files :
if not isfile(file) :
continue

with open(file) as f :
click.echo(f'==> exec: {file}')
cur.execute(f.read())
if not file.endswith('.sql') :
continue

with open(file) as f :
click.echo(f'==> exec: {file}')
await cur.execute(f.read()) # type: ignore

conn.commit()
await conn.commit()

with open('sql.lock', 'w') as f :
f.write(str(dir))
with open('sql.lock', 'w') as f :
f.write(str(dir))


EmojiFontURL = r'https://github.com/PoomSmart/EmojiFonts/releases/download/15.1.0/AppleColorEmoji-HD.ttc'
Expand Down
4 changes: 4 additions & 0 deletions k8s.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ spec:
valueFrom:
fieldRef:
fieldPath: status.podIP
- name: pod_name
valueFrom:
fieldRef:
fieldPath: metadata.name
- name: pod_host
value: "*.fuzz.ly"
- name: kh_aes
Expand Down
6 changes: 0 additions & 6 deletions posts/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,6 @@ class TagGroups(Dict[TagGroupPortable, List[str]]) :

def _thumbhash_converter(value: Any) -> Optional[str] :
if value :
if isinstance(value, memoryview) :
value = bytes(value)

if isinstance(value, bytes) :
return b64encode(value).decode()

Expand Down Expand Up @@ -137,9 +134,6 @@ class SearchResults(BaseModel) :

def _bytes_converter(value: Any) -> Optional[bytes] :
if value :
if isinstance(value, memoryview) :
value = bytes(value)

if isinstance(value, bytes) :
return value

Expand Down
2 changes: 1 addition & 1 deletion posts/posts.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,7 +824,7 @@ async def fetchDrafts(self, user: KhUser) -> List[Post] :
Value(await privacy_map.get(Privacy.draft)),
),
).order(
Field('posts', 'created'),
Field('posts', 'updated'),
Order.descending_nulls_first,
)

Expand Down
2 changes: 1 addition & 1 deletion posts/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ async def _vote(self: Self, user: KhUser, post_id: PostId, upvote: Optional[bool
),
)

transaction.commit()
await transaction.commit()

score: InternalScore = InternalScore(
up = up,
Expand Down
8 changes: 4 additions & 4 deletions reporting/mod_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ async def create(self: Self, user: KhUser, response: str, action: ModAction) ->
await kvs.put_async(f'ban={iban.user_id}', iban)
await kvs.put_async(f'user_id={user_id}', iaction)

t.commit()
await t.commit()

await reporting_kvs.put_async(str(ireport.report_id), ireport)
await kvs.put_async(f'report_id={iaction.report_id}', iaction)
Expand All @@ -288,7 +288,7 @@ async def create(self: Self, user: KhUser, response: str, action: ModAction) ->

@AerospikeCache('kheina', 'actions', 'report_id={report_id}', _kvs=kvs)
async def _read(self: Self, report_id: int) -> Optional[InternalModAction] :
data: Optional[tuple[int, int, Optional[int], Optional[int], Optional[int], datetime, Optional[datetime], str, int, memoryview]] = await self.query_async(Query(InternalModAction.__table_name__).select(
data: Optional[tuple[int, int, Optional[int], Optional[int], Optional[int], datetime, Optional[datetime], str, int, bytes]] = await self.query_async(Query(InternalModAction.__table_name__).select(
Field('mod_actions', 'action_id'),
Field('mod_actions', 'report_id'),
Field('mod_actions', 'post_id'),
Expand Down Expand Up @@ -417,7 +417,7 @@ async def bans(self: Self, user: KhUser, handle: str) -> list[Ban] :

@AerospikeCache('kheina', 'actions', 'active_action={post_id}', _kvs=kvs)
async def _active_action(self: Self, post_id: PostId) -> Optional[InternalModAction] :
data: Optional[tuple[int, int, Optional[int], Optional[int], Optional[int], datetime, Optional[datetime], str, int, memoryview]] = await self.query_async(Query(InternalModAction.__table_name__).select(
data: Optional[tuple[int, int, Optional[int], Optional[int], Optional[int], datetime, Optional[datetime], str, int, bytes]] = await self.query_async(Query(InternalModAction.__table_name__).select(
Field('mod_actions', 'action_id'),
Field('mod_actions', 'report_id'),
Field('mod_actions', 'post_id'),
Expand Down Expand Up @@ -455,7 +455,7 @@ async def _active_action(self: Self, post_id: PostId) -> Optional[InternalModAct

@AerospikeCache('kheina', 'actions', 'active_action={post_id}', _kvs=kvs)
async def _actions(self: Self, post_id: PostId) -> list[InternalModAction] :
data: list[tuple[int, int, Optional[int], Optional[int], Optional[int], datetime, Optional[datetime], str, int, memoryview]] = await self.query_async(Query(InternalModAction.__table_name__).select(
data: list[tuple[int, int, Optional[int], Optional[int], Optional[int], datetime, Optional[datetime], str, int, bytes]] = await self.query_async(Query(InternalModAction.__table_name__).select(
Field('mod_actions', 'action_id'),
Field('mod_actions', 'report_id'),
Field('mod_actions', 'post_id'),
Expand Down
2 changes: 1 addition & 1 deletion reporting/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ async def close_response(self: Self, user: KhUser, queue_id: int, response: str)
ireport.assignee = user.user_id
ireport = await t.update(ireport)

t.commit()
await t.commit()

await kvs.put_async(str(ireport.report_id), ireport)
return await self.report(user, ireport)
Expand Down
4 changes: 4 additions & 0 deletions reporting/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,7 @@ async def v1Bans(req: Request, handle: str) -> list[Ban] :
app.include_router(actionsRouter)
app.include_router(queueRouter)
app.include_router(bansRouter)

@app.on_event('shutdown')
async def shutdown() :
await reporting.close()
4 changes: 3 additions & 1 deletion requirements.lock
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ multidict==6.0.5
numpy==2.0.0
packaging==24.1
patreon==0.5.0
psycopg2-binary==2.9.9
psycopg==3.2.3
psycopg-binary==3.2.3
psycopg-pool==3.2.3
pycparser==2.22
pycryptodome==3.20.0
pydantic==1.10.17
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pydantic_core~=2.18.4
python-multipart~=0.0.9
PyExifTool~=0.5.6
Wand~=0.6.13
psycopg2-binary~=2.9.3
psycopg[binary, pool]~=3.2.3
scipy~=1.14.0
ujson>=3.0.0
uvicorn>=0.11.6
Expand Down
Loading

0 comments on commit c0bea54

Please sign in to comment.