Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

257 user model refactor the config handling #433

Merged
merged 4 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 14 additions & 38 deletions src/app.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from contextlib import asynccontextmanager
import os

from enum import Enum, auto
import uvicorn # type: ignore
from pathlib import Path
from fastapi import (
Expand All @@ -26,7 +25,7 @@

from .config import app_config
from .util import mquery_version
from .db import Database
from .db import Database, UserRole
from .lib.yaraparse import parse_yara
from .plugins import PluginManager
from .lib.ursadb import UrsaDb
Expand Down Expand Up @@ -71,24 +70,6 @@ def with_plugins() -> Iterable[PluginManager]:
plugins.cleanup()


# See docs/users.md for documentation on the permission model.
# Enum values are meaningless and may change. Make sure to not store them
# anywhere (for storing/transfer use role names instead).
class UserRole(Enum):
# "role groups", used to grant a collection of "action roles"
nobody = auto() # no permissions granted
user = auto() # can run yara queries and read the state
admin = auto() # can manage the system (and do everything else)

# "action roles", used to give permission to a specific thing
can_manage_all_queries = auto()
can_manage_queries = auto()
can_list_all_queries = auto()
can_list_queries = auto()
can_view_queries = auto()
can_download_files = auto()


class User:
def __init__(self, token: Optional[Dict]) -> None:
self.__token = token
Expand All @@ -114,8 +95,8 @@ def roles(self, client_id: Optional[str]) -> List[UserRole]:


async def current_user(authorization: Optional[str] = Header(None)) -> User:
auth_enabled = db.get_mquery_config_key("auth_enabled")
if not auth_enabled or auth_enabled == "false":
auth_enabled = db.config.auth_enabled
if not auth_enabled:
return User(None)

if not authorization:
Expand All @@ -134,7 +115,7 @@ async def current_user(authorization: Optional[str] = Header(None)) -> User:

_bearer, token = token_parts

secret = db.get_mquery_config_key("openid_secret")
secret = db.config.openid_secret
if secret is None:
raise RuntimeError("Invalid configuration - missing_openid_secret.")

Expand Down Expand Up @@ -169,9 +150,9 @@ def __init__(self, need_permissions: List[UserRole]) -> None:
self.need_permissions = need_permissions

def __call__(self, user: User = Depends(current_user)):
auth_enabled = db.get_mquery_config_key("auth_enabled")
if not auth_enabled or auth_enabled == "false":
return
auth_enabled = db.config.auth_enabled
if not auth_enabled:
return User(None)

all_roles = get_user_roles(user)
if not any(role in self.need_permissions for role in all_roles):
Expand All @@ -198,15 +179,10 @@ def __call__(self, user: User = Depends(current_user)):
def get_user_roles(user: User) -> List[UserRole]:
"""Get all roles assigned to user, taking into account the
system configuration (like default configured roles)"""
client_id = db.get_mquery_config_key("openid_client_id")
client_id = db.config.openid_client_id
user_roles = user.roles(client_id)
auth_default_roles = db.get_mquery_config_key("auth_default_roles")
if not auth_default_roles:
auth_default_roles = "admin"
default_roles = [
UserRole[role.strip()] for role in auth_default_roles.split(",")
]
all_roles = set(user_roles + default_roles)
auth_default_roles = db.config.auth_default_roles
all_roles = set(user_roles + auth_default_roles)
return sum((expand_role(role) for role in all_roles), [])


Expand Down Expand Up @@ -455,7 +431,7 @@ def query(
]

degenerate_rules = [r.name for r in rules if r.parse().is_degenerate]
allow_slow = db.get_mquery_config_key("query_allow_slow") == "true"
allow_slow = db.config.query_allow_slow
if degenerate_rules and not (allow_slow and data.force_slow_queries):
if allow_slow:
# Warning: "You can force a slow query" literal is used to
Expand Down Expand Up @@ -601,9 +577,9 @@ def query_remove(
def server() -> ServerSchema:
return ServerSchema(
version=mquery_version(),
auth_enabled=db.get_mquery_config_key("auth_enabled"),
openid_url=db.get_mquery_config_key("openid_url"),
openid_client_id=db.get_mquery_config_key("openid_client_id"),
auth_enabled=str(db.config.auth_enabled).lower(),
openid_url=db.config.openid_url,
openid_client_id=db.config.openid_client_id,
about=app_config.mquery.about,
)

Expand Down
60 changes: 59 additions & 1 deletion src/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import random
import string
from redis import StrictRedis
from enum import Enum
from enum import Enum, auto
from rq import Queue # type: ignore
from sqlmodel import (
Session,
Expand Down Expand Up @@ -40,10 +40,64 @@ class TaskType(Enum):
COMMAND = "command"


# See docs/users.md for documentation on the permission model.
# Enum values are meaningless and may change. Make sure to not store them
# anywhere (for storing/transfer use role names instead).
class UserRole(Enum):
# "role groups", used to grant a collection of "action roles"
nobody = auto() # no permissions granted
user = auto() # can run yara queries and read the state
admin = auto() # can manage the system (and do everything else)

# "action roles", used to give permission to a specific thing
can_manage_all_queries = auto()
can_manage_queries = auto()
can_list_all_queries = auto()
can_list_queries = auto()
can_view_queries = auto()
can_download_files = auto()


# Type alias for Job ids
JobId = str


class UserModelConfig:
def __init__(self, db_instance):
self.db = db_instance

@property
def auth_default_roles(self) -> List[UserRole]:
auth_default_roles = self.db.get_mquery_config_key(
"auth_default_roles"
)
if auth_default_roles is None:
auth_default_roles = "admin"
return [
UserRole[role.strip()] for role in auth_default_roles.split(",")
]

@property
def openid_client_id(self) -> str | None:
return self.db.get_mquery_config_key("openid_client_id")

@property
def query_allow_slow(self) -> bool:
return self.db.get_mquery_config_key("query_allow_slow") == "true"

@property
def auth_enabled(self) -> bool:
return self.db.get_mquery_config_key("auth_enabled") == "true"

@property
def openid_url(self) -> str | None:
return self.db.get_mquery_config_key("openid_url")

@property
def openid_secret(self) -> str | None:
return self.db.get_mquery_config_key("openid_secret")


class Database:
def __init__(self, redis_host: str, redis_port: int) -> None:
self.redis: Any = StrictRedis(
Expand All @@ -57,6 +111,10 @@ def __schedule(self, agent: str, task: Any, *args: Any) -> None:
task, *args, job_timeout=app_config.rq.job_timeout
)

@property
def config(self):
return UserModelConfig(self)

@contextmanager
def session(self):
with Session(self.engine) as session:
Expand Down
Loading